rasterframes
rasterframes copied to clipboard
Enable pyarrow de/serialization for Tile
Given the following correct looking code in an environment with pyarrow
installed:
from scipy.stats import kurtosis
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.SCALAR)
def tile_kurtosis(t):
return kurtosis(t.cells, axis=None)
spark.read.raster(path) \
.select(tile_kurtosis(rf.proj_raster).alias('kurt')) \
.show(33, False)
Actual Result
Results in java.lang.UnsupportedOperationException: Unsupported data type: tile
. Full stack below.
Expected result
Basically we expect the same behavior as a udf
but with the claimed performance enhancements of Pandas UDFs.
from pyrasterframes.rasterfunctions import rf_tile
from pyspark.sql.functions import udf
@udf('double')
def tile_kurtosis_1(t):
return kurtosis(t.cells, axis=None)
rf.select(tile_kurtosis_1(rf_tile(rf.proj_raster)).alias('udf_kurt')) \
.show(33, False)
Returns seomthing like
+------------------+
|udf_kurt |
+------------------+
|446.3521735008447 |
|617.4294389316507 |
|19.549182850744835|
|15.485447432509027|
|176.42575141589163|
Stack trace
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-47-183b506a644b> in <module>
1 rf.select(tile_kurtosis(rf.proj_raster).alias('kurt')) \
----> 2 .show(33, False)
/anaconda3/envs/pyrf/lib/python3.7/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
350 print(self._jdf.showString(n, 20, vertical))
351 else:
--> 352 print(self._jdf.showString(n, int(truncate), vertical))
353
354 def __repr__(self):
/anaconda3/envs/pyrf/lib/python3.7/site-packages/py4j/java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
/anaconda3/envs/pyrf/lib/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
61 def deco(*a, **kw):
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
65 s = e.java_exception.toString()
/anaconda3/envs/pyrf/lib/python3.7/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling o187.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 16.0 failed 1 times, most recent failure: Lost task 0.0 in stage 16.0 (TID 19, localhost, executor driver): java.lang.UnsupportedOperationException: Unsupported data type: tile
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:53)
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:89)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowField$1.apply(ArrowUtils.scala:86)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowField$1.apply(ArrowUtils.scala:85)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:85)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:113)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:112)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.Iterator$class.foreach(Iterator.scala:893)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at org.apache.spark.sql.types.StructType.map(StructType.scala:99)
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:112)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:70)
at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:247)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1992)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:170)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1651)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1639)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1638)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1638)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1872)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1821)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1810)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2055)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2074)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:363)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3278)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2489)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2489)
at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3259)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3258)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2489)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2703)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:254)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.UnsupportedOperationException: Unsupported data type: tile
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:53)
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:89)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowField$1.apply(ArrowUtils.scala:86)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowField$1.apply(ArrowUtils.scala:85)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:85)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:113)
at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:112)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.Iterator$class.foreach(Iterator.scala:893)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at org.apache.spark.sql.types.StructType.map(StructType.scala:99)
at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:112)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:70)
at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:247)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1992)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:170)
For reference and motivation
@vpipkt To feel more confident in the relative performance measures in those plots I'd like to see them incorporate some compute operation that is implemented in the JVM. Just creating a DataFrame in Python and then calling toPandas
may be optimized to never send data across the runtime barrier.
As of today, there doesn't seem to be any support in Arrow for UDTs.
https://github.com/apache/spark/blob/9f8c7a280476d37fb430da0adbde5d61e8a40714/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala#L36-L57
It's too bad, since Parquet supports them, and all they'd have to do is convert the schema of the underlying UTD encoding.
This test:
def test_pandas_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.SCALAR)
def tile_mean(cells):
# `cells` is a Pandas `Series`.
return cells.apply(np.mean)
df = self.rf.select(tile_mean(rf_tile_to_array_double(self.rf.tile)).alias('pandas_udf_mean'), rf_tile_mean(self.rf.tile))
df.show(truncate=False)
Generates this result:
+------------------+------------------+
|pandas_udf_mean |rf_tile_mean(tile)|
+------------------+------------------+
|10488.786144329897|10488.786144329897|
|null |10573.227770833333|
|9672.139422680413 |9672.139422680413 |
|null |10122.969770833333|
|10606.11498969072 |10606.11498969072 |
|9912.923030927835 |9912.923030927835 |
|10305.663113402063|10305.663113402063|
|9605.92806185567 |9605.92806185567 |
+------------------+------------------+
Not entirely sure, but suspect the null
values have to do with NoData values not getting handled properly.
try np.nanmean
?
How is the performance ?
On Mon, Aug 26, 2019 at 11:01 AM Simeon H.K. Fitch [email protected] wrote:
This test:
def test_pandas_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.SCALAR) def tile_mean(cells): # `cells` is a Pandas `Series`. return cells.apply(np.mean) df = self.rf.select(tile_mean(rf_tile_to_array_double(self.rf.tile)).alias('pandas_udf_mean'), rf_tile_mean(self.rf.tile)) df.show(truncate=False)
Generates this result:
+------------------+------------------+ |pandas_udf_mean |rf_tile_mean(tile)|+------------------+------------------+ |10488.786144329897|10488.786144329897| |null |10573.227770833333| |9672.139422680413 |9672.139422680413 | |null |10122.969770833333| |10606.11498969072 |10606.11498969072 | |9912.923030927835 |9912.923030927835 | |10305.663113402063|10305.663113402063| |9605.92806185567 |9605.92806185567 |+------------------+------------------+
Not entirely sure, but suspect the null values have to do with NoData values not getting handled properly.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/locationtech/rasterframes/issues/216?email_source=notifications&email_token=AB3P4L7PP63ESU542YRPH73QGPV35A5CNFSM4IHFGAVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD5EUOGI#issuecomment-524896025, or mute the thread https://github.com/notifications/unsubscribe-auth/AB3P4L4SQ7FVFPVWAZ2Y2TTQGPV35ANCNFSM4IHFGAVA .