Written by Enrico Minack, Open Source Software Contributor.
Spark is used to process tabular data of arbitrary size. One common operation is to group the data by some grouping columns. All rows that have the same values in those grouping columns (also called group keys) belong to the same group. These groups can then be further processed, for instance, all rows of a group can be aggregated into a single row per group.
As in many other systems, Spark allows to co-group two grouped datasets. The two groups that have the same group keys can be aggregated or further transformed into new rows.
When one of the datasets contains a window function with the same partition columns as the group columns, the result of the operation may be incorrect (below Spark 3.2.4) due to bug SPARK-42168.
Incorrect results are difficult to spot as your code does not break or throw an exception. Spark happily continues its computation, which renders dependent computation incorrect as well. The final result is not what you have asked for and debugging such a situation is cumbersome and tedious.
Eliciting the PySpark co-group bug
Here is an example of how the PySpark co-group bug occurs:
First we group datasets left_df
and right_with_window_df
by columns "id"
and "day"
. Then we co-group those grouped datasets and count the rows per group on the left and right side of the co-group:
… import pandas as pd left_grouped_df = left_df.groupBy("id", "day") right_grouped_df = right_with_window_df.groupBy("id", "day") def group_sizes(key, left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame([{ "id": key[0], "day": key[1], "lefts": len(left.index), "rights": len(right.index) }]) df = left_grouped_df.cogroup(right_grouped_df) .applyInPandas(group_sizes, schema="id long, day long, lefts integer, rights integer") df.orderBy("id", "day").show(10)
We would expect the following result (as returned by Spark 3.2.4, 3.3.0 and above):
id | day | lefts | rights |
0 | 0 | 10000 | 10000 |
0 | 1 | 10000 | 10000 |
0 | 2 | 10000 | 10000 |
... | ... | ... | ... |
1 | 99 | 10000 | 10000 |
However, with Spark 3.0, 3.1, 3.2.0, 3.2.1 and 3.2.3, we get:
[wptb id="8984" not found ]
For some group keys (e.g. id=0
and day=1
), there exist two rows. This is not allowed for the result of cogroup.
In which situation does the PySpark co-grouping bug occur?
This bug surfaces when one of your grouped datasets contains a window function that is partitioned by your group columns, but with a different order.
For example, if we first define two datasets, left_df
and right_df
:
from pyspark.sql.functions import col # the right numbers here are important to expose the bug ids = 2 days = 100 vals = 10000 parts = 10 # create two example datasets left_df and right_df id_df = spark.range(ids) day_df = spark.range(days).withColumnRenamed("id", "day") vals_df = spark.range(vals).withColumnRenamed("id", "value") df = id_df.join(day_df).join(vals_df) left_df = df.withColumnRenamed("value", "left").repartition(parts).cache() # SPARK-42132: this bug requires us to alias all columns from id_day_df here right_df = df.select( col("id").alias("id"), col("day").alias("day"), col("value").alias("right") ).repartition(parts).cache() …
Both of these datasets, left_df
and right_df
, look like the following (where for each value of id and day we have 10,000 rows):
[wptb id="8985" not found ]
Next, we add a window function right_count
to the right dataset that counts rows that have the same values for "day"
and "id"
:
… from pyspark.sql import Window from pyspark.sql.functions import count window = Window.partitionBy("day", "id") right_with_window_df = right_df.withColumn("right_count", count(col("right")).over(window)) …
The window function right_count
will compute 10000
for every row, because each row belongs to a window that has 10,000 rows:
[wptb id="8986" not found ]
We now have dataset right_with_window_df
with a window function partitioned by "day", "id"
, while we group this dataset by "id", "day"
:
window = Window.partitionBy("day", "id") … right_grouped_df = right_with_window_df.groupBy("id", "day")
In this situation the bug occurs because one of the datasets contains a window function with partition columns in a different order than the group columns.
How to workaround this bug?
The workaround is as easy as adjusting the order of your window function partition columns to the group column order (here "id", "day"
):
window = Window.partitionBy("id", "day")
The order of columns given to partitionBy
does not affect the result computed by the window function. In other words, Window.partitionBy("id", "day")
is equivalent to Window.partitionBy("day", "id")
.
Full Example
For reference, here is the full example that fails with Spark earlier than 3.2.4:
import pandas as pd from pyspark.sql import SparkSession, Window from pyspark.sql.functions import col, count spark = SparkSession.builder.getOrCreate() # the right numbers here are important to expose the bug ids = 2 days = 100 vals = 10000 parts = 10 # create two example datasets left_df and right_df id_df = spark.range(ids) day_df = spark.range(days).withColumnRenamed("id", "day") vals_df = spark.range(vals).withColumnRenamed("id", "value") df = id_df.join(day_df).join(vals_df) left_df = df.withColumnRenamed("value", "left").repartition(parts).cache() # SPARK-42132: this bug requires us to alias all columns from id_day_df here right_df = df.select( col("id").alias("id"), col("day").alias("day"), col("value").alias("right") ).repartition(parts).cache() # note the partitionBy column order is different to the groupBy("id", "day") column order below window = Window.partitionBy("day", "id") right_with_window_df = right_df.withColumn("right_count", count(col("right")).over(window)) # grouping the datasets left_grouped_df = left_df.groupBy("id", "day") right_grouped_df = right_with_window_df.groupBy("id", "day") def group_sizes(key, left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame([{ "id": key[0], "day": key[1], "lefts": len(left.index), "rights": len(right.index) }]) df = left_grouped_df.cogroup(right_grouped_df) .applyInPandas(group_sizes, schema="id long, day long, lefts integer, rights integer") df.orderBy("id", "day").show(10)
Investigating the bug
The first step to investigate what makes Spark come to that result is to look at the query plan:
df.explain()
Let’s look at the relevant piece of this plan:
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- FlatMapCoGroupsInPandas [id#0L, day#4L], [id#34L, day#35L], … :- Sort [id#0L ASC NULLS FIRST, day#4L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#0L, day#4L, 200), … <═══ ❌ : +- … +- Sort [id#34L ASC NULLS FIRST, day#35L ASC NULLS FIRST], false, 0 +- Project [id#34L, day#35L, id#34L, day#35L, right#36L, right_count#56L] +- Window [count(1) windowspecdefinition(day#35L, id#34L, …) …], … +- Sort [day#35L ASC NULLS FIRST, id#34L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(day#35L, id#34L, 200), … <═══ ❌ +- …
In words: Spark partitions the left side of the co-group by columns id and day (hashpartitioning(id#0L, day#4L, 200)
), while the right side is partitioned by day and id (hashpartitioning(day#35L, id#34L, 200)
), even though we tell Spark to group the right side by "id"
and then by "day"
. Note that the grouping column order is important for co-group to correctly collect the group data when calling into user function group_sizes
.
In Spark 3.2.4 and above, the query plan looks correct:
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- FlatMapCoGroupsInPandas [id#0L, day#4L], [id#34L, day#35L], … :- Sort [id#0L ASC NULLS FIRST, day#4L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#0L, day#4L, 200), … <═══ ✅ : +- … +- Sort [id#34L ASC NULLS FIRST, day#35L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#34L, day#35L, 200), … <═══ ✅ +- Project [id#34L, day#35L, id#34L, day#35L, right#36L, right_count#56L] +- Window [count(1) windowspecdefinition(day#35L, id#34L, …) …], … +- Sort [day#35L ASC NULLS FIRST, id#34L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(day#35L, id#34L, 200), … +- …
The right side is first partitioned by columns day and id (hashpartitioning(day#35L, id#34L, 200)
) in order to compute the window function, then it is partitioned again by columns id and day (hashpartitioning(id#34L, day#35L, 200)
).
References: