[FEA] Automatically calculate appropiate number of hash partitions
Is your feature request related to a problem? Please describe.
The multigpu criteo benchmark is hardcoding the best number of hash partitions for each categorical variable: https://github.com/NVIDIA/NVTabular/blob/2dd4cbc94e074d2a7a319dcf05ff249c7cdec3b3/examples/dask-nvtabular-criteo-benchmark.py#L45-L54 as well as specifying which columns should be stored in host memory
Describe the solution you'd like We should automatically figure out the best number of hash partitions to use, and not require customers to know how to tune nvtabular on a per column basis
Additional context One potential way of doing this for parquet files is to leverage the dictionary encoding metadata. We could also dynamically increase the number of hash partitions at runtime with some effort.
For estimating this using parquet metadata, here is a hacky proof of concept showing that we can detect high cardinality columns by looking at the parquet dictionary size in bytes:
from collections import defaultdict
import pyarrow.parquet as pq
def estimate_categorical_size(parquet_filename, categorical_columns):
categorical_columns = set(categorical_columns)
metadata = pq.read_metadata(parquet_filename)
col_sizes = defaultdict(int)
count = 0
for i in range(metadata.num_row_groups):
rg = metadata.row_group(i)
for j in range(rg.num_columns):
col = rg.column(j)
col_name = col.path_in_schema
if col_name in categorical_columns:
col_sizes[col_name] += col.data_page_offset - col.dictionary_page_offset
count += rg.num_rows
return sorted(((col, size/count) for col, size in col_sizes.items()), key=lambda x: -x[1])
I couldn't get the direct cardinality of the dictionary from the parquet metadata, but instead this is only showing the number of bytes per row in the dictionary. This does correlate well with the hard coded column cardinalities though:
In [35]: estimate_categorical_size(filename, [f"C{i}" for i in range(1, 25)] )
Out[35]:
[('C20', 1.1692657554403938),
('C1', 1.168845422788476),
('C22', 1.1678482506461598),
('C10', 1.1623461059300508),
('C21', 0.9883943242862369),
('C12', 0.5324453491790806),
('C11', 0.466415942352255),
('C23', 0.3989175195088401),
('C3', 0.20309658073368225),
('C2', 0.19938913980191542),
('C5', 0.18719789684979077),
('C24', 0.13401147216676054),
('C7', 0.10373114323248604),
('C15', 0.09040796659136605),
('C4', 0.07420452537233778),
('C14', 0.03077807994312741),
('C8', 0.02007240609747949),
('C18', 0.011280506492764313),
('C16', 0.0009989365481543999),
('C9', 0.0006503115045038749),
('C19', 0.0002890258388076269),
('C13', 0.00020638160809214164),
('C17', 8.422007409777208e-05),
('C6', 8.416619911034348e-05)]