TensorRT
TensorRT copied to clipboard
adding rotary embedding example, with graph rewrite for complex subgraph [WIP]
This PR-
- Adds an example for parallel rotary embedding
- Adds logic for complex graph detection
- Adds a pass for complex graph rewrite in aten_lowering_pass Please note that this PR is currently for the single GPU case where there is no DTensor in the inputs of the torch module. Ideally this should not require runtime changes. This should avoid the graph breaks caused due to view_as_complex and view_as_real nodes.
I see lint error here would reformat /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py which is unrelated to my change. Not sure why is it failing here