super-gradients
super-gradients copied to clipboard
bug in super-gradients/src/super_gradients/training/models/conversion.py
Describe the bug
A clear and concise description of what the bug is.
To Reproduce
Steps to reproduce the behavior:
- net = models.get(model_name, pretrained_weights="coco")
- models.convert_to_onnx(model=net, out_path=f"{model_name}.onnx", torch_onnx_export_kwargs={'input_size': (3,640,640)})
- See error
Expected behavior
Conversion should happen without giving validation errer
Screenshots
If applicable, add screenshots to help explain your problem. models.convert_to_onnx(model=net, out_path=f"{model_name}.onnx", torch_onnx_export_kwargs={'input_size': (3,640,640)}) File "/opt/conda/lib/python3.8/site-packages/super_gradients/common/decorators/factory_decorator.py", line 36, in wrapper return func(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/super_gradients/common/decorators/factory_decorator.py", line 36, in wrapper return func(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/super_gradients/training/models/conversion.py", line 89, in convert_to_onnx prep_model_for_conversion_kwargs["input_size"] = (1, *input_shape) TypeError: 'NoneType' object is not iterable
The change should be: if input_shape is not None: logger.warning( "input_shape is deprecated and will be removed in the next major release." " Use the input_size kwarg in prep_model_for_conversion_kwargs instead" ) prep_model_for_conversion_kwargs["input_size"] = (1, *input_shape)
Thanks @JagdishKolhe , we will look into it and update.
I don't think you need to use torch_onnx_export_kwargs
directly. Passing input_shape
argument should do the work just fine:
from super_gradients.common.object_names import Models
from super_gradients.training import models
net = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
models.convert_to_onnx(model=net, out_path=f"yolo_nas.onnx", input_shape=(3,640,640))
@JagdishKolhe may I ask you where did you get the code snippet? It may that we have outdated documentation somewhere and I'd like to address that to prevent such issue in future. Thanks!
@BloodAxe Please check the code for the method convert_to_onnx in conversion.py (line no 84) It says that input_shape argument to the function will be deprecated in future.
if input_shape is not None: logger.warning( "input_shape is deprecated and will be removed in the next major release." " Use the input_size kwarg in prep_model_for_conversion_kwargs instead" )
Ah I see what you mean. Yeah, it's definitely a bug