apex
apex copied to clipboard
Variant Name Error
In https://github.com/NVIDIA/apex/tree/master/apex/parallel
import apex
input_t = torch.randn(3, 5, 20).cuda()
sbn = apex.parallel.SyncBatchNorm(5).cuda()
output_t = sbn(input)
SHOULD BE
import apex
input_t = torch.randn(3, 5, 20).cuda()
sbn = apex.parallel.SyncBatchNorm(5).cuda()
output_t = sbn(input_t)