Issue 2153/ Refactor PyTorch 64-Bit Data Type Conversion #2153
Made changes so it uses a function dtype_to_32bit that converts 64-bit data types to 32-bit data types.
Thanks @teelrabbit for the proposed fix. We will need to also fix where all those "number to dtype" mapping are used, e.g. ops.py
Thanks @teelrabbit for the proposed fix. We will need to also fix where all those "number to dtype" mapping are used, e.g. ops.py
Makes sense. I'll do some testing and see I can use the current solution across all instances of where the mapping is used 🤙🏻
"number to dtype" mapping are used
Made some changes to the occurrences of NUM_TO_TORCH_DTYPE to use "dtype_to_32bit". Can you comfirm if this is what you intended by "number to dtype". https://pastes.dev/2oAOqJeDM1 @YifanShenSZ
https://github.com/apple/coremltools/commit/db22515ad1c7a9b138459db252a9d1ac2e3c1944
' np_type = nptype_from_builtin(target_dtype.dtype)
dtype = NUMPY_DTYPE_TO_TORCH_NUM[np_type] dtype = NUMPY_DTYPE_TO_TORCH_NUM[np_type]
torch_dtype = NUM_TO_TORCH_DTYPE[dtype] torch_dtype = dtype_to_32bit(dtype)
if isinstance(_input, Var) and _input.can_be_folded_to_const(): if isinstance(_input, Var) and _input.can_be_folded_to_const():
# numpy -> torch -> torch cast -> numpy # numpy -> torch -> torch cast -> numpy
# This path is needed to use the mapping of passed in dtypes to torch dtypes. # This path is needed to use the mapping of passed in dtypes to torch dtypes.
LGTM! CI ✅ Many thanks for contributing to coremltools!
Also this issue should be closed out https://github.com/apple/coremltools/issues/2153#issuecomment-2085832273