pyspark-ai icon indicating copy to clipboard operation
pyspark-ai copied to clipboard

Support PySpark Code generation for `transform` and improve usability.

Open grundprinzip opened this issue 1 year ago • 1 comments

This patch adds a new feature to the project that allows generating PySpark code instead of SQL code for a given prompt. This is valuable since it's closer to the code itself and makes it easier to reason about the behavior.

The code can be used as follows:

In [3]: from pyspark_ai import SparkAI
   ...: from langchain.chat_models import ChatOpenAI
   ...:
   ...: llm = ChatOpenAI(model="gpt-3.5-turbo")
   ...: ai = SparkAI(llm=llm, spark_session=spark)
   ...: ai.activate()
   ...: df = spark.range(10)
   ...:
   ...: df.ai.transform("count of rows", language="Python")
INFO: Python Code for the transform:
from pyspark.sql.functions import count

def transform_df(df):
    count_df = df.select(count("*").alias("count"))
    return count_df

Out[3]: DataFrame[count: bigint]

This is, in particular, useful as it allows to circumvent the issue of #75 as it does not rely on the eager evaluation of the full query plan, but can continue to be used lazily.

Secondly, this patch adds a feature that allows tracking the history of changes to a data frame based on prompts.

In [6]: x.ai.history()
Out[6]: [<AIHistoryElement prompt: count of rows>]

In [9]: x.ai.history()[0].prompt
Out[9]: 'count of rows'

In [10]: x.ai.history()[0].llm_result
Out[10]: 'Here is an example of a Python function called `transform_df` that performs the required transformation:\n\n```python\ndef transform_df(df):\n    from pyspark.sql.functions import count\n    \n    return df.agg(count("*").alias("count"))\n```\n\nThis function uses the `agg` method of the dataframe to calculate the count of rows in the dataframe. The `count` function from `pyspark.sql.functions` is used to count the rows, and the `alias` method is used to assign the column name "count" to the result. The resulting dataframe is returned.\n\nNote that this function imports the `count` function from `pyspark.sql.functions` within the function definition, as required by the specifications.'

In [11]: x.ai.history()[0].df
Out[11]: DataFrame[id: bigint]

grundprinzip avatar Jul 24 '23 21:07 grundprinzip

There is one caveat with the history as it does not allow tracking for the following situation:

df = spark.range(10)
df = df.ai.transform("prompt")
assert(len(df.ai.history()) == 1) # Works
df = df.withColumn("x", F.lit(1))
assert(len(df.ai.history()) == 1) # Fails

grundprinzip avatar Jul 24 '23 21:07 grundprinzip