morphir-elm
morphir-elm copied to clipboard
Support per-aggregation filters in Morphir SDK Aggregation API transpilation to Spark
As decided in https://github.com/finos/morphir-elm/issues/799#issuecomment-1193923862, this should be solvable by case expressions (like SQL), or window functions.
I've investigated filters using Spark's equivalent of case-where blocks. For an aggregation represented in elm by:
source
|> groupBy .fieldName
|> aggregate (\key inputs ->
{ fieldName = key,
, aggregated1 = inputs (averageOf .otherFieldName |> withFilter ((>=) 1.0))
}
)
(i.e. take the average of otherFieldName where otherFieldName >= 1.0) The equivalent in Spark would be
source.groupBy("fieldName").agg(
org.apache.spark.sql.functions.avg(org.apache.spark.sql.functions.when(
org.apache.spark.sql.functions.col("otherFieldName") >= 1.0),
org.apache.spark.sql.functions.col("otherFieldName")
).alias("aggregated1")
)
This works because when()
sets the value to null if .otherwise()
isn't used, and most aggregations ignore nulls.
It doesn't ignore nulls in count
, though.
The mention that it doesn't ignore nulls in count
came from https://spark.apache.org/docs/3.0.0-preview/sql-ref-null-semantics.html#built-in-aggregate.
Having tested with some real data in spark-shell, I think count
might behave as expected without any special handling.
scala> df.show()
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
| James| Sales| 3000|
| Michael| Sales| 4600|
| Robert| Sales| 4100|
| Maria| Finance| 3000|
| James| Sales| 3000|
| Scott| Finance| 3300|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
| Saif| Sales| 4100|
+-------------+----------+------+
scala> df.groupBy("department").agg(count(when(col("salary") >= 2000, col("salary")))).show()
+----------+-------------------------------------------------+
|department|count(CASE WHEN (salary >= 2000) THEN salary END)|
+----------+-------------------------------------------------+
| Sales| 5|
| Finance| 3|
| Marketing| 2|
+----------+-------------------------------------------------+
scala> df.groupBy("department").agg(count(when(col("salary") >= 2001, col("salary")))).show()
+----------+-------------------------------------------------+
|department|count(CASE WHEN (salary >= 2001) THEN salary END)|
+----------+-------------------------------------------------+
| Sales| 5|
| Finance| 3|
| Marketing| 1|
+----------+-------------------------------------------------+
i.e. if when
is used to turn some values into nulls, count
doesn't count them under the circumstances we expect to use it.
I have not been able to figure out a way to do the same with window functions.
In Spark ObjectExpressions, the above is:
Aggregate
"department"
[
["aggregated", "1"]
, Function
"count"
[ Function
"when"
[ BinaryOperation
">="
Column "salary"
Literal (StringLiteral 2001)
, Column "salary"
]
]
]
(From "source")
In src/Morphir/Spark/AST.elm, (List.filter predicate sourceRelation) is transformed into a (Filter fieldExpression source). i.e. expressionFromValue is responsible for transforming a filter function into a Spark Expression.
Filter functions are expected to have the form \a -> a.fieldName >= 2001