pytorch-model-parallel
pytorch-model-parallel copied to clipboard
A memory balanced and communication efficient FullyConnected layer with CrossEntropyLoss model parallel implementation in PyTorch
请问@amp.float_function可以去掉吗,我使用发现有inf问题
pytorch1.6: cuda10.2 titan rtx * 4 output = self.am_branches[i](x.cuda(i), labels[i]) File "/home/derron/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/home/derron/arcface-pytorch/head/metrics_parallel.py", line 102, in forward output[index] = phi[index] RuntimeError:...