marimo icon indicating copy to clipboard operation
marimo copied to clipboard

mo.ui.dataframe cannot render a Polars dataframe

Open jsnelgro opened this issue 1 year ago • 4 comments

Describe the bug

Thanks for creating such a refreshing alternative to Jupyter! However, I'm having trouble using polars with the dataframe ui widget. Here's the simplest example:

import marimo as mo
import polars as pl
 
df = pl.DataFrame({"a": [1,2,3], "b":[True, False, True], "c":["Bob", "Sally", "Jane"]})
ui_df = mo.ui.dataframe(df)
ui_df

Produces the error:

AttributeError
This cell raised an exception: AttributeError(''list' object has no attribute 'to_dict'')

with stacktrace:

Traceback (most recent call last):
  Cell , line 5
    ui_df = mo.ui.dataframe(df)
  File /Users/johndoe/Library/Caches/pypoetry/virtualenvs/python-playground-dPNtgDsu-py3.12/lib/python3.12/site-packages/marimo/_plugins/ui/_impl/dataframes/dataframe.py, line 95, in __init__
    "columns": df.dtypes.to_dict(),
AttributeError: 'list' object has no attribute 'to_dict'

Environment

pyproject.toml

[tool.poetry]
name = "python-playground"
version = "0.1.0"
description = "using poetry and marimo to make python actually nice"
authors = ["Your Name "]
readme = "README.md"
 
[tool.poetry.dependencies]
python = "^3.12"
marimo = "^0.1.76"
altair = "^5.2.0"
numpy = "^1.26.3"
polars = "^0.20.4"
pyarrow = "^14.0.2"
 
 
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Code to reproduce

import marimo as mo
import polars as pl
 
df = pl.DataFrame({"a": [1,2,3], "b":[True, False, True], "c":["Bob", "Sally", "Jane"]})
ui_df = mo.ui.dataframe(df)
ui_df

jsnelgro avatar Jan 21 '24 21:01 jsnelgro

Unfortunately we don’t support Polars at this time for this specific plugin. You’ll need to do to_pandas but you’ll be returned a pandas do.

We can add first-class Polars support for this though.

mscolnick avatar Jan 22 '24 01:01 mscolnick

Also with to_pandas there is another error: TypeError('Object of type DatetimeTZDtype is not JSON serializable') with polars dataframe containing the column of type Datetime(time_unit='us', time_zone='UTC') Except of mo.ui.dataframe, polars works for me excellently with marimo.

misolietavec avatar Jan 22 '24 19:01 misolietavec

Thanks for finding the serialization bug - it will be fixed in this PR: https://github.com/marimo-team/marimo/pull/631

mscolnick avatar Jan 23 '24 04:01 mscolnick

Hi there. I've taken a pass at looking into this issue with our tool, Glide. Please see the High-Level Plan and Implementation below. Note: The plan does not cover unit tests or frontend rendering.

Plan Steps

Step 1: Modify the dataframe class to accept Polars dataframes

Description: Update the dataframe class in marimo/_plugins/ui/_impl/dataframes/dataframe.py to handle Polars dataframes by checking the type of the input dataframe and setting up the appropriate handlers.

  • Check the type of the df parameter in the __init__ method.
  • If df is a Polars dataframe, initialize a PolarsTransformHandlers instance.
  • Store the type of dataframe (Pandas or Polars) for use in other methods.

Step 2: Create an abstract base class for transform handlers

Description: Define an abstract base class TransformHandlers with abstract methods for each type of transform.

  • Define an abstract base class TransformHandlers.
  • Add abstract methods for each transform type: handle_column_conversion, handle_rename_column, handle_sort_column, handle_filter_rows, handle_group_by, handle_aggregate, handle_select_columns, handle_shuffle_rows, and handle_sample_rows.

Step 3: Implement PandasTransformHandlers and PolarsTransformHandlers subclasses

Description: Create two subclasses that inherit from TransformHandlers and implement the methods for handling transforms for Pandas and Polars dataframes respectively.

  • Create PandasTransformHandlers subclass with methods that use Pandas-specific code.
  • Create PolarsTransformHandlers subclass with methods that use Polars-specific code.

Step 4: Implement Polars transform handler methods

Description: Implement each transform handler method in PolarsTransformHandlers using the provided Polars code context.

  • For handle_shuffle_rows, use Polars' sample_frac function with frac=1 and shuffle=True.
  • Implement other methods by translating the Pandas logic to Polars API calls, using the provided code context as a reference.

Step 5: Implement get_dataframe for Polars

Description: Use Polars' write_csv method to produce a CSV output that matches the format expected by Marimo.

  • Modify the get_dataframe method to check the type of dataframe.
  • If it's a Polars dataframe, use write_csv with the correct arguments to produce a CSV string.
  • Convert the CSV string to a VirtualFile and return the appropriate GetDataFrameResponse.

Step 6: Update get_column_values for Polars

Description: Ensure that the get_column_values function can retrieve unique values from a specified column in a Polars dataframe.

  • Modify the get_column_values method to handle Polars dataframes.
  • Use Polars' API to get unique values from the specified column.
  • Return a GetColumnValuesResponse with the values or indicate if there are too many values.

Additional context to gather

  • Verify the compatibility of CSV formatting options between Pandas' to_csv and Polars' write_csv methods.
  • Ensure that the VirtualFile creation process is consistent with Marimo's handling of CSV data.

Watch out for:

  • Ensure that the behavior of transformations is consistent between Pandas and Polars dataframes.
  • Make sure that the CSV output from Polars is formatted correctly for Marimo's expectations.
  • Handle any edge cases or differences in API behavior between Pandas and Polars.

Implementation Plan

Edit 1: Update the dataframe class to accept Polars dataframes

Description: Modify the dataframe class in marimo/_plugins/ui/_impl/dataframes/dataframe.py to handle both Pandas and Polars dataframes by checking the type of the input dataframe and setting up the appropriate handlers. Code:

# Import necessary libraries at the top of the file
import polars as pl

# Modify the __init__ method of the dataframe class
def __init__(
    self,
    df: Union[pd.DataFrame, pl.DataFrame],
    on_change: Optional[Callable[[Union[pd.DataFrame, pl.DataFrame]], None]] = None,
) -> None:
    dataframe_name = "df"
    try:
        frame = inspect.currentframe()
        if frame is not None and frame.f_back is not None:
            for (
                var_name,
                var_value,
            ) in frame.f_back.f_locals.items():
                if var_value is df:
                    dataframe_name = var_name
                    break
    except Exception:
        pass

    self._data = df
    self._transform_container = TransformsContainer(df)
    self._error: Optional[str] = None

    # Determine if the dataframe is a Pandas or Polars dataframe
    if isinstance(df, pd.DataFrame):
        self._df_type = 'pandas'
    elif isinstance(df, pl.DataFrame):
        self._df_type = 'polars'
    else:
        raise ValueError("Unsupported dataframe type. Only Pandas and Polars dataframes are supported.")

    super().__init__(
        component_name=dataframe._name,
        initial_value={
            "transforms": [],
        },
        on_change=on_change,
        label="",
        args={
            "columns": self._get_columns_dict(df),
            "dataframe-name": dataframe_name,
            "total": len(df),
        },
        functions=(
            Function(
                name=self.get_dataframe.__name__,
                arg_cls=EmptyArgs,
                function=self.get_dataframe,
            ),
            Function(
                name=self.get_column_values.__name__,
                arg_cls=GetColumnValuesArgs,
                function=self.get_column_values,
            ),
        ),
    )

# Add a new method to get the columns dictionary
def _get_columns_dict(self, df: Union[pd.DataFrame, pl.DataFrame]) -> Dict[str, str]:
    if self._df_type == 'pandas':
        return df.dtypes.to_dict()
    elif self._df_type == 'polars':
        return {name: str(dtype) for name, dtype in zip(df.columns, df.dtypes)}
    else:
        raise ValueError("Unsupported dataframe type.")

Edit 2: Create an abstract base class for transform handlers and implement subclasses

Description: Define an abstract base class TransformHandlers with abstract methods for each type of transform and create two subclasses PandasTransformHandlers and PolarsTransformHandlers. Code:

from abc import ABC, abstractmethod
from typing import Union

# Abstract base class for transform handlers
class TransformHandlers(ABC):
    def handle(self, df, transform):
        transform_type = transform.type

        if transform_type is TransformType.COLUMN_CONVERSION:
            return self.handle_column_conversion(df, transform)
        elif transform_type is TransformType.RENAME_COLUMN:
            return self.handle_rename_column(df, transform)
        elif transform_type is TransformType.SORT_COLUMN:
            return self.handle_sort_column(df, transform)
        elif transform_type is TransformType.FILTER_ROWS:
            return self.handle_filter_rows(df, transform)
        elif transform_type is TransformType.GROUP_BY:
            return self.handle_group_by(df, transform)
        elif transform_type is TransformType.AGGREGATE:
            return self.handle_aggregate(df, transform)
        elif transform_type is TransformType.SELECT_COLUMNS:
            return self.handle_select_columns(df, transform)
        elif transform_type is TransformType.SHUFFLE_ROWS:
            return self.handle_shuffle_rows(df, transform)
        elif transform_type is TransformType.SAMPLE_ROWS:
            return self.handle_sample_rows(df, transform)
        else:
            raise NotImplementedError(f"Transform type {transform_type} is not implemented")

    @abstractmethod
    def handle_column_conversion(self, df, transform): pass

    @abstractmethod
    def handle_rename_column(self, df, transform): pass

    @abstractmethod
    def handle_sort_column(self, df, transform): pass

    @abstractmethod
    def handle_filter_rows(self, df, transform): pass

    @abstractmethod
    def handle_group_by(self, df, transform): pass

    @abstractmethod
    def handle_aggregate(self, df, transform): pass

    @abstractmethod
    def handle_select_columns(self, df, transform): pass

    @abstractmethod
    def handle_shuffle_rows(self, df, transform): pass

    @abstractmethod
    def handle_sample_rows(self, df, transform): pass
# Subclass for handling Pandas dataframe transforms
class PandasTransformHandlers(TransformHandlers):
    def handle_column_conversion(self, df, transform):
        # Existing implementation from TransformHandlers for column conversion should be moved here

    def handle_rename_column(self, df, transform):
        # Existing implementation from TransformHandlers for column renaming should be moved here

    def handle_sort_column(self, df, transform):
        # Existing implementation from TransformHandlers for sorting columns should be moved here

    def handle_filter_rows(self, df, transform):
        # Existing implementation from TransformHandlers for filtering rows should be moved here

    def handle_group_by(self, df, transform):
        # Existing implementation from TransformHandlers for grouping by columns should be moved here

    def handle_aggregate(self, df, transform):
        # Existing implementation from TransformHandlers for aggregating data should be moved here

    def handle_select_columns(self, df, transform):
        # Existing implementation from TransformHandlers for selecting columns should be moved here

    def handle_shuffle_rows(self, df, transform):
        # Existing implementation from TransformHandlers for shuffling rows should be moved here

    def handle_sample_rows(self, df, transform):
        # Existing implementation from TransformHandlers for sampling rows should be moved here
# Subclass for handling Polars dataframe transforms
from polars import col

class PolarsTransformHandlers(TransformHandlers):
    def handle_column_conversion(self, df, transform):
        # Use the `cast` method from the Polars API
        dtypes = {transform.column_id: transform.data_type}
        return df.cast(dtypes)

    def handle_rename_column(self, df, transform):
        # Use the `rename` method from the Polars API
        return df.rename({transform.column_id: transform.new_column_id})

    def handle_sort_column(self, df, transform):
        # Use the `sort` method from the Polars API
        return df.sort(by_column=transform.column_id, descending=not transform.ascending, nulls_last=transform.na_position == 'last')

    def handle_filter_rows(self, df, transform):
        # Start with no filter (all rows included)
        filter_expr = None

        # Iterate over all conditions and build the filter expression
        for condition in transform.where:
            column = col(condition.column_id)
            value = condition.value

            # Build the expression based on the operator
            if condition.operator == "==":
                condition_expr = column == value
            elif condition.operator == "!=":
                condition_expr = column != value
            elif condition.operator == ">":
                condition_expr = column > value
            elif condition.operator == "<":
                condition_expr = column < value
            elif condition.operator == ">=":
                condition_expr = column >= value
            elif condition.operator == "<=":
                condition_expr = column <= value
            elif condition.operator == "is_true":
                condition_expr = column.is_true()
            elif condition.operator == "is_false":
                condition_expr = column.is_false()
            elif condition.operator == "is_nan":
                condition_expr = column.is_null()
            elif condition.operator == "is_not_nan":
                condition_expr = column.is_not_null()
            elif condition.operator == "equals":
                condition_expr = column == value
            elif condition.operator == "does_not_equal":
                condition_expr = column != value
            elif condition.operator == "contains":
                condition_expr = column.str_contains(value)
            elif condition.operator == "regex":
                condition_expr = column.str_contains(value, regex=True)
            elif condition.operator == "starts_with":
                condition_expr = column.str_starts_with(value)
            elif condition.operator == "ends_with":
                condition_expr = column.str_ends_with(value)
            elif condition.operator == "in":
                condition_expr = column.is_in(value)
            else:
                raise ValueError(f"Unsupported operator: {condition.operator}")

            # Combine the condition expression with the filter expression
            if filter_expr is None:
                filter_expr = condition_expr
            else:
                filter_expr = filter_expr & condition_expr

        # Apply the filter expression to the dataframe
        if filter_expr is not None:
            df = df.filter(filter_expr)

        # Handle the operation (keep_rows or remove_rows)
        if transform.operation == "keep_rows":
            return df
        elif transform.operation == "remove_rows":
            return df.filter(~filter_expr)
        else:
            raise ValueError(f"Unsupported operation: {transform.operation}")

    def handle_group_by(self, df, transform):
        # Use the `group_by` and `agg` methods from the Polars API
        return df.groupby(transform.column_ids).agg(transform.aggregations)

    def handle_aggregate(self, df, transform):
        agg_exprs = []

        for column_id, aggregations in transform.aggregations.items():
            for agg_func in aggregations:
                if agg_func == "count":
                    agg_exprs.append(col(column_id).count().alias(f"{column_id}_count"))
                elif agg_func == "sum":
                    agg_exprs.append(col(column_id).sum().alias(f"{column_id}_sum"))
                elif agg_func == "mean":
                    agg_exprs.append(col(column_id).mean().alias(f"{column_id}_mean"))
                elif agg_func == "median":
                    agg_exprs.append(col(column_id).median().alias(f"{column_id}_median"))
                elif agg_func == "min":
                    agg_exprs.append(col(column_id).min().alias(f"{column_id}_min"))
                elif agg_func == "max":
                    agg_exprs.append(col(column_id).max().alias(f"{column_id}_max"))
                else:
                    raise ValueError(f"Unsupported aggregation function: {agg_func}")

        return df.groupby(transform.column_ids).agg(agg_exprs)

    def handle_select_columns(self, df, transform):
        # Use the `select` method from the Polars API
        return df.select(transform.column_ids)

    def handle_shuffle_rows(self, df, transform):
        # Use the `sample_frac` method from the Polars API with frac=1 and shuffle=True
        return df.sample_frac(frac=1, shuffle=True, seed=transform.seed)

    def handle_sample_rows(self, df, transform):
        # Use the `sample_n` method from the Polars API
        return df.sample_n(n=transform.n, shuffle=True, seed=transform.seed, with_replacement=transform.replace)

Edit 3: Implement get_dataframe for Polars

Description: Modify the get_dataframe method in the dataframe class to handle Polars dataframes using the write_csv method to produce a compatible CSV output.

import io  # Make sure to import io at the top of the file

# Modify the get_dataframe method of the dataframe class
def get_dataframe(self, _args: EmptyArgs) -> GetDataFrameResponse:
    LIMIT = 100

    if self._error is not None:
        raise Exception(self._error)

    # Check if the dataframe is a Polars dataframe and handle accordingly
    if self._df_type == 'polars':
        # Create a buffer to write the CSV data
        buffer = io.BytesIO()
        # Write the CSV data to the buffer
        self._value.head(LIMIT).write_csv(buffer)
        # Seek to the start of the buffer to read its content
        buffer.seek(0)
        # Read the buffer content
        csv_data = buffer.read()
        # Create a VirtualFile from the CSV data
        url = mo_data.any_data(csv_data, ext="csv").url
    else:
        # Existing handling for Pandas dataframe
        url = mo_data.csv(self._value.head(LIMIT)).url

    total_rows = len(self._value)
    return GetDataFrameResponse(
        url=url,
        total_rows=total_rows,
        has_more=total_rows > LIMIT,
        row_headers=get_row_headers(self._value),
    )

Edit 4: Update get_column_values for Polars

Description: Update the get_column_values method in the dataframe class to work with Polars dataframes, retrieving unique values from a specified column.

# Modify the get_column_values method of the dataframe class
def get_column_values(self, args: GetColumnValuesArgs) -> GetColumnValuesResponse:
    LIMIT = 500

    # Check if the dataframe is a Polars dataframe and handle accordingly
    if self._df_type == 'polars':
        # Use Polars' API to get unique values from the specified column
        unique_values = self._data.select(args.column).unique().to_list()
    else:
        # Existing handling for Pandas dataframe
        unique_values = self._data[args.column].unique().tolist()

    if len(unique_values) <= LIMIT:
        return GetColumnValuesResponse(
            values=list(sorted(unique_values, key=str)),
            too_many_values=False,
        )
    else:
        return GetColumnValuesResponse(
            values=[],
            too_many_values=True,
        )

Generated with Glide by Agentic Labs

robmck1995 avatar Mar 11 '24 21:03 robmck1995