tsai
tsai copied to clipboard
tsai models should be TorchScriptable
Context: A common way to deploy models is using TorchServe (https://pytorch.org/serve/). The simplest way to do this requires a TorchScripted model (https://pytorch.org/docs/stable/jit.html).
Issue: While many tsai model architectures are already TorchScriptable, some are not for various reasons. This makes it difficult to deploy these models for production use. For example below code shows that the TransformerModel
is TorchScriptable while LSTM
is not.
In [1]: import torch
In [2]: from tsai.models.RNN import LSTM
In [3]: from tsai.models.TransformerModel import TransformerModel
In [4]: optimized_transformer = torch.jit.script(TransformerModel(c_in=2, c_out=1))
<RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:240: UserWarning: 'batch_first' was found in ScriptModule constants, but was not actually set in __init__. Consider removing it.
warnings.warn("'{}' was found in ScriptModule constants, "
<RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:234: UserWarning: 'norm' was found in ScriptModule constants, but it is a non-constant submodule. Consider removing it.
warnings.warn("'{}' was found in ScriptModule constants, "
In [5]: optimized_lstm = torch.jit.script(LSTM(c_in=2, c_out=1))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [5], in <cell line: 1>()
----> 1 optimized_lstm = torch.jit.script(LSTM(c_in=2, c_out=1))
File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_script.py:1265, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1263 if isinstance(obj, torch.nn.Module):
1264 obj = call_prepare_scriptable_func(obj)
-> 1265 return torch.jit._recursive.create_script_module(
1266 obj, torch.jit._recursive.infer_methods_to_compile
1267 )
1269 if isinstance(obj, dict):
1270 return create_script_dict(obj)
File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:454, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
452 if not is_tracing:
453 AttributeTypeIsSupportedChecker().check(nn_module)
--> 454 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:520, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
518 # Compile methods if necessary
519 if concrete_type not in concrete_type_store.methods_compiled:
--> 520 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
521 # Create hooks after methods to ensure no name collisions between hooks and methods.
522 # If done before, hooks can overshadow methods that aren't exported.
523 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:371, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
368 property_defs = [p.def_ for p in property_stubs]
369 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 371 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError:
Module 'LSTM' has no attribute 'dropout' (This function exists as an attribute on the Python module, but we failed to compile it to a TorchScript function.
The error stack is reproduced here:
Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "<RETRACTED>/venv/lib/python3.9/site-packages/fastai/imports.py", line 66
def noop (x=None, *args, **kwargs):
~~~~~~~ <--- HERE
"Do nothing"
return x
:
File "<RETRACTED>/venv/lib/python3.9/site-packages/tsai/models/RNN.py", line 21
output, _ = self.rnn(x) # output from all sequence steps: [batch_size x seq_len x hidden_size * (1 + bidirectional)]
output = output[:, -1] # output from last sequence step : [batch_size x hidden_size * (1 + bidirectional)]
output = self.fc(self.dropout(output))
~~~~~~~~~~~~ <--- HERE
return output
Potential solution:
For the LSTM (and other RNN models) the issue is explicitly with this line: https://github.com/timeseriesAI/tsai/blob/df8eb53c22701e633b796f5f61b5197d6a2a0872/tsai/models/RNN.py#L14
I believe substituting noop
with nn.Identity()
will keep existing behavior and make the model TorchScriptable.
Generally, perhaps adding an integration test to test all models for TorchScriptability would be a sensible first step. Then potentially individual model architectures could be addressed on a case by case basis.
Hi @ivanzvonkov , I fully agree with you. It's something I've always had in mind, and I'll start doing it from now onwards. If you look at the RNN documentation you'll see I've added some tests and show how you can convert models to TorchScript and/or ONNX. Are you aware of any other models that cannot be converted?
Thank you for making the updates to the RNN! I recall quite a few from the README were not TorchScriptable including: FCN, TCN, InceptionTime, Rocket, TST, TabTransformer. Is there any generalized ways of testing all models within this framework?
There isn't a way to test all models, but I'll add some functionality soon as it makes sense.
Great, having some guarantees about which models are torchsciptable will make it a lot easier to test different models in projects where deployment is required for full evaluation.