apex icon indicating copy to clipboard operation
apex copied to clipboard

Regression: tensor_parallel.ColumnParallelLinear fails on onnx.export

Open borisfom opened this issue 2 years ago • 2 comments

Describe the Bug

Here, while exporting one of Nemo Megatron modules that use tensor_parallel.ColumnParallelLinear. Happens with ToT. This used to work with previous releases. Apparently, the problem is that, inference/no-grad forward execution path still contains LinearWithGradAccumulationAndAsyncAllreduce AutogradFunction's forward() - which by design won't export.

E0408 21:46:43.917169 140336425469760 export.py:160] Export failed. Please make sure your NeMo model class (nemo.collections.nlp.models.question_answering.qa_model.QAModel) has working export
() and that you have the latest NeMo package installed with [all] dependencies. Traceback (most recent call last): File "/git/NeMo/scripts/export.py", line 176, in nemo_export(sys.argv[1:]) File "/git/NeMo/scripts/export.py", line 165, in nemo_export raise e File "/git/NeMo/scripts/export.py", line 151, in nemo_export _, descriptions = model.export( File "/git/NeMo/nemo/core/classes/exportable.py", line 142, in export torch.onnx.export( File "/opt/conda/lib/python3.8/site-packages/torch/onnx/init.py", line 332, in export return utils.export(model, args, f, export_params, verbose, training, File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 113, in export _export(model, args, f, export_params, verbose, training, input_names, output_names, File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 790, in _export proto, export_map, val_use_external_data_format = graph._export_onnx( RuntimeError: ONNX export failed: Couldn't export Python operator LinearWithGradAccumulationAndAsyncAllreduce

Defined at: /opt/conda/lib/python3.8/site-packages/apex/transformer/tensor_parallel/layers.py(315): linear_with_grad_accumulation_and_async_allreduce

Expected Behavior

Environment

borisfom avatar Apr 08 '22 22:04 borisfom

This was my quick workaround - to replace instances of tensor_parallel.ColumnParallelLinear with my wrapper class below. Something like that should be implemented inside tensor_parallel.ColumnParallelLinear.forward instead:

class ColumnLinear(tensor_parallel.ColumnParallelLinear):
    # redefine forward only for non-parallel inference                                                                                                                                          
    def forward(self, input_):
        world_size = get_tensor_model_parallel_world_size()
        if input_.requires_grad or world_size > 1:
            return tensor_parallel.ColumnParallelLinear.forward(self, input_)

        bias = self.bias if not self.skip_bias_add else None
        # Matrix multiply.                                                                                                                                                                      
        output = torch.matmul(input_, self.weight.t())
        if not self.skip_bias_add:
            output = output + self.bias

        output_bias = self.bias if self.skip_bias_add else None

        return output, output_bias

borisfom avatar Apr 08 '22 22:04 borisfom

seems related to https://github.com/NVIDIA/NeMo/pull/3998

crcrpar avatar Apr 15 '22 21:04 crcrpar