transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix: Correctly handle integer device_map for NPU devices in _load_sta…

Open gspeter-max opened this issue 4 months ago • 13 comments

Fixes #38468

Problem Description

This PR addresses the AssertionError: "Torch not compiled with CUDA enabled" that occurs when attempting to load models using device_map="auto" on systems with Ascend NPU hardware (Huawei NPU).

The root cause is that, in transformers versions >= 4.50.0, the _load_state_dict_into_meta_model function was not correctly interpreting integer device IDs (e.g., 0, 1) provided by device_map="auto" for NPU devices. Instead, PyTorch's default behavior would implicitly attempt to map these integers to CUDA devices, leading to the assertion error on PyTorch builds that are specifically compiled for NPU and not CUDA.

Solution

The solution involves adding robust type-checking and device-specific conversion logic within the _load_state_dict_into_meta_model function. Specifically, in the code block that handles device_map when it is an integer:

  • It now explicitly checks for the availability of NPU (using hasattr(torch, 'npu'), torch.npu.is_available(), and torch.npu.device_count()), prioritizing NPU over CUDA.
  • If NPU is available and the integer ID is valid, it converts the integer ID into the correct NPU-specific string representation (e.g., "npu:0").
  • It includes fallback logic for CUDA and CPU, ensuring the integer is always correctly mapped to an existing and intended device string.

This ensures that when a parameter is moved from the "meta" device to its target device (e.g., via param.to(param_device)), param_device is always a correctly formatted string (like "npu:0") that PyTorch can interpret, thereby preventing the AssertionError.

Testing

The fix has been implemented and validated through detailed code analysis of the transformers source and understanding of PyTorch's device management. I do not have access to physical Ascend NPU hardware to fully reproduce the AssertionError and verify the fix on target hardware. However, the logical path for correctly disambiguating integer device IDs for NPU is now correctly implemented according to PyTorch's API.

All local make quality and make test checks passed in my Colab environment.

gspeter-max avatar Jun 04 '25 18:06 gspeter-max