baml icon indicating copy to clipboard operation
baml copied to clipboard

Make baml_py.baml_py.BamlRuntime pickle-able

Open rjurney opened this issue 7 months ago • 4 comments

I am trying to integrate BAML with PySpark and am having trouble... I get a pickle error. I suspected this might happen and want to work on getting it to go. Do you know what about the class isn't pickle-able? I would like to fix it.

The code is a PySpark UDF that pickles objects to distribute compute jobs.

The exception is:

Traceback (most recent call last):
  File "/Users/rjurney/Software/spark/python/pyspark/serializers.py", line 459, in dumps
    return cloudpickle.dumps(obj, pickle_protocol)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/rjurney/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/Users/rjurney/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 632, in dump
    return Pickler.dump(self, obj)
           ^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'baml_py.baml_py.BamlRuntime' object
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/Software/spark/python/pyspark/serializers.py:459, in CloudPickleSerializer.dumps(self, obj)
    458 try:
--> 459     return cloudpickle.dumps(obj, pickle_protocol)
    460 except pickle.PickleError:

File ~/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py:73, in dumps(obj, protocol, buffer_callback)
     70 cp = CloudPickler(
     71     file, protocol=protocol, buffer_callback=buffer_callback
     72 )
---> 73 cp.dump(obj)
     74 return file.getvalue()

File ~/Software/spark/python/pyspark/cloudpickle/cloudpickle_fast.py:632, in CloudPickler.dump(self, obj)
    631 try:
--> 632     return Pickler.dump(self, obj)
    633 except RuntimeError as e:

TypeError: cannot pickle 'baml_py.baml_py.BamlRuntime' object

During handling of the above exception, another exception occurred:

PicklingError                             Traceback (most recent call last)
Cell In[10], line 1
----> 1 article_df = article_df.withColumn("document", F.udf(parse_article, doc_schema)("url"))

File ~/Software/spark/python/pyspark/sql/udf.py:423, in UserDefinedFunction._wrapped.<locals>.wrapper(*args)
    421 @functools.wraps(self.func, assigned=assignments)
    422 def wrapper(*args: "ColumnOrName") -> Column:
--> 423     return self(*args)

File ~/Software/spark/python/pyspark/sql/udf.py:400, in UserDefinedFunction.__call__(self, *cols)
    398         sc.profiler_collector.add_profiler(id, memory_profiler)
    399 else:
--> 400     judf = self._judf
    401     jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
    402 return Column(jPythonUDF)

File ~/Software/spark/python/pyspark/sql/udf.py:321, in UserDefinedFunction._judf(self)
    314 @property
    315 def _judf(self) -> JavaObject:
    316     # It is possible that concurrent access, to newly created UDF,
    317     # will initialize multiple UserDefinedPythonFunctions.
    318     # This is unlikely, doesn't affect correctness,
    319     # and should have a minimal performance impact.
    320     if self._judf_placeholder is None:
--> 321         self._judf_placeholder = self._create_judf(self.func)
    322     return self._judf_placeholder

File ~/Software/spark/python/pyspark/sql/udf.py:330, in UserDefinedFunction._create_judf(self, func)
    327 spark = SparkSession._getActiveSessionOrCreate()
    328 sc = spark.sparkContext
--> 330 wrapped_func = _wrap_function(sc, func, self.returnType)
    331 jdt = spark._jsparkSession.parseDataType(self.returnType.json())
    332 assert sc._jvm is not None

File ~/Software/spark/python/pyspark/sql/udf.py:59, in _wrap_function(sc, func, returnType)
     57 else:
     58     command = (func, returnType)
---> 59 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
     60 assert sc._jvm is not None
     61 return sc._jvm.SimplePythonFunction(
     62     bytearray(pickled_command),
     63     env,
   (...)
     68     sc._javaAccumulator,
     69 )

File ~/Software/spark/python/pyspark/rdd.py:5251, in _prepare_for_python_RDD(sc, command)
   5248 def _prepare_for_python_RDD(sc: "SparkContext", command: Any) -> Tuple[bytes, Any, Any, Any]:
   5249     # the serialized command will be compressed by broadcast
   5250     ser = CloudPickleSerializer()
-> 5251     pickled_command = ser.dumps(command)
   5252     assert sc._jvm is not None
   5253     if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc):  # Default 1M
   5254         # The broadcast will have same life cycle as created PythonRDD

File ~/Software/spark/python/pyspark/serializers.py:469, in CloudPickleSerializer.dumps(self, obj)
    467     msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
    468 print_exec(sys.stderr)
--> 469 raise pickle.PicklingError(msg)

PicklingError: Could not serialize object: TypeError: cannot pickle 'baml_py.baml_py.BamlRuntime' object

rjurney avatar Apr 19 '25 03:04 rjurney

there's a few things that get hard here, but i can ask the team to allocate some resources for this effort perhaps as a future friday project!

(Every Friday we each take a request from a user and fix it so we can ship it in the following weeks release)

hellovai avatar Apr 19 '25 15:04 hellovai

what we need to do here is:

the runtime (which is a rust-defined python object) can be made pickable by simply pickling the source files and env vars used to initialize it. I think we just need to implement some python methods on it that pickle will call.

hellovai avatar Apr 19 '25 16:04 hellovai

Okay, thanks. This would facilitate using BAML with Databricks Mosaic as well.

rjurney avatar Apr 19 '25 18:04 rjurney

@rjurney do you mind sharing a snippet of the code which caused this?

ba11b0y avatar May 27 '25 18:05 ba11b0y

We merged a fix! #1990

ba11b0y avatar Jun 30 '25 21:06 ba11b0y