Caching datasets
The result of transformations can be reused by subsequent actions by calling cache first.
// mark ds4 to be cached
ds4.cache()
// executes four transformations and caches result
ds4.count()
// reuses result of the four transformations
ds4.write.csv("existing-users.csv")
ds4.collect()
The disadvantage of caching is that this occupies memory or disk storage to store the computed dataset df4
.
The memory storage is freed by calling unpersist
:
// frees storage occupied by ds4
ds4.unpersist()
Metrics of datasets like the number of rows, the average of a column, or the number of null values of a column, can be declared through transformations. But they have to be executed by an action in order to be retrieved (e.g. show
, write
, or collect
). This executes all other transformations of the initial dataset as well.
Imagine a data processing job that reads a dataset and declares dozens of expensive transformations, which might take hours to process. Declaring and retrieving metrics on the result dataset through an action, before or after actually storing the result dataset, will execute all transformations twice. So retrieving the metrics takes as long as processing your dataset, effectively multiplying the processing time.
// read input data
val raw = spark.read.csv("raw-data.csv")
// process data
val clean = prepare_dataset(raw)
// write clean data
clean.write.csv("clean-data.csv")
// retrieve metrics about our clean data
val rows = clean.count()
val average = clean.select(avg($"value")).as[Double].collect().head
val nullValues = clean.select(count(when($"value".isNull, lit(1)))).as[Long].collect().head
The three metrics execute all transformations of dataset (clean
) three times. This makes your data processing job take three times longer.
The result dataset clean
could be cached, which makes the transformation result reused by the metrics, but this may require a vast amount of storage, which might not be available.
The least expensive alternative is usually to read the written data and compute the metrics from that dataset:
val written = spark.read.csv("clean-data.csv")
val rows = written.count()
val average = written.select(avg($"value")).as[Double].collect().head
val nullValues = written.select(count(when($"value".isNull, lit(1)))).as[Long].collect().head
Any of these approaches require either extra cache memory or processing time. This is where Spark observation metrics come into play. They extract metrics while an action is executed. Caching is not required and computation is reduced to the minimum.
Computing metrics as observations
Spark allows declaring metrics through the transformation observe
, where metrics are declared as aggregations. The difference from metrics as transformations is that they are not retrieved through an action return value (e.g. val rows = clean.count()
), but retrieved through an Observation instance (e.g. observation.get("rows")
), while executing another action (e.g. cleanWithMetrics.write.csv("clean-data.csv")
):
import org.apache.spark.sql.Observation
// read input data
val raw = spark.read.option("header", true).csv("raw-data.csv")
// process data, this adds loads of expensive transformations
val clean = prepare_dataset(raw)
// define observations on data
val observation = Observation()
val cleanWithMetrics = clean.observe(
observation,
count().as("rows"),
avg($"value").as("average"),
count(when($"value".isNull, lit(1))).as("null values")
)
// write clean data
cleanWithMetrics.write.csv("clean-data.csv")
// retrieve metrics about our clean data
val rows = observation.get("rows")
val average = observation.get("average")
val nullValues = observation.get("null values")
After executing an action on the dataset returned by observe (here cleanWithMetrics
), the declared metrics can be retrieved via the Observation instance. They have been computed while executing the action and become available with the termination of the action.
Limitations of observation metrics
Observation metrics are restricted to aggregate functions. These are functions that return a single aggregated value for the entire dataset like sum
, avg
, or collect_set
. Further, observation metrics aim at providing metrics with low additional computational effort. This prohibits any aggregations that require a shuffle stage (repartitioning the dataset) like count_distinct
(while approx_count_distinct
is allowed) or any window function (e.g. sum.over
). We look into some ways to work around this limitation below.
Another restriction is that executing an action on the observed dataset a second time will not change the observed metrics. Metrics are collected only from the first invocation of an action. Consequently, an Observation
instance can only be used for one Dataset.observe
call.
Over-counting observation metrics
There are situations where stages that contain observations are executed multiple times. This will execute the aggregations multiple times as well, creating over-counting metrics.
A common situation where this happens is when an observed dataset (i.e. the dataset returned by observe
) is sorted:
// define observations on data
val observation = Observation()
val cleanWithMetrics = clean.observe(
observation,
count().as("rows"),
avg($"value").as("average"),
count(when($"value".isNull, lit(1))).as("null values")
).sort()
Now, the sort first evaluates the observation once while sampling the observed dataset to prepare for the sort, then it evaluates the observation a second time to actually sort the observed dataset. We will see that metric "rows"
will be twice as large as expected.
A skewed dataset evaluates some partitions even three times, all others twice, so we get completely unpredictable observations. Let’s look at this reproducible example:
import org.apache.spark.sql.Observation
import org.apache.spark.sql.functions.log2
import org.apache.spark.sql.expressions.Window
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
spark.conf.set("spark.sql.adaptive.enabled", value = false)
val ids = spark.range(0, 1000000, 1, 100)
val once = Observation()
val twice = Observation()
val skewed = Observation()
// an observation run once
ids.observe(once, count("*").as("rows")).collect
// an observation run twice
ids.observe(twice, count("*").as("rows")).sort($"id".desc).collect
// a skewed dataset runs some partitions three times, all others twice
ids.withColumn("bits", log2($"id").cast("int"))
// this repartitions the dataset by column `bits`, which is highly skewed
.withColumn("bits-card", count("*").over(Window.partitionBy($"bits")))
.observe(skewed, count("*").as("rows"))
// sorting this skewed dataset evaluates some partitions three times, others twice
.sort($"id")
.collect
(once.get("rows"), twice.get("rows"), skewed.get("rows"))
// (1000000,2000000,2984192)
This example shows how sensitive observations are to subsequent transformations. Observations can still be used in those situations when the observed dataset is cached, but as discussed above, the niche advantage of observation metrics is avoiding caching the entire dataset in the first place.
The conclusion here is to use observations only in situations where no over-counting occurs, or only with metrics that are robust against over-counting like the minimum, maximum, or the existence of a property. Alternatively, handle the result with the required care.
Counting null values for instance, where none are expected will produce a count of 0 if no nulls exist. So a 0 can be trusted, even if over-counting occurs. A number larger than 0 does not give you a precise amount of null values, but it tells you null values exist, and the upper bound of such null values.
Observation metrics with shuffles
Metrics that involve distinct
or window functions require a shuffle stage. Such a shuffle renders the metrics expensive, which is against the purpose of cheap observation metrics. However, such metrics can still be computed cheaply in situations where the shuffle stage is already required by the subsequent transformation. In that situation, we can still observe metrics without invoking extra computational effort.
We have seen in Limitations of observation metrics that observation metrics have to be aggregate functions (return a single value) while window functions do not aggregate any rows. Using these functions requires an additional aggregation. We can observe such aggregations by first materializing the window function and then observing the aggregation.
Let’s look at the following example. We want to get the maximum cardinality of values in column "id"
, i.e., the number of rows that have the same value in that column. Such an aggregation can be expressed by a window function, aggregated by max
:
import org.apache.spark.sql.expressions.Window
val dsWithCardinality = ds.withColumn("cardinality", count("*").over(Window.partitionBy($"id")))
// id user created deleted exists cardinality
// 1 Alice 2023-07-01 null yes 1
// 2 Bob 2023-07-08 2023-08-01 no 1
// 3 Charly 2023-07-15 null yes 2
// 3 Charly 2023-08-01 null yes 2
val maxCardinality = dsWithCardinality.select(max($"cardinality")).as[Long].head
// val maxCardinality: Long = 2
The shuffle stage (repartitioning) required by the window function can be seen via explain
:
dsWithCardinality.explain()
== Physical Plan ==
Window [count(1) windowspecdefinition(id, …) AS cardinality], [id]
+- *(1) Sort [id ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id, 200), ENSURE_REQUIREMENTS
+- FileScan csv [id,user,created,deleted] …
The query plan[3] repartitions the dataset by column id (see step Exchange hashpartitioning(id, 200), ENSURE_REQUIREMENTS
), which is an expensive operation. If our data pipeline contains a transformation that requires such a repartitioning (e.g. join
, groupBy
, or a window function with partitionBy
), then we can compute our observation right before that transformation for free.
Here is an example for a join by column "id"
transformation, e.g. ds.join(logins, "id")
:
val logins = spark.read.option("header", true).csv("user_logins.csv")
ds.join(logins, "id").explain
== Physical Plan ==
*(5) Project [id, user, created, deleted, exists, login]
+- *(5) SortMergeJoin [id], [id], Inner
:- *(2) Sort [id ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id, 200), ENSURE_REQUIREMENTS
: +- *(1) Filter isnotnull(id)
: +- FileScan csv [id,user,created,deleted,exists] …
+- *(4) Sort [id ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id, 200), ENSURE_REQUIREMENTS
+- *(3) Filter isnotnull(id)
+- FileScan csv [id,login] …
The query plan shows that the join requires a repartitioning by column "id"
, which we can reuse to materialize our window function metric to then observe the max
aggregation for free.
To be more precise: the window function repartitions the dataset, the observed metric aggregates the window function result, and the subsequent join reuses the repartitioned dataset:
val dsWithMetrics = dsWithCardinality
.observe(observation, max($"cardinality"))
.drop("cardinality")
.join(logins, "id")
dsWithMetrics.explain()
== Physical Plan ==
*(5) Project [id, user, created, deleted, exists, cardinality, login]
+- *(5) SortMergeJoin [id], [id], Inner
:- *(2) Filter isnotnull(id)
: +- CollectMetrics …, [max(cardinality) AS max(cardinality)]
: +- Window [count(1) windowspecdefinition(id, …) AS cardinality], [id]
: +- *(1) Sort [id ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id, 200), ENSURE_REQUIREMENTS
: +- FileScan csv [id,user,created,deleted,exists] …
+- *(4) Sort [id ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id, 200), ENSURE_REQUIREMENTS
+- *(3) Filter isnotnull(id)
+- FileScan csv [id,login] …
We can see in this query plan that there is no extra shuffle involved.