modin icon indicating copy to clipboard operation
modin copied to clipboard

PERF: `_propagate_index_objs` function does extra work

Open anmyachev opened this issue 2 years ago • 2 comments

_propagate_index_objs changes the index of all partitions, even if the new index is the same as the old one. We can try to avoid this and only update where the indexes are different. This will be especially useful for rename operation, which is now implemented through function index/columns properties.

This change is mainly to reduce serialization/deserialization costs.

anmyachev avatar Nov 14 '22 20:11 anmyachev

This code can be used as a starting point:

diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py
index c1e990e2..b17a2596 100644
--- a/modin/core/dataframe/pandas/dataframe/dataframe.py
+++ b/modin/core/dataframe/pandas/dataframe/dataframe.py
@@ -174,6 +174,7 @@ class PandasDataframe(ClassLogger):
     # These properties flag whether or not we are deferring the metadata synchronization
     _deferred_index = False
     _deferred_column = False
+    _old_labels = None
 
     @property
     def __constructor__(self):
@@ -413,11 +414,15 @@ class PandasDataframe(ClassLogger):
         if self._columns_cache is None:
             self._columns_cache = ensure_index(new_columns)
         else:
+            old_columns = self._columns_cache
             new_columns = self._validate_set_axis(new_columns, self._columns_cache)
             self._columns_cache = new_columns
             if self._dtypes is not None:
                 self._dtypes.index = new_columns
-        self.synchronize_labels(axis=1)
+        if not self._deferred_column:
+            self.synchronize_labels(axis=1, old_labels=old_columns)
+        else:
+            self.synchronize_labels(axis=1)
 
     columns = property(_get_columns, _set_columns)
     index = property(_get_index, _set_index)
@@ -496,7 +501,7 @@ class PandasDataframe(ClassLogger):
         self._column_widths_cache = [w for w in self.column_widths if w != 0]
         self._row_lengths_cache = [r for r in self.row_lengths if r != 0]
 
-    def synchronize_labels(self, axis=None):
+    def synchronize_labels(self, axis=None, old_labels=None):
         """
         Set the deferred axes variables for the ``PandasDataframe``.
 
@@ -513,6 +518,7 @@ class PandasDataframe(ClassLogger):
             self._deferred_index = True
         else:
             self._deferred_column = True
+        self._old_labels = old_labels
 
     def _propagate_index_objs(self, axis=None):
         """
@@ -585,22 +591,29 @@ class PandasDataframe(ClassLogger):
             def apply_idx_objs(df, cols):
                 return df.set_axis(cols, axis="columns")
 
-            self._partitions = np.array(
-                [
-                    [
-                        self._partitions[i][j].add_to_apply_calls(
+            self._partitions = self._partitions.copy()
+            for i in range(len(self._partitions)):
+                for j in range(len(self._partitions[i])):
+                    cols = self.columns[slice(cum_col_widths[j], cum_col_widths[j + 1])]
+                    if (
+                        self._old_labels is not None
+                        and not (
+                            cols.equals(
+                                self._old_labels[
+                                    slice(cum_col_widths[j], cum_col_widths[j + 1])
+                                ]
+                            )
+                        )
+                        or self._old_labels is None
+                    ):
+                        self._partitions[i][j] = self._partitions[i][
+                            j
+                        ].add_to_apply_calls(
                             apply_idx_objs,
-                            cols=self.columns[
-                                slice(cum_col_widths[j], cum_col_widths[j + 1])
-                            ],
+                            cols=cols,
                             length=self.row_lengths[i],
                             width=self.column_widths[j],
                         )
-                        for j in range(len(self._partitions[i]))
-                    ]
-                    for i in range(len(self._partitions))
-                ]
-            )
             self._deferred_column = False
         else:
             ErrorMessage.catch_bugs_and_request_email(

anmyachev avatar Nov 15 '22 16:11 anmyachev

@anmyachev, can you revisit this?

YarShev avatar Jan 19 '24 16:01 YarShev