tsai icon indicating copy to clipboard operation
tsai copied to clipboard

Issues saving models with TSMetaDataset Dataloader

Open ykim opened this issue 3 years ago • 1 comments

On trying to save a model that uses a dataloader from TSMetaDatasets(TSMetaDataset), I got the following error:

Traceback (most recent call last):
  File "cli.py", line 15, in <module>
    cli()
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1659, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1659, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/Users/ykim/Code/Gomanzanas/tsai-explore/explore.py", line 158, in classify_directionality
    learn.export(os.path.join(run_dir, "model", "learner.pkl"))
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/fastai/learner.py", line 369, in export
    self.dls = self.dls.new_empty()
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/fastai/data/core.py", line 146, in new_empty
    loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]
  File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/fastai/data/core.py", line 146, in <listcomp>
    loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]
AttributeError: 'TSMetaDataset' object has no attribute 'new_empty'

After looking around, it looks like #215 is related. Any suggestions on how I can get around this?

ykim avatar Dec 16 '21 07:12 ykim

Hi @ykim, it seems like a true bug so you won't be able to work around it. It needs to be fixed. I won't be able to work on it though for the next few days.

oguiza avatar Dec 23 '21 17:12 oguiza

Well, the next few days have become > 1yr. In any case, I want to document that I've found a way to fix the issue. Here's the documented solution:

  1. You prepare the metadatasets:
from tsai.data.metadatasets import TSMetaDataset, TSMetaDatasets

vocab = alphabet[:10]
dsets = []
for i in range(3):
    size = np.random.randint(50, 150)
    X = torch.rand(size, 5, 50)
    y = vocab[torch.randint(0, 10, (size,))]
    tfms = [None, TSClassification(vocab=vocab)]
    dset = TSDatasets(X, y, tfms=tfms)
    dsets.append(dset)



metadataset = TSMetaDataset(dsets)
splits = TimeSplitter()(metadataset)
metadatasets = TSMetaDatasets(metadataset, splits=splits)
dls = TSDataLoaders.from_dsets(metadatasets.train, metadatasets.valid)
xb, yb = dls.train.one_batch()
xb, yb
  1. Train the model as you'd do with any other model in tsai:
learn = ts_learner(dls, arch="TSTPlus")
learn.fit_one_cycle(1)
  1. Export the model when training is complete:
learn.export("test.pkl") # this has been fixed now and it should work
  1. For inference, you'd prepare a dataloader like you did when you trained the initial model:
vocab = alphabet[:10]
dsets = []
for i in range(2):
    size = np.random.randint(50, 150)
    X = torch.rand(size, 5, 50)
    y = vocab[torch.randint(0, 10, (size,))]
    tfms = [None, TSClassification(vocab=vocab)]
    dset = TSDatasets(X, y, tfms=tfms)
    dsets.append(dset)
metadataset = TSMetaDataset(dsets)
dl = TSDataLoader(metadataset)
  1. You load the saved learner and use fastai's get_preds method:
learn = load_learner("test.pkl")
learn.get_preds(dl=dl)

oguiza avatar Mar 24 '23 11:03 oguiza

This issue has been fixed in GitHub. The solution will be available in the next pip/ conda release (0.3.6).

oguiza avatar Mar 24 '23 11:03 oguiza