SDV icon indicating copy to clipboard operation
SDV copied to clipboard

PARSynthesizer model won't fit if sequence_index is missing

Open srinify opened this issue 1 year ago • 0 comments

Environment Details

Happening in Mac & Colab in both SDV 1.11 and 1.12 (haven't tried other versions)

Error Description

When using PARSynthesizer, supplying a sequence_key but not a sequence_index seems to be throwing an error. Both of the following examples can be fixed by adding the sequence_index column back in, updating the metadata, and then running fit()

Steps to reproduce

Example 1

Full Code:

import pandas as pd
from sdv.sequential import PARSynthesizer
from sdv.metadata import SingleTableMetadata
from sdv.datasets.demo import get_available_demos, download_demo

demo_data, metadata = download_demo(dataset_name='AtrialFibrillation', modality='sequential')
# Removed column that would normally be the sequence_index
demo2 = demo2.drop('s_index', axis=1)

demo2_metadata = SingleTableMetadata()
# Re-building metadata (ofc I could remove s_index column too from existing metadata)
demo2_metadata.detect_from_dataframe(demo2)

demo2_metadata.update_column(column_name='e_id', sdtype='id')
demo2_metadata.set_sequence_key(column_name='e_id')

synthesizer_demo = PARSynthesizer(demo2_metadata)
synthesizer_demo.fit(demo2)

Throws this error:

/usr/local/lib/python3.10/dist-packages/sdv/single_table/base.py:80: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.
  warnings.warn(
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-25-e429dedddd24> in <cell line: 2>()
      1 synthesizer_demo = PARSynthesizer(demo2_metadata)
----> 2 synthesizer_demo.fit(demo2)

7 frames
/usr/local/lib/python3.10/dist-packages/pandas/core/indexes/base.py in _raise_if_missing(self, key, indexer, axis_name)
   5936                 if use_interval_msg:
   5937                     key = list(key)
-> 5938                 raise KeyError(f"None of [{key}] are in the [{axis_name}]")
   5939 
   5940             not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())

KeyError: "None of [Index(['74899b63-1f49-4701-8cdc-e9aeda8426cf'], dtype='object')] are in the [columns]"

Example 2

Generating from scratch:

import numpy as np
import pandas as pd

ids = np.arange(0, 50_000, 1)
ids = np.repeat(ids, 45)

obs = np.concatenate(
    [np.random.normal(loc=5, scale=1, size=1) for i in ids]
)

df = pd.DataFrame(
    {
        "id": ids,
        "obs": obs
    }
)

from sdv.sequential import PARSynthesizer
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()
metadata.detect_from_dataframe(df)
metadata.update_column(column_name='id', sdtype='id')
metadata.set_sequence_key(column_name='id')

synthesizer = PARSynthesizer(metadata, verbose=True)
synthesizer.fit(df)

Returns the same error as above but with a different ID:

...
KeyError: "None of [Index(['fb9aa2a7-3694-47f0-8145-2434b9196bbb'], dtype='object')] are in the [columns]"

Workaround

Create a simple incrementing integer column (e.g. from 1 to n rows per sequence) that can be used to index each row that's linked to the same sequence_key.

s_key | s_id | dim1
------------------
A        | 1        | 9.1 
A        | 2        | 8.1
B        | 1        | 4.1
B        | 2        | 5.1
...

Then update your metadata. Here's a code snippet that does both:

# Replace "seq_key" with your column you're using for the sequence_key
s_index = demo2.groupby('seq_key').cumcount() + 1
df['s_index'] = s_index

metadata.set_sequence_key(column_name='s_index')

# Now this should work!
synthesizer.fit(df)

srinify avatar Apr 30 '24 15:04 srinify