TensorRT
TensorRT copied to clipboard
Aten scatter converter
Dependency of PR- #2519
The harness.py file is changed in this for the cases in which the index is passed in the forward function.
torch scatter function takes only int64 inputs for index
Couple of points
- When passed as a variable outside the
forwardfunction, theget_trt_tensorfunction will cast to int64 - When int64 is passed in the
forwardfunction, it throws error in theunified_dtype_converterinfx/utils.py. This is inTRTInterpretor.pywhile adding the placeholder nodes -self.ctx.net.add_input - Passing
truncate_long_and_doubleincompilationsettingsdoes not help since this would repair inputs in_compiler.pyoutside the TRTInterpreter and it would not help in harness.py - Had the tests been in the form of lowering tests, the line - https://github.com/pytorch/TensorRT/blob/16c031349c6a1af5a8408a817f2ef8542aa6f176/py/torch_tensorrt/dynamo/_compiler.py#L394 would create issue