torch2trt
torch2trt copied to clipboard
[Question] Is `check_torch_dtype` necessary in `add_missing_trt_tensors` or `trt_`?
There are several calls to check_torch_dtype when adding multiple constant tensors to the graph (eg. via add_missing_trt_tensors or trt_).
This might be useful for some ops where we may expect multiple inputs to be the same dtype, but is this check actually required? It seems like either:
- the original pytorch op requires the same dtypes, so this requirement is already satisfied on conversion,
- trt requires the same dtype for certain arguments into a specific layer, but maybe this check should occur in the converter.
The context here is that as utility functions, we may want to call add_missing_trt_tensors or trt_ to add multiple itensors to the graph; however, we may supply an arbitrary number and dtype of tensors to be added. Asserting that all of the inputs are the same dtype seems unnecessary at a glance, since we can just get around this restriction by making a second call to these functions (and still break any dtype congruency that a trt layer requires); on the other hand, removing the restriction would allow us to collapse all conversions into a single function call, improving developer ergonomics.
Thoughts?