Skip to main content
We're enhancing our site and your experience, so please keep checking back as we evolve.
Back to News
Spark’s groupByKey should be avoided – and here’s why

Spark’s groupByKey should be avoided – and here’s why

13 June 2023
  • Open Source Software

Written by Enrico Minack, Open Source Software Contributor.

Apache Spark is very popular when it comes to processing tabular data of arbitrary size. One common operation is to group the data by some columns to further process those grouped data. Spark has two ways of grouping data groupBy and groupByKey, while the latter works, it may cause performance issues in some cases. As good practice, avoid groupByKey whenever possible to prevent those performance issues.

Grouping data

Spark provides two ways to group and process data. Grouping can be done via groupBy and groupByKey. These functions return a RelationalGroupedDataset and a KeyValueGroupedDataset[K, V], respectively.

If you are already familiar with the differences between these two types of grouped datasets, you can jump right into the performance implications section further down. If you are not familiar with these two types, you may rightly ask:

Why are there two different types of grouped data?

Either type of grouped dataset provides different operations on groups. But before we dive into processing groups, here is an example dataset ds, that we use throughout this article:

// just here to make sure our tiny example dataset does not optimize joins
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

case class Val(id: Long, number: Int)
val ds = Seq((1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)).toDF("id", "number").as[Val]
ds.show
+---+------+
| id|number|
+---+------+
|  1|     1|
|  1|     2|
|  1|     3|
|  2|     2|
|  2|     3|
|  3|     3|
+---+------+

We will use column "id" as the grouping column, so we will get three groups.

Aggregating groups

Both types of grouped datasets RelationalGroupedDataset and KeyValueGroupedDataset[K, V] allow for aggregating groups. This is the most common way to process grouped data. An aggregate function returns a single row per group.

Grouping with groupBy and aggregating the groups returns a DataFrame with schema id: int, sum: bigint:

ds.groupBy("id")
  .agg(sum("number").as("sum"))
  .show()
+---+---+
| id|sum|
+---+---+
|  1|  6|
|  2|  5|
|  3|  3|
+---+---+

Grouping with groupByKey and aggregating the groups returns a DataSet[(Int, Int)] with schema key: int, sum: bigint:

ds.groupByKey(row => row.id)
  .agg(sum("number").as("sum").as[Int])
  .show()
+---+---+
|key|sum|
+---+---+
|  1|  6|
|  2|  5|
|  3|  3|
+---+---+

We have seen that both types of grouped data are pretty similar when it comes to aggregating groups, but only one allows us to iterate group values.

Iterate group values

Only KeyValueGroupedDataset[K, V] allows to process groups with a function defined by the user. That function obtains an iterator and can return an arbitrary number of rows. Hence, the groupByKey grouped data can be processed into none, one or many rows per group. The user has more possibilities processing the groups with groupByKey than with groupBy.

// return first and last element of iterator
def firstAndLast[T](id: Long, it: Iterator[T]): Iterator[T] = {
  if (it.hasNext) {
    val first = it.next
    if (it.hasNext) {
      Iterator(first, it.reduceLeft((b, n) => n))
    } else {
      Iterator.single(first)
    }
  } else {
    Iterator.empty
  }
}

// now get the first and last row of each group
// group by row id
ds.groupByKey(row => row.id)
  // call firstAndLast for each id and group iterator
  .flatMapGroups(firstAndLast)
  .show()
+---+-----+
| id|value|
+---+-----+
|  1|    1|
|  1|    3|
|  2|    2|
|  2|    3|
|  3|    3|
+---+-----+

We have seen that the two grouped datasets RelationalGroupedDataset and KeyValueGroupedDataset[K, V] provide similar, but also differing, functions on grouped data.

But are there also differences other than functional (API)?

In fact, there can be a significant performance penalty using one over the other.

Performance Considerations

Before Spark can process individual groups, it first has to rearrange the data. This is expensive and involves partitioning the data by the group columns. Calling one of the mapGroups or flatMapGroups methods additionally involves sorting the individual partitions.

If data is already partitioned and sorted, Spark skips these steps and processes the groups right away. This saves a lot of time and processing power. But it can only do so if it knows which columns are used for grouping.

A common situation where data come already partitioned and sorted, occurs after performing a join on a prospect group column:

// join dataset ds with another dataset on column id
ds.join(spark.range(4), "id").as[Val]
  // group by column id
  .groupByKey(row => row.id)
  // process groups via iterator
  .mapGroups((id, it) => id * it.size)
  // show the query plan
  .explain
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SerializeFromObject [input[0, bigint, false] AS value#824L]
   +- MapGroups …, [value#821L], [id#806L, number#807], obj#823: bigint
      +- Sort [value#821L ASC NULLS FIRST], false, 0                                      <=== ❌
         +- Exchange hashpartitioning(value#821L, 200), ENSURE_REQUIREMENTS, [id=#1283]   <=== ❌
            +- AppendColumns …, [input[0, bigint, false] AS value#821L]
               +- Project [id#806L, number#807]
                  +- SortMergeJoin [id#806L], [id#811L], Inner
                     :- Sort [id#806L ASC NULLS FIRST], false, 0
                     :  +- Exchange hashpartitioning(id#806L, 200), …, [id=#1275]
                     :     +- LocalTableScan [id#806L, number#807]
                     +- Sort [id#811L ASC NULLS FIRST], false, 0
                        +- Exchange hashpartitioning(id#811L, 200), …, [id=#1276]
                           +- Range (0, 3, step=1, splits=16)

We can read the following from this query plan:

The join (SortMergeJoin) triggers the partitioning (Exchange hashpartitioning(id#806L, 200)) and sorting (Sort [id#806L ASC NULLS FIRST]) of dataset ds (and the second dataset (Range (0, 3, step=1, splits=16))) by column "id". Our joined dataset (Project [id#806L, number#807]) is partitioned and sorted by column "id". Grouping that dataset should not require another partitioning or sorting step.

Then, the grouping key is added as a new column "value" (AppendColumns … AS value#821L), by executing the function row => row.id for each row. Spark cannot know that the values of this new column are equivalent to column "id", because the expression row => row.id is a Scala function that is opaque to Spark. So Spark partitions and sorts all data by column value#821L, not knowing that this is redundant.

Avoid groupByKey(...), better use groupBy(...).as[...]

If we were to use groupBy("id") instead, Spark would know the missing bit. But how can we access mapGroups and flatMapGroups methods when using groupBy rather than groupByKey?

We can get from RelationalGroupedDataset to KeyValueGroupedDataset[K, V] via as [K, V]:

// join dataset ds with another dataset on column id
ds.join(spark.range(4), "id").as[Val]
  // group by column id
  .groupBy("id")
  // turn into a KeyValueGroupedDataset[Long, (Long, Int)]
  .as[Long, (Long, Int)]
  // process groups via iterator
  .mapGroups((id, it) => id * it.size)
  // show the query plan
  .explain
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SerializeFromObject [input[0, bigint, false] AS value#845L]
   +- MapGroups …, [id#806L], [id#806L, number#807], obj#844: bigint
      +- Project [id#806L, number#807]
         +- SortMergeJoin [id#806L], [id#827L], Inner
            :- Sort [id#806L ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(id#806L, 200), …, [id=#1312]
            :     +- LocalTableScan [id#806L, number#807]
            +- Sort [id#827L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(id#827L, 200), …, [id=#1313]
                  +- Range (0, 3, step=1, splits=16)

Now that Spark knows we want to group by column "id", it skips the outer partitioning and sorting.

Any other method of KeyValueGroupedDataset[K, V] also benefits from using groupBy(...).as[...]. For instance, aggregating groups skips the partitioning step (no sorting involved):

First with groupByKey:

// join dataset ds with another dataset on column id
ds.join(spark.range(4), "id").as[Val]
  // group by column id
  .groupByKey(row => row.id)
  // process groups via iterator
  .agg(sum("number").as("sum").as[Int])
  // show the query plan
  .explain
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[value#917L], functions=[sum(number#807)])
   +- Exchange hashpartitioning(value#917L, 200), ENSURE_REQUIREMENTS, [id=#1397]         <=== ❌
      +- HashAggregate(keys=[value#917L], functions=[partial_sum(number#807)])
         +- Project [number#807, value#917L]
            +- AppendColumns …, [input[0, bigint, false] AS value#917L]
               +- Project [id#806L, number#807]
                  +- SortMergeJoin [id#806L], [id#907L], Inner
                     :- Sort [id#806L ASC NULLS FIRST], false, 0
                     :  +- Exchange hashpartitioning(id#806L, 200), …, [id=#1387]
                     :     +- LocalTableScan [id#806L, number#807]
                     +- Sort [id#907L ASC NULLS FIRST], false, 0
                        +- Exchange hashpartitioning(id#907L, 200), …, [id=#1388]
                           +- Range (0, 3, step=1, splits=16)

We again see a redundant partitioning (Exchange hashpartitioning(value#917L, 200)).

Now with groupBy(...).as[...]:

// join dataset ds with another dataset on column id
ds.join(spark.range(4), "id").as[Val]
  // group by column id
  .groupBy("id")
  // turn into a KeyValueGroupedDataset[Long, (Long, Int)]
  .as[Long, (Long, Int)]
  // process groups via iterator
  .agg(sum("number").as("sum").as[Int])
  // show the query plan
  .explain
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[id#806L], functions=[sum(number#807)])
   +- HashAggregate(keys=[id#806L], functions=[partial_sum(number#807)])
      +- Project [id#806L, number#807]
         +- SortMergeJoin [id#806L], [id#881L], Inner
            :- Sort [id#806L ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(id#806L, 200), …, [id=#1347]
            :     +- LocalTableScan [id#806L, number#807]
            +- Sort [id#881L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(id#881L, 200), …, [id=#1348]
                  +- Range (0, 3, step=1, splits=16)

Working with DataFrames.

Using groupBy(...).as[...] with DataFrames is a bit tricky, as you need to provide an encoder for the Row values.

It is easiest to reuse the encoder of the dataframe that is being grouped:

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.Row

// join dataset ds with another dataset on column id, creates a DataFrame
val df = ds.join(spark.range(4), "id")

// group dataframe by column id
df.groupBy("id")
  // turn into a KeyValueGroupedDataset[Long, Row]
  .as[Long, Row](Encoders.scalaLong, df.encoder)
  // process groups via iterator
  .agg(sum("number").as("sum").as[Int])
  // show the query plan
  .explain

Summary

We have seen that using groupByKey can have a significant impact on performance when data is already partitioned and sorted. It can be considered good practice to prefer groupBy(...) (RelationalGroupedDataset) over groupByKey(...) (KeyValueGroupedDataset[K, V]).

If you really need to use KeyValueGroupedDataset[K, V] use groupBy(...).as[K, V] instead of groupByKey(...). This allows for Spark’s query optimisation.

A PySpark bug makes co-grouping with window function partition-key-order-sensitive
  • Technology Innovation and Open Source
  • 13 Jun 2023

Spark is used to process tabular data of arbitrary size. One common operation is to group the data by some grouping columns.

Read more

Latest News

G-Research 2024 PhD prize winners: Imperial College London
  • 18 Jun 2024
Read article
G Research
G-Research May 2024 Grant Winners
  • 06 Jun 2024

Each month, we provide up to £2,000 in grant money to early career researchers in quantitative disciplines. Hear from our April grant winners.

Read article
G-Research 2024 PhD prize winners: University of Oxford
  • 03 Jun 2024
Read article

Stay up to date with
G-Research