mlflow-snowflake icon indicating copy to clipboard operation
mlflow-snowflake copied to clipboard

Support for custom scikit-learn transformers

Open jonwiggins opened this issue 2 years ago • 1 comments

Hi,

I have a model which uses a custom transformer, it's used to featurize text input into various scalar values, like:

from sklearn.base import TransformerMixin

class MyTransformer(TransformerMixin):
    def __init__(self,):
        pass
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        to_return = X.copy()
        to_return.loc[:, "length"] = to_return["text"].str.len()
        ...
        return to_return[feature_columns]

I then use this transformer in my model like:

from sklearn.pipeline import Pipeline
from sklearn.tree import DecisionTreeClassifier
from mylib import MyTransformer
...
my_pipeline = Pipeline(steps=[("my_transformer", MyTransformer()),
                              ("dtc", DecisionTreeClassifier()),
                          ]
                      )

This seems reasonable to me, and AFAIK it is how these transformers are intended to be used. However, when I try to create a deployment UDF using this library I get:

  File "/root/miniconda3/lib/python3.8/site-packages/snowflake/snowpark/_internal/server_connection.py", line 317, in run_query
    results_cursor = self._cursor.execute(query, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/snowflake/connector/cursor.py", line 804, in execute
    Error.errorhandler_wrapper(self.connection, self, error_class, errvalue)
  File "/root/miniconda3/lib/python3.8/site-packages/snowflake/connector/errors.py", line 276, in errorhandler_wrapper
    handed_over = Error.hand_to_other_handler(
  File "/root/miniconda3/lib/python3.8/site-packages/snowflake/connector/errors.py", line 331, in hand_to_other_handler
    cursor.errorhandler(connection, cursor, error_class, error_value)
  File "/root/miniconda3/lib/python3.8/site-packages/snowflake/connector/errors.py", line 210, in default_errorhandler
    raise error_class(
snowflake.snowpark.exceptions.SnowparkSQLException ...
Python Interpreter Error:
ModuleNotFoundError: No module named 'mylib' in function MLFLOW$TEST_PREDICT with handler ...

Typically when I recall this model using mlflow I'm able to make it work by importing the transoformer in the context like:

from mylib import MyTransformer
...
model = mlflow.pyfunc.load_model(model_uri=my_model_uri)
model.predict(my_data)

It seems like there is probably a way to make this work. Because this relatively simple operation is done in python, perhaps there is some way to make a UDF to do the operation, or maybe snowpark could detect this python dependency and make sure to import it as well when publishing the UDF. I'm interested to hear your thoughts.

Thanks, Jon

jonwiggins avatar Feb 14 '23 16:02 jonwiggins