Difference between map and flatMap in Spark

Apache Spark is a powerful distributed framework that leverages in-memory caching and optimized query execution to produce faster results. The transformation of data is one of the key components of Spark, and in this article, we will explore the distinctions between two important transformation functions, namely map() and flatMap(), and their respective use cases.

The difference between map and flatMap in Spark is that map() transforms every element of an RDD into a new element utilizing a specified function. In contrast, flatMap() applies a function to each element, which produces a sequence of values that are then flattened into a new RDD. Essentially, map performs a one-to-one transformation, while flatMap performs a one-to-many transformation.

Difference between map and flatMap
Differences

Before discussing about map and flatMap transformation functions, Let’s understand more about transformation in Spark

What is Transformation in Spark?

A Transformation is an operation that takes an RDD (Resilient Distributed Dataset) as an Input and returns a new RDD as output. Transformations are performed lazily, which means that they are not executed immediately until an action is called.

Some of the common transformations used in Spark

map, flatMap, filter,reduceByKey,groupBy

Transformations are immutable Because RDD cannot be modified. Instead, you can create new RDDs from existing ones using transformations.

map() in Spark

map() in Spark is a transformation logic that applies to each element of an RDD (Resilient Distributed Dataset) and returns a new RDD with the same number of elements. The function applied to each element should return a single value for that element.

One => One Transformation

Example:

Let’s say, we have an RDD containing the following integers: [1, 2, 3]. We can use map() to double each integer in the RDD:

#Create a pyspark session:

pyspark

>>> rdd = sc.parallelize([1, 2, 3])

>>> mapped_rdd = rdd.map(lambda x: x * 2)

#Output

>>> mapped_rdd.collect()

[2, 4, 6]                                                                    

>>>

In this example, we have created an RDD with the parallelize method and provided a list of integers. Then using map transformation with a lambda function that multiplies each element by 2. The result is a new RDD with the values 2, 4, and 6.

NOTE: In pyspark, we use lambda functions to define transformations, Which can be passed to the map function

flatMap() in Spark

flatMap in Spark is also a transformation logic, but it applies to each element of an RDD that can either return zero, one, or multiple values. The output is a flattened RDD where all the returned values are concatenated into a single RDD.

one => many Transformation

Example:

Let’s say, we have an RDD containing the following strings: ["hello world", "Learn Share"]. We can use flatMap to split each string into words:

#Create a pyspark session:

pyspark
>>> rdd = sc.parallelize(["hello world", "Learn Share"])

>>> flat_mapped_rdd = rdd.flatMap(lambda x: x.split(" "))

#Output

>>> flat_mapped_rdd.collect()

['hello', 'world', 'Learn', 'Share']                                            
>>>

In this example, we have created an RDD with the parallelize method and passed a list of two strings. Then we used flatMap transformation with a lambda function that splits each string into words and returns a sequence of words. The flatMap function flattens the sequence of words into a new RDD with each word as a separate element. The result is a new RDD with the values “hello”, “world”, “Learn”, and “Share”.

Note: flatMap is used to apply a one-to-many transformation to the elements of an RDD. In this case, the lambda function returns a sequence of words for each string, which are then flattened into a new RDD.

When to Use map vs. flatMap

In general, map is useful for applying a transformation to each element of an RDD, while flatMap is useful for transforming each element into multiple elements and flattening the result into a single RDD.

For example, suppose we have an RDD containing the following strings: ["a,b,c", "d,e,f"]. If we want to split each string into a list of characters, we can use map:

#Create a pyspark session: 
pyspark
>>> rdd = sc.parallelize(["a,b,c", "d,e,f"])

>>> flat_mapped_rdd = rdd.map(lambda x: x.split(","))

#Output

>>> flat_mapped_rdd.collect()

[['a', 'b', 'c'], ['d', 'e', 'f']]                                              
>>> 

The resulting RDD mappedRDD will contain the values [[a, b, c], [d, e, f]].

However, if we want to split each string into individual characters, we can use flatMap:

#Create a pyspark session:

pyspark

>>> rdd = sc.parallelize(["a,b,c", "d,e,f"])

>>> flat_mapped_rdd = rdd.flatMap(lambda x: x.split(","))

#Output
                                        
>>>>>> flat_mapped_rdd.collect()

['a', 'b', 'c', 'd', 'e', 'f']                                                  
>>> 

The resulting RDD flatMappedRDD will contain the values [a, b, c, d, e, f].

Real-Time Example

Let’s consider a real-time example, Where we have a dataset of “tweets” and each tweet is stored as a string. Now, We wanted to create an RDD which has “hashtags” used in tweets

flatMap transformation to achieve this

#Create a Pyspark session

pyspark
>>> tweets_rdd = sc.parallelize(["I love #Learn-Share!", "Learn-share is a platform to share knowledge 'https://lean-share.com'"])

>>> hashtags_rdd = tweets_rdd.flatMap(lambda tweet: tweet.split(" ")).filter(lambda word: word.startswith("#"))

#Output

>>> hashtags_rdd.collect()

['#Learn-Share!']                                                            

>>>

In the above code, We have created an RDD with all the tweets. Then we applied flatMap transformation to spit the strings into a sequence of words. With # symbol, We are filtering the hashtag and storing it in the new RDD (hashtags_rdd)

Now, If we want to create a new RDD that contains the length of each hashtag used in the tweets. We can use the map transformation for this.

# Continuation of the above example

>>> lengths_rdd = hashtags_rdd.map(lambda hashtag: len(hashtag))

#Output

>>> lengths_rdd.collect()

[13]                                                                            

>>> 

#If we have two hashtag, the output will be like below

>>> lengths_rdd.collect()
[13, 10]

In the above code, We have used map transformation on top of the”hashtags_rdd” RDD to transform the hashtag into the corresponding length of the hashtag

Conclusion

In conclusion, map and flatMap are both useful transformation operations in Spark, but they have their own use cases. map is used for transforming each element into a single value, while flatMap is used for transforming each element into multiple values and flattening the result into a single RDD.

Understanding the differences between these two functions is essential to optimizing and streamlining your Spark data processing workflows.

If you have any further questions or if you like to add up something, please use the comment to start a discussion

Good Luck with your Learning !!

Similar Posts