NVTabular
NVTabular copied to clipboard
[BUG] `Groupby` with single column for grouping and calculating results throws an error
Describe the bug
This functionality is important because often we might want to group by an identifier column (such as customer_id
for instance), perform some calculations on the groupings and separately return the column that we grouped by on (for instance, to add some sort of target using that column, or even juts to be able to tell what subset of data each calculated value relates to).
Having this functionality would be very useful, because often we might want to still run some additional preprocessing on the numerical values that are returned by the original grouping. Essentially, we need to process the identifier column separately.
Currently, doing so throws an error as below
Steps/Code to reproduce bug Reproducer code:
import cudf
import nvtabular as nvt
import numpy as np
import datetime
purchases = cudf.DataFrame(
data={
'customer_id': np.random.randint(0, 10, 1000),
'purchase_date': [datetime.date(2022, np.random.randint(1,13), np.random.randint(1,29)) for i in range(1000)],
'quantity': np.random.randint(1, 50, 1000)
})
purchases.head()
ds = nvt.Dataset(purchases)
last_purchase_quantity = ['quantity', 'purchase_date'] >> nvt.ops.Groupby(
groupby_cols='customer_id',
sort_cols='purchase_date',
aggs={'quantity': 'last'}
) >> nvt.ops.Normalize()
customer_id = ['customer_id', 'purchase_date'] >> nvt.ops.Groupby(
groupby_cols='customer_id',
sort_cols='purchase_date',
aggs={'customer_id': 'last'}
)
wf = nvt.Workflow(last_purchase_quantity + customer_id)
wf.fit_transform(ds).compute()
Expected behavior No error is raised and the values are returned correctly
Environment details (please complete the following information): current NVTabular main
It looks like the column names are getting mangled by the GroupBy
op somehow:
z> /nvtabular/nvtabular/workflow/workflow.py(523)_transform_partition()
-> col_series = output_df[col_name]
(Pdb) output_df
c_u_s_t_o_m_e_r___i_d
0 5
1 0
2 6
3 2
4 8
5 4
6 1
7 9
8 7
9 3
(Pdb) col_name
'customer_id'
This is being worked on under #1636