baml
baml copied to clipboard
Make baml_py.baml_py.BamlRuntime pickle-able
I am trying to integrate BAML with PySpark and am having trouble... I get a pickle error. I suspected this might happen and want to work on getting it to go. Do you know what about the class isn't pickle-able? I would like to fix it.
The code is a PySpark UDF that pickles objects to distribute compute jobs.
The exception is:
Traceback (most recent call last):
File "/Users/rjurney/Software/spark/python/pyspark/serializers.py", line 459, in dumps
return cloudpickle.dumps(obj, pickle_protocol)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/rjurney/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 73, in dumps
cp.dump(obj)
File "/Users/rjurney/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 632, in dump
return Pickler.dump(self, obj)
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'baml_py.baml_py.BamlRuntime' object
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/Software/spark/python/pyspark/serializers.py:459, in CloudPickleSerializer.dumps(self, obj)
458 try:
--> 459 return cloudpickle.dumps(obj, pickle_protocol)
460 except pickle.PickleError:
File ~/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py:73, in dumps(obj, protocol, buffer_callback)
70 cp = CloudPickler(
71 file, protocol=protocol, buffer_callback=buffer_callback
72 )
---> 73 cp.dump(obj)
74 return file.getvalue()
File ~/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py:632, in CloudPickler.dump(self, obj)
631 try:
--> 632 return Pickler.dump(self, obj)
633 except RuntimeError as e:
TypeError: cannot pickle 'baml_py.baml_py.BamlRuntime' object
During handling of the above exception, another exception occurred:
PicklingError Traceback (most recent call last)
Cell In[10], line 1
----> 1 article_df = article_df.withColumn("document", F.udf(parse_article, doc_schema)("url"))
File ~/Software/spark/python/pyspark/sql/udf.py:423, in UserDefinedFunction._wrapped.<locals>.wrapper(*args)
421 @functools.wraps(self.func, assigned=assignments)
422 def wrapper(*args: "ColumnOrName") -> Column:
--> 423 return self(*args)
File ~/Software/spark/python/pyspark/sql/udf.py:400, in UserDefinedFunction.__call__(self, *cols)
398 sc.profiler_collector.add_profiler(id, memory_profiler)
399 else:
--> 400 judf = self._judf
401 jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
402 return Column(jPythonUDF)
File ~/Software/spark/python/pyspark/sql/udf.py:321, in UserDefinedFunction._judf(self)
314 @property
315 def _judf(self) -> JavaObject:
316 # It is possible that concurrent access, to newly created UDF,
317 # will initialize multiple UserDefinedPythonFunctions.
318 # This is unlikely, doesn't affect correctness,
319 # and should have a minimal performance impact.
320 if self._judf_placeholder is None:
--> 321 self._judf_placeholder = self._create_judf(self.func)
322 return self._judf_placeholder
File ~/Software/spark/python/pyspark/sql/udf.py:330, in UserDefinedFunction._create_judf(self, func)
327 spark = SparkSession._getActiveSessionOrCreate()
328 sc = spark.sparkContext
--> 330 wrapped_func = _wrap_function(sc, func, self.returnType)
331 jdt = spark._jsparkSession.parseDataType(self.returnType.json())
332 assert sc._jvm is not None
File ~/Software/spark/python/pyspark/sql/udf.py:59, in _wrap_function(sc, func, returnType)
57 else:
58 command = (func, returnType)
---> 59 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
60 assert sc._jvm is not None
61 return sc._jvm.SimplePythonFunction(
62 bytearray(pickled_command),
63 env,
(...)
68 sc._javaAccumulator,
69 )
File ~/Software/spark/python/pyspark/rdd.py:5251, in _prepare_for_python_RDD(sc, command)
5248 def _prepare_for_python_RDD(sc: "SparkContext", command: Any) -> Tuple[bytes, Any, Any, Any]:
5249 # the serialized command will be compressed by broadcast
5250 ser = CloudPickleSerializer()
-> 5251 pickled_command = ser.dumps(command)
5252 assert sc._jvm is not None
5253 if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc): # Default 1M
5254 # The broadcast will have same life cycle as created PythonRDD
File ~/Software/spark/python/pyspark/serializers.py:469, in CloudPickleSerializer.dumps(self, obj)
467 msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
468 print_exec(sys.stderr)
--> 469 raise pickle.PicklingError(msg)
PicklingError: Could not serialize object: TypeError: cannot pickle 'baml_py.baml_py.BamlRuntime' object
there's a few things that get hard here, but i can ask the team to allocate some resources for this effort perhaps as a future friday project!
(Every Friday we each take a request from a user and fix it so we can ship it in the following weeks release)
what we need to do here is:
the runtime (which is a rust-defined python object) can be made pickable by simply pickling the source files and env vars used to initialize it. I think we just need to implement some python methods on it that pickle will call.
Okay, thanks. This would facilitate using BAML with Databricks Mosaic as well.
@rjurney do you mind sharing a snippet of the code which caused this?
We merged a fix! #1990