fairseq
fairseq copied to clipboard
Problem with torch version detection
Hi,
#4513 introduced PyTorch version checks, that are problematic when using NVIDIA PyTorch images.
Let's have a look at one code snippet, where the version check is performed:
https://github.com/facebookresearch/fairseq/blob/5307a0e078d7460003a86f4e2246d459d4706a1d/fairseq/modules/transformer_layer.py#L118-L138
The PyTorch NVIDIA Docker image has the following version identifiers - see here
Let's pick 1.12.0a0+8a1a93a
as an example.
if "+" in torch.__version__:
self.torch_version = torch.__version__.split("+")[0]
The +
is found in the PyTorch version string, so torch_version
will be 1.12.0a0
.
But then the following cast leads to an int-casting error:
self.torch_version = self.torch_version.split(".")
self.int_version = (
int(self.torch_version[0]) * 1000
+ int(self.torch_version[1]) * 10
+ int(self.torch_version[2])
)
Because the torch_version
string will be splitted on .
returning the following array: ['1', '12', '0a0']
.
More precisely, int(self.torch_version[2]
will return a:
ValueError: invalid literal for int() with base 10: '0a0'
because int('0a0')
throws that error.
So I think that version check needs to be adjusted.