marimo
marimo copied to clipboard
mo.ui.dataframe cannot render a Polars dataframe
Describe the bug
Thanks for creating such a refreshing alternative to Jupyter! However, I'm having trouble using polars with the dataframe ui widget. Here's the simplest example:
import marimo as mo
import polars as pl
df = pl.DataFrame({"a": [1,2,3], "b":[True, False, True], "c":["Bob", "Sally", "Jane"]})
ui_df = mo.ui.dataframe(df)
ui_df
Produces the error:
AttributeError
This cell raised an exception: AttributeError(''list' object has no attribute 'to_dict'')
with stacktrace:
Traceback (most recent call last):
Cell , line 5
ui_df = mo.ui.dataframe(df)
File /Users/johndoe/Library/Caches/pypoetry/virtualenvs/python-playground-dPNtgDsu-py3.12/lib/python3.12/site-packages/marimo/_plugins/ui/_impl/dataframes/dataframe.py, line 95, in __init__
"columns": df.dtypes.to_dict(),
AttributeError: 'list' object has no attribute 'to_dict'
Environment
pyproject.toml
[tool.poetry]
name = "python-playground"
version = "0.1.0"
description = "using poetry and marimo to make python actually nice"
authors = ["Your Name "]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.12"
marimo = "^0.1.76"
altair = "^5.2.0"
numpy = "^1.26.3"
polars = "^0.20.4"
pyarrow = "^14.0.2"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Code to reproduce
import marimo as mo
import polars as pl
df = pl.DataFrame({"a": [1,2,3], "b":[True, False, True], "c":["Bob", "Sally", "Jane"]})
ui_df = mo.ui.dataframe(df)
ui_df
Unfortunately we don’t support Polars at this time for this specific plugin. You’ll need to do to_pandas but you’ll be returned a pandas do.
We can add first-class Polars support for this though.
Also with to_pandas there is another error:
TypeError('Object of type DatetimeTZDtype is not JSON serializable')
with polars dataframe containing the column of type Datetime(time_unit='us', time_zone='UTC')
Except of mo.ui.dataframe, polars works for me excellently with marimo.
Thanks for finding the serialization bug - it will be fixed in this PR: https://github.com/marimo-team/marimo/pull/631
Hi there. I've taken a pass at looking into this issue with our tool, Glide. Please see the High-Level Plan and Implementation below. Note: The plan does not cover unit tests or frontend rendering.
Plan Steps
Step 1: Modify the dataframe class to accept Polars dataframes
Description: Update the dataframe class in marimo/_plugins/ui/_impl/dataframes/dataframe.py to handle Polars dataframes by checking the type of the input dataframe and setting up the appropriate handlers.
- Check the type of the
dfparameter in the__init__method. - If
dfis a Polars dataframe, initialize aPolarsTransformHandlersinstance. - Store the type of dataframe (Pandas or Polars) for use in other methods.
Step 2: Create an abstract base class for transform handlers
Description: Define an abstract base class TransformHandlers with abstract methods for each type of transform.
- Define an abstract base class
TransformHandlers. - Add abstract methods for each transform type:
handle_column_conversion,handle_rename_column,handle_sort_column,handle_filter_rows,handle_group_by,handle_aggregate,handle_select_columns,handle_shuffle_rows, andhandle_sample_rows.
Step 3: Implement PandasTransformHandlers and PolarsTransformHandlers subclasses
Description: Create two subclasses that inherit from TransformHandlers and implement the methods for handling transforms for Pandas and Polars dataframes respectively.
- Create
PandasTransformHandlerssubclass with methods that use Pandas-specific code. - Create
PolarsTransformHandlerssubclass with methods that use Polars-specific code.
Step 4: Implement Polars transform handler methods
Description: Implement each transform handler method in PolarsTransformHandlers using the provided Polars code context.
- For
handle_shuffle_rows, use Polars'sample_fracfunction withfrac=1andshuffle=True. - Implement other methods by translating the Pandas logic to Polars API calls, using the provided code context as a reference.
Step 5: Implement get_dataframe for Polars
Description: Use Polars' write_csv method to produce a CSV output that matches the format expected by Marimo.
- Modify the
get_dataframemethod to check the type of dataframe. - If it's a Polars dataframe, use
write_csvwith the correct arguments to produce a CSV string. - Convert the CSV string to a
VirtualFileand return the appropriateGetDataFrameResponse.
Step 6: Update get_column_values for Polars
Description: Ensure that the get_column_values function can retrieve unique values from a specified column in a Polars dataframe.
- Modify the
get_column_valuesmethod to handle Polars dataframes. - Use Polars' API to get unique values from the specified column.
- Return a
GetColumnValuesResponsewith the values or indicate if there are too many values.
Additional context to gather
- Verify the compatibility of CSV formatting options between Pandas'
to_csvand Polars'write_csvmethods. - Ensure that the
VirtualFilecreation process is consistent with Marimo's handling of CSV data.
Watch out for:
- Ensure that the behavior of transformations is consistent between Pandas and Polars dataframes.
- Make sure that the CSV output from Polars is formatted correctly for Marimo's expectations.
- Handle any edge cases or differences in API behavior between Pandas and Polars.
Implementation Plan
Edit 1: Update the dataframe class to accept Polars dataframes
Description: Modify the dataframe class in marimo/_plugins/ui/_impl/dataframes/dataframe.py to handle both Pandas and Polars dataframes by checking the type of the input dataframe and setting up the appropriate handlers. Code:
# Import necessary libraries at the top of the file
import polars as pl
# Modify the __init__ method of the dataframe class
def __init__(
self,
df: Union[pd.DataFrame, pl.DataFrame],
on_change: Optional[Callable[[Union[pd.DataFrame, pl.DataFrame]], None]] = None,
) -> None:
dataframe_name = "df"
try:
frame = inspect.currentframe()
if frame is not None and frame.f_back is not None:
for (
var_name,
var_value,
) in frame.f_back.f_locals.items():
if var_value is df:
dataframe_name = var_name
break
except Exception:
pass
self._data = df
self._transform_container = TransformsContainer(df)
self._error: Optional[str] = None
# Determine if the dataframe is a Pandas or Polars dataframe
if isinstance(df, pd.DataFrame):
self._df_type = 'pandas'
elif isinstance(df, pl.DataFrame):
self._df_type = 'polars'
else:
raise ValueError("Unsupported dataframe type. Only Pandas and Polars dataframes are supported.")
super().__init__(
component_name=dataframe._name,
initial_value={
"transforms": [],
},
on_change=on_change,
label="",
args={
"columns": self._get_columns_dict(df),
"dataframe-name": dataframe_name,
"total": len(df),
},
functions=(
Function(
name=self.get_dataframe.__name__,
arg_cls=EmptyArgs,
function=self.get_dataframe,
),
Function(
name=self.get_column_values.__name__,
arg_cls=GetColumnValuesArgs,
function=self.get_column_values,
),
),
)
# Add a new method to get the columns dictionary
def _get_columns_dict(self, df: Union[pd.DataFrame, pl.DataFrame]) -> Dict[str, str]:
if self._df_type == 'pandas':
return df.dtypes.to_dict()
elif self._df_type == 'polars':
return {name: str(dtype) for name, dtype in zip(df.columns, df.dtypes)}
else:
raise ValueError("Unsupported dataframe type.")
Edit 2: Create an abstract base class for transform handlers and implement subclasses
Description: Define an abstract base class TransformHandlers with abstract methods for each type of transform and create two subclasses PandasTransformHandlers and PolarsTransformHandlers. Code:
from abc import ABC, abstractmethod
from typing import Union
# Abstract base class for transform handlers
class TransformHandlers(ABC):
def handle(self, df, transform):
transform_type = transform.type
if transform_type is TransformType.COLUMN_CONVERSION:
return self.handle_column_conversion(df, transform)
elif transform_type is TransformType.RENAME_COLUMN:
return self.handle_rename_column(df, transform)
elif transform_type is TransformType.SORT_COLUMN:
return self.handle_sort_column(df, transform)
elif transform_type is TransformType.FILTER_ROWS:
return self.handle_filter_rows(df, transform)
elif transform_type is TransformType.GROUP_BY:
return self.handle_group_by(df, transform)
elif transform_type is TransformType.AGGREGATE:
return self.handle_aggregate(df, transform)
elif transform_type is TransformType.SELECT_COLUMNS:
return self.handle_select_columns(df, transform)
elif transform_type is TransformType.SHUFFLE_ROWS:
return self.handle_shuffle_rows(df, transform)
elif transform_type is TransformType.SAMPLE_ROWS:
return self.handle_sample_rows(df, transform)
else:
raise NotImplementedError(f"Transform type {transform_type} is not implemented")
@abstractmethod
def handle_column_conversion(self, df, transform): pass
@abstractmethod
def handle_rename_column(self, df, transform): pass
@abstractmethod
def handle_sort_column(self, df, transform): pass
@abstractmethod
def handle_filter_rows(self, df, transform): pass
@abstractmethod
def handle_group_by(self, df, transform): pass
@abstractmethod
def handle_aggregate(self, df, transform): pass
@abstractmethod
def handle_select_columns(self, df, transform): pass
@abstractmethod
def handle_shuffle_rows(self, df, transform): pass
@abstractmethod
def handle_sample_rows(self, df, transform): pass
# Subclass for handling Pandas dataframe transforms
class PandasTransformHandlers(TransformHandlers):
def handle_column_conversion(self, df, transform):
# Existing implementation from TransformHandlers for column conversion should be moved here
def handle_rename_column(self, df, transform):
# Existing implementation from TransformHandlers for column renaming should be moved here
def handle_sort_column(self, df, transform):
# Existing implementation from TransformHandlers for sorting columns should be moved here
def handle_filter_rows(self, df, transform):
# Existing implementation from TransformHandlers for filtering rows should be moved here
def handle_group_by(self, df, transform):
# Existing implementation from TransformHandlers for grouping by columns should be moved here
def handle_aggregate(self, df, transform):
# Existing implementation from TransformHandlers for aggregating data should be moved here
def handle_select_columns(self, df, transform):
# Existing implementation from TransformHandlers for selecting columns should be moved here
def handle_shuffle_rows(self, df, transform):
# Existing implementation from TransformHandlers for shuffling rows should be moved here
def handle_sample_rows(self, df, transform):
# Existing implementation from TransformHandlers for sampling rows should be moved here
# Subclass for handling Polars dataframe transforms
from polars import col
class PolarsTransformHandlers(TransformHandlers):
def handle_column_conversion(self, df, transform):
# Use the `cast` method from the Polars API
dtypes = {transform.column_id: transform.data_type}
return df.cast(dtypes)
def handle_rename_column(self, df, transform):
# Use the `rename` method from the Polars API
return df.rename({transform.column_id: transform.new_column_id})
def handle_sort_column(self, df, transform):
# Use the `sort` method from the Polars API
return df.sort(by_column=transform.column_id, descending=not transform.ascending, nulls_last=transform.na_position == 'last')
def handle_filter_rows(self, df, transform):
# Start with no filter (all rows included)
filter_expr = None
# Iterate over all conditions and build the filter expression
for condition in transform.where:
column = col(condition.column_id)
value = condition.value
# Build the expression based on the operator
if condition.operator == "==":
condition_expr = column == value
elif condition.operator == "!=":
condition_expr = column != value
elif condition.operator == ">":
condition_expr = column > value
elif condition.operator == "<":
condition_expr = column < value
elif condition.operator == ">=":
condition_expr = column >= value
elif condition.operator == "<=":
condition_expr = column <= value
elif condition.operator == "is_true":
condition_expr = column.is_true()
elif condition.operator == "is_false":
condition_expr = column.is_false()
elif condition.operator == "is_nan":
condition_expr = column.is_null()
elif condition.operator == "is_not_nan":
condition_expr = column.is_not_null()
elif condition.operator == "equals":
condition_expr = column == value
elif condition.operator == "does_not_equal":
condition_expr = column != value
elif condition.operator == "contains":
condition_expr = column.str_contains(value)
elif condition.operator == "regex":
condition_expr = column.str_contains(value, regex=True)
elif condition.operator == "starts_with":
condition_expr = column.str_starts_with(value)
elif condition.operator == "ends_with":
condition_expr = column.str_ends_with(value)
elif condition.operator == "in":
condition_expr = column.is_in(value)
else:
raise ValueError(f"Unsupported operator: {condition.operator}")
# Combine the condition expression with the filter expression
if filter_expr is None:
filter_expr = condition_expr
else:
filter_expr = filter_expr & condition_expr
# Apply the filter expression to the dataframe
if filter_expr is not None:
df = df.filter(filter_expr)
# Handle the operation (keep_rows or remove_rows)
if transform.operation == "keep_rows":
return df
elif transform.operation == "remove_rows":
return df.filter(~filter_expr)
else:
raise ValueError(f"Unsupported operation: {transform.operation}")
def handle_group_by(self, df, transform):
# Use the `group_by` and `agg` methods from the Polars API
return df.groupby(transform.column_ids).agg(transform.aggregations)
def handle_aggregate(self, df, transform):
agg_exprs = []
for column_id, aggregations in transform.aggregations.items():
for agg_func in aggregations:
if agg_func == "count":
agg_exprs.append(col(column_id).count().alias(f"{column_id}_count"))
elif agg_func == "sum":
agg_exprs.append(col(column_id).sum().alias(f"{column_id}_sum"))
elif agg_func == "mean":
agg_exprs.append(col(column_id).mean().alias(f"{column_id}_mean"))
elif agg_func == "median":
agg_exprs.append(col(column_id).median().alias(f"{column_id}_median"))
elif agg_func == "min":
agg_exprs.append(col(column_id).min().alias(f"{column_id}_min"))
elif agg_func == "max":
agg_exprs.append(col(column_id).max().alias(f"{column_id}_max"))
else:
raise ValueError(f"Unsupported aggregation function: {agg_func}")
return df.groupby(transform.column_ids).agg(agg_exprs)
def handle_select_columns(self, df, transform):
# Use the `select` method from the Polars API
return df.select(transform.column_ids)
def handle_shuffle_rows(self, df, transform):
# Use the `sample_frac` method from the Polars API with frac=1 and shuffle=True
return df.sample_frac(frac=1, shuffle=True, seed=transform.seed)
def handle_sample_rows(self, df, transform):
# Use the `sample_n` method from the Polars API
return df.sample_n(n=transform.n, shuffle=True, seed=transform.seed, with_replacement=transform.replace)
Edit 3: Implement get_dataframe for Polars
Description: Modify the get_dataframe method in the dataframe class to handle Polars dataframes using the write_csv method to produce a compatible CSV output.
import io # Make sure to import io at the top of the file
# Modify the get_dataframe method of the dataframe class
def get_dataframe(self, _args: EmptyArgs) -> GetDataFrameResponse:
LIMIT = 100
if self._error is not None:
raise Exception(self._error)
# Check if the dataframe is a Polars dataframe and handle accordingly
if self._df_type == 'polars':
# Create a buffer to write the CSV data
buffer = io.BytesIO()
# Write the CSV data to the buffer
self._value.head(LIMIT).write_csv(buffer)
# Seek to the start of the buffer to read its content
buffer.seek(0)
# Read the buffer content
csv_data = buffer.read()
# Create a VirtualFile from the CSV data
url = mo_data.any_data(csv_data, ext="csv").url
else:
# Existing handling for Pandas dataframe
url = mo_data.csv(self._value.head(LIMIT)).url
total_rows = len(self._value)
return GetDataFrameResponse(
url=url,
total_rows=total_rows,
has_more=total_rows > LIMIT,
row_headers=get_row_headers(self._value),
)
Edit 4: Update get_column_values for Polars
Description: Update the get_column_values method in the dataframe class to work with Polars dataframes, retrieving unique values from a specified column.
# Modify the get_column_values method of the dataframe class
def get_column_values(self, args: GetColumnValuesArgs) -> GetColumnValuesResponse:
LIMIT = 500
# Check if the dataframe is a Polars dataframe and handle accordingly
if self._df_type == 'polars':
# Use Polars' API to get unique values from the specified column
unique_values = self._data.select(args.column).unique().to_list()
else:
# Existing handling for Pandas dataframe
unique_values = self._data[args.column].unique().tolist()
if len(unique_values) <= LIMIT:
return GetColumnValuesResponse(
values=list(sorted(unique_values, key=str)),
too_many_values=False,
)
else:
return GetColumnValuesResponse(
values=[],
too_many_values=True,
)
Generated with Glide by Agentic Labs