transformers
                                
                                 transformers copied to clipboard
                                
                                    transformers copied to clipboard
                            
                            
                            
                        Fix: Correctly handle integer device_map for NPU devices in _load_sta…
Fixes #38468
Problem Description
This PR addresses the AssertionError: "Torch not compiled with CUDA enabled" that occurs when attempting to load models using device_map="auto" on systems with Ascend NPU hardware (Huawei NPU).
The root cause is that, in transformers versions >= 4.50.0, the _load_state_dict_into_meta_model function was not correctly interpreting integer device IDs (e.g., 0, 1) provided by device_map="auto" for NPU devices. Instead, PyTorch's default behavior would implicitly attempt to map these integers to CUDA devices, leading to the assertion error on PyTorch builds that are specifically compiled for NPU and not CUDA.
Solution
The solution involves adding robust type-checking and device-specific conversion logic within the _load_state_dict_into_meta_model function. Specifically, in the code block that handles device_map when it is an integer:
- It now explicitly checks for the availability of NPU (using hasattr(torch, 'npu'),torch.npu.is_available(), andtorch.npu.device_count()), prioritizing NPU over CUDA.
- If NPU is available and the integer ID is valid, it converts the integer ID into the correct NPU-specific string representation (e.g., "npu:0").
- It includes fallback logic for CUDA and CPU, ensuring the integer is always correctly mapped to an existing and intended device string.
This ensures that when a parameter is moved from the "meta" device to its target device (e.g., via param.to(param_device)), param_device is always a correctly formatted string (like "npu:0") that PyTorch can interpret, thereby preventing the AssertionError.
Testing
The fix has been implemented and validated through detailed code analysis of the transformers source and understanding of PyTorch's device management.
I do not have access to physical Ascend NPU hardware to fully reproduce the AssertionError and verify the fix on target hardware. However, the logical path for correctly disambiguating integer device IDs for NPU is now correctly implemented according to PyTorch's API.
All local make quality and make test checks passed in my Colab environment.