Open Source Software 29/03/2023 14 min read

A PySpark bug makes co-grouping with window function partition-key-order-sensitive

Written by Enrico Minack, Open Source Software Contributor.

Spark is used to process tabular data of arbitrary size. One common operation is to group the data by some grouping columns. All rows that have the same values in those grouping columns (also called group keys) belong to the same group. These groups can then be further processed, for instance, all rows of a group can be aggregated into a single row per group.

As in many other systems, Spark allows to co-group two grouped datasets. The two groups that have the same group keys can be aggregated or further transformed into new rows.

When one of the datasets contains a window function with the same partition columns as the group columns, the result of the operation may be incorrect (below Spark 3.2.4) due to bug SPARK-42168.

Incorrect results are difficult to spot as your code does not break or throw an exception. Spark happily continues its computation, which renders dependent computation incorrect as well. The final result is not what you have asked for and debugging such a situation is cumbersome and tedious.

Eliciting the PySpark co-group bug

Here is an example of how the PySpark co-group bug occurs:

First we group datasets left_df and right_with_window_df by columns "id" and "day". Then we co-group those grouped datasets and count the rows per group on the left and right side of the co-group:

…

import pandas as pd

left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_with_window_df.groupBy("id", "day")

def group_sizes(key, left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
    return pd.DataFrame([{
        "id": key[0],
        "day": key[1],
        "lefts": len(left.index),
        "rights": len(right.index)
    }])

df = left_grouped_df.cogroup(right_grouped_df) \
         .applyInPandas(group_sizes, schema="id long, day long, lefts integer, rights integer")

df.orderBy("id", "day").show(10)

We would expect the following result (as returned by Spark 3.2.4, 3.3.0 and above):

id

day

lefts

rights

0

0

10000

10000

0

1

10000

10000

0

2

10000

10000

...

...

...

...

1

99

10000

10000

However, with Spark 3.0, 3.1, 3.2.0, 3.2.1 and 3.2.3, we get:

id

day

lefts

rights

0

0

10000

10000

0

1

0

10000

0

1

10000

0

0

2

0

10000

0

2

10000

0

0

3

10000

0

...

...

...

...

1

99

0

10000

1

99

10000

0

For some group keys (e.g. id=0 and day=1), there exist two rows. This is not allowed for the result of cogroup.

In which situation does the PySpark co-grouping bug occur?

This bug surfaces when one of your grouped datasets contains a window function that is partitioned by your group columns, but with a different order.

For example, if we first define two datasets, left_df and right_df:

from pyspark.sql.functions import col

# the right numbers here are important to expose the bug
ids = 2
days = 100
vals = 10000
parts = 10

# create two example datasets left_df and right_df
id_df = spark.range(ids)
day_df = spark.range(days).withColumnRenamed("id", "day")
vals_df = spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)

left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from id_day_df here
right_df = df.select(
    col("id").alias("id"),
    col("day").alias("day"),
    col("value").alias("right")
).repartition(parts).cache()

…

Both of these datasets, left_df and right_df, look like the following (where for each value of id and day we have 10,000 rows):

id

day

left / right

0

0

0

0

0

1

0

0

2

...

...

...

0

0

9999

0

1

0

...

...

...

2

99

9999

Next, we add a window function right_count to the right dataset that counts rows that have the same values for "day" and "id":

…

from pyspark.sql import Window
from pyspark.sql.functions import count

window = Window.partitionBy("day", "id")
right_with_window_df = right_df.withColumn("right_count", count(col("right")).over(window))

…

The window function right_count will compute 10000 for every row, because each row belongs to a window that has 10,000 rows:

id

day

right

right_count

0

0

0

10000

0

0

1

10000

0

0

2

10000

...

...

...

...

2

99

9999

10000

We now have dataset right_with_window_df with a window function partitioned by "day", "id", while we group this dataset by "id", "day":

window = Window.partitionBy("day", "id")
…
right_grouped_df = right_with_window_df.groupBy("id", "day")

In this situation the bug occurs because one of the datasets contains a window function with partition columns in a different order than the group columns.

How to workaround this bug?

The workaround is as easy as adjusting the order of your window function partition columns to the group column order (here "id", "day"):

window = Window.partitionBy("id", "day")

The order of columns given to partitionBy does not affect the result computed by the window function. In other words, Window.partitionBy("id", "day") is equivalent to Window.partitionBy("day", "id").

Full Example

For reference, here is the full example that fails with Spark earlier than 3.2.4:

import pandas as pd

from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, count

spark = SparkSession.builder.getOrCreate()

# the right numbers here are important to expose the bug
ids = 2
days = 100
vals = 10000
parts = 10

# create two example datasets left_df and right_df
id_df = spark.range(ids)
day_df = spark.range(days).withColumnRenamed("id", "day")
vals_df = spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)
left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from id_day_df here
right_df = df.select(
    col("id").alias("id"),
    col("day").alias("day"),
    col("value").alias("right")
).repartition(parts).cache()

# note the partitionBy column order is different to the groupBy("id", "day") column order below
window = Window.partitionBy("day", "id")
right_with_window_df = right_df.withColumn("right_count", count(col("right")).over(window))

# grouping the datasets
left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_with_window_df.groupBy("id", "day")

def group_sizes(key, left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
    return pd.DataFrame([{
        "id": key[0],
        "day": key[1],
        "lefts": len(left.index),
        "rights": len(right.index)
    }])

df = left_grouped_df.cogroup(right_grouped_df) \
        .applyInPandas(group_sizes, schema="id long, day long, lefts integer, rights integer")

df.orderBy("id", "day").show(10)

Investigating the bug

The first step to investigate what makes Spark come to that result is to look at the query plan:

df.explain()

Let’s look at the relevant piece of this plan:

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- FlatMapCoGroupsInPandas [id#0L, day#4L], [id#34L, day#35L], …
   :- Sort [id#0L ASC NULLS FIRST, day#4L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(id#0L, day#4L, 200), …                             <═══ ❌
   :     +- …
   +- Sort [id#34L ASC NULLS FIRST, day#35L ASC NULLS FIRST], false, 0
      +- Project [id#34L, day#35L, id#34L, day#35L, right#36L, right_count#56L]
         +- Window [count(1) windowspecdefinition(day#35L, id#34L, …) …], …
            +- Sort [day#35L ASC NULLS FIRST, id#34L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(day#35L, id#34L, 200), …                  <═══ ❌
                  +- …

In words: Spark partitions the left side of the co-group by columns id and day (hashpartitioning(id#0L, day#4L, 200)), while the right side is partitioned by day and id (hashpartitioning(day#35L, id#34L, 200)), even though we tell Spark to group the right side by "id" and then by "day". Note that the grouping column order is important for co-group to correctly collect the group data when calling into user function group_sizes.

In Spark 3.2.4 and above, the query plan looks correct:

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- FlatMapCoGroupsInPandas [id#0L, day#4L], [id#34L, day#35L], …
   :- Sort [id#0L ASC NULLS FIRST, day#4L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(id#0L, day#4L, 200), …                             <═══ ✅
   :     +- …
   +- Sort [id#34L ASC NULLS FIRST, day#35L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(id#34L, day#35L, 200), …                           <═══ ✅
         +- Project [id#34L, day#35L, id#34L, day#35L, right#36L, right_count#56L]
            +- Window [count(1) windowspecdefinition(day#35L, id#34L, …) …], …
               +- Sort [day#35L ASC NULLS FIRST, id#34L ASC NULLS FIRST], false, 0
                  +- Exchange hashpartitioning(day#35L, id#34L, 200), …
                     +- …

The right side is first partitioned by columns day and id (hashpartitioning(day#35L, id#34L, 200)) in order to compute the window function, then it is partitioned again by columns id and day (hashpartitioning(id#34L, day#35L, 200)).

References:

Stay up to-date with G-Research

Subscribe to our newsletter to receive news & updates

You can click here to read our privacy policy. You can unsubscribe at anytime.