dask-sql
dask-sql copied to clipboard
[ENH] Support string column type specification in row UDFs
cuDF's 22.10 branch recently merged support for string UDFs-
>>> import cudf
>>> df = cudf.DataFrame({'str_col': ['a', 'bcd', 'efg']})
>>>
>>> def f(row):
... st = row['str_col']
... return len(st)
...
>>> result = df.apply(f, axis=1)
>>> print(result)
0 1
1 3
2 3
dtype: int32
I'm attempting to use these with Dask-SQL row UDFs:
from dask_sql import Context
>>> c = Context()
>>> c.create_table("df", df)
# attempting to supply no input param types:
>>> c.register_function(f, "len", [], return_type=np.int32, row_udf=True, replace=True)
>>> c.sql("select len(str_col) from df").compute()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/conda/envs/rapids/lib/python3.9/site-packages/dask_sql/context.py", line 493, in sql
rel, _ = self._get_ral(sql)
File "/opt/conda/envs/rapids/lib/python3.9/site-packages/dask_sql/context.py", line 811, in _get_ral
raise ParsingException(sql, str(pe)) from None
dask_sql.utils.ParsingException: Plan("UDF signature not found for input types [Utf8]")
# attempting to supply "Utf8" type:
>>> c.register_function(f, "len", [("str_col", "Utf8")], return_type=np.int32, row_udf=True, replace=True)
Traceback (most recent call last):
File "/opt/conda/envs/rapids/lib/python3.9/site-packages/dask_sql/mappings.py", line 112, in python_to_sql_type
return DaskTypeMap(_PYTHON_TO_SQL[python_type])
KeyError: 'Utf8'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/conda/envs/rapids/lib/python3.9/site-packages/dask_sql/context.py", line 375, in register_function
self._register_callable(
File "/opt/conda/envs/rapids/lib/python3.9/site-packages/dask_sql/context.py", line 902, in _register_callable
sql_parameters = [
File "/opt/conda/envs/rapids/lib/python3.9/site-packages/dask_sql/context.py", line 903, in <listcomp>
(name, python_to_sql_type(param_type)) for name, param_type in parameters
File "/opt/conda/envs/rapids/lib/python3.9/site-packages/dask_sql/mappings.py", line 114, in python_to_sql_type
raise NotImplementedError(
NotImplementedError: The python type Utf8 is not implemented (yet)
I get a similar error when trying to register against varchar and VARCHAR.
unfortunately we don't seem to be able to register or use row-UDFs with string columns just yet, or I'm not specifying them correctly.
Thanks @brandon-b-miller for the tip- it should be:
c.register_function(f, "len", [("str_col", np.dtype("object"))], return_type=np.int32, row_udf=True, replace=True)
Glad to see it up and running - I'll do a doc update here and add some tests, if you'd like to assign me 👍