tsai
tsai copied to clipboard
Issues saving models with TSMetaDataset Dataloader
On trying to save a model that uses a dataloader from TSMetaDatasets(TSMetaDataset)
, I got the following error:
Traceback (most recent call last):
File "cli.py", line 15, in <module>
cli()
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1128, in __call__
return self.main(*args, **kwargs)
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1053, in main
rv = self.invoke(ctx)
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1659, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1659, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 1395, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/click/core.py", line 754, in invoke
return __callback(*args, **kwargs)
File "/Users/ykim/Code/Gomanzanas/tsai-explore/explore.py", line 158, in classify_directionality
learn.export(os.path.join(run_dir, "model", "learner.pkl"))
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/fastai/learner.py", line 369, in export
self.dls = self.dls.new_empty()
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/fastai/data/core.py", line 146, in new_empty
loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]
File "/Users/ykim/.local/share/virtualenvs/tsai-explore-3_McIf8_/lib/python3.8/site-packages/fastai/data/core.py", line 146, in <listcomp>
loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]
AttributeError: 'TSMetaDataset' object has no attribute 'new_empty'
After looking around, it looks like #215 is related. Any suggestions on how I can get around this?
Hi @ykim, it seems like a true bug so you won't be able to work around it. It needs to be fixed. I won't be able to work on it though for the next few days.
Well, the next few days have become > 1yr. In any case, I want to document that I've found a way to fix the issue. Here's the documented solution:
- You prepare the metadatasets:
from tsai.data.metadatasets import TSMetaDataset, TSMetaDatasets
vocab = alphabet[:10]
dsets = []
for i in range(3):
size = np.random.randint(50, 150)
X = torch.rand(size, 5, 50)
y = vocab[torch.randint(0, 10, (size,))]
tfms = [None, TSClassification(vocab=vocab)]
dset = TSDatasets(X, y, tfms=tfms)
dsets.append(dset)
metadataset = TSMetaDataset(dsets)
splits = TimeSplitter()(metadataset)
metadatasets = TSMetaDatasets(metadataset, splits=splits)
dls = TSDataLoaders.from_dsets(metadatasets.train, metadatasets.valid)
xb, yb = dls.train.one_batch()
xb, yb
- Train the model as you'd do with any other model in
tsai
:
learn = ts_learner(dls, arch="TSTPlus")
learn.fit_one_cycle(1)
- Export the model when training is complete:
learn.export("test.pkl") # this has been fixed now and it should work
- For inference, you'd prepare a dataloader like you did when you trained the initial model:
vocab = alphabet[:10]
dsets = []
for i in range(2):
size = np.random.randint(50, 150)
X = torch.rand(size, 5, 50)
y = vocab[torch.randint(0, 10, (size,))]
tfms = [None, TSClassification(vocab=vocab)]
dset = TSDatasets(X, y, tfms=tfms)
dsets.append(dset)
metadataset = TSMetaDataset(dsets)
dl = TSDataLoader(metadataset)
- You load the saved learner and use
fastai's
get_preds
method:
learn = load_learner("test.pkl")
learn.get_preds(dl=dl)
This issue has been fixed in GitHub. The solution will be available in the next pip/ conda release (0.3.6).