I haven’t found a good beginner level explanation of apache spark’s aggregateByKey method so I’m writing one up.
This was part of a coursera big data class at https://www.coursera.org/learn/big-data-essentials
The simple case I was using was: given a social graph file with userId and follower id, finding the user with the most followers.
So this is a simple grouping by userId and aggregate by count of followers they have.
Implementing this using the aggregateByKey method is as follows:
#from pyspark shell, read in the file import operator #read in a local file raw_data = sc.textFile('/data/twitter/twitter_sample_small.txt') #define a method to read the data, split by tab def parse_edge(s): user, follower = s.split('\t') return (int(user),int(follower)) # cache the intermediate rdd after parsing it edges = raw_data.map(parse_edge).cache() #apply aggregateByKey - see explanation below the code fol_agg = edges.aggregateByKey(0,lambda v1,v2: v1+1\ ,operator.add) # top user/key with most followers. # use operator to make sure the values(aggregated counts) and not the keys/userIds # are used for the comparison top_user = fol_agg.top(1,operator.itemgetter(1)
Explanation:
0 is the starting value
The 1st closure (lambda v1,v2: v1+1) works over all records with the same key but within a partition.
Say the 1st 3 records key, value pairs are:
1\t5
2\t6
1\t7
For userId(key) 1, the 1st iteration takes the values 0(start value) and 5, and returns result of adding 0 and 1.
For the 2nd iteration for key 1, it takes values 1 (from previous calculation) and 7 from next record, and returns result of adding 1 to previous sum 1,
So key 1 now has a aggregated value of 2 and so on…
Note that the actual values 5 and 7 are not used directly but are just proxies to tell how many records there are for a key , each counting only once
The 2nd function passed (operator.add) adds up the values coming from the 1st lambda but it does this across different partitions.
SourceCode (slightly modified to run as python script not from pyspark shell):