tsai
tsai copied to clipboard
[BUG] Robust indexing & safe dtype handling in tsai’s TSDataLoaders / TfmdLists
Fixes https://github.com/sktime/sktime/issues/7885 <[ENH] interface to tsai package>
What does this implement / fix?
While wiring tsai into sktime I hit two subtle indexing / dtype issues that can also bite vanilla tsai users: ----NumPy scalar vs. Python int in TfmdLists + Subset When self._splits is a NumPy array (dtype=int8/int32/…) the value returned by idx = self._splits[it] is a NumPy scalar. Indexing self.items[idx] with that object fails, breaking Learner.predict, dev notebooks, and downstream libs.
Patch: Convert NumPy scalar → .item() and tiny NumPy array → .tolist() before the final lookup. Zero behavioural change for normal Python ints/lists.
-----Unsafe dtype when casting NumPy arrays to tensors in TfmdLists.init Inside the in-place branch (inplace=True, tfms=None) we call typ(tl.items). If tl.items is an integer array, torch.as_tensor produces a LongTensor, which later collides with models expecting FloatTensors.
Patch: After the existing cast, explicitly re-cast any NumPy array to torch.float32.
Dependency impact None—pure Python changes, no new packages.
Focus for reviewers Sanity-check the scalar/array coercion logic in getitem.
Confirm that forcing float32 won’t interfere with edge-cases where integer tensors are explicitly desired (I could not find any).
I’ve run the basic_motions smoke-tests plus a small custom dataset; everything trains and predicts fine on CPU and GPU.
Thanks for taking a look!