transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix: Correctly handle integer device_map for NPU devices in _load_sta…

Open gspeter-max opened this issue 6 months ago • 13 comments

Fixes #38468

Problem Description

This PR addresses the AssertionError: "Torch not compiled with CUDA enabled" that occurs when attempting to load models using device_map="auto" on systems with Ascend NPU hardware (Huawei NPU).

The root cause is that, in transformers versions >= 4.50.0, the _load_state_dict_into_meta_model function was not correctly interpreting integer device IDs (e.g., 0, 1) provided by device_map="auto" for NPU devices. Instead, PyTorch's default behavior would implicitly attempt to map these integers to CUDA devices, leading to the assertion error on PyTorch builds that are specifically compiled for NPU and not CUDA.

Solution

The solution involves adding robust type-checking and device-specific conversion logic within the _load_state_dict_into_meta_model function. Specifically, in the code block that handles device_map when it is an integer:

  • It now explicitly checks for the availability of NPU (using hasattr(torch, 'npu'), torch.npu.is_available(), and torch.npu.device_count()), prioritizing NPU over CUDA.
  • If NPU is available and the integer ID is valid, it converts the integer ID into the correct NPU-specific string representation (e.g., "npu:0").
  • It includes fallback logic for CUDA and CPU, ensuring the integer is always correctly mapped to an existing and intended device string.

This ensures that when a parameter is moved from the "meta" device to its target device (e.g., via param.to(param_device)), param_device is always a correctly formatted string (like "npu:0") that PyTorch can interpret, thereby preventing the AssertionError.

Testing

The fix has been implemented and validated through detailed code analysis of the transformers source and understanding of PyTorch's device management. I do not have access to physical Ascend NPU hardware to fully reproduce the AssertionError and verify the fix on target hardware. However, the logical path for correctly disambiguating integer device IDs for NPU is now correctly implemented according to PyTorch's API.

All local make quality and make test checks passed in my Colab environment.

gspeter-max avatar Jun 04 '25 18:06 gspeter-max

cc @sunmarc

Rocketknight1 avatar Jun 05 '25 12:06 Rocketknight1

@gspeter-max thx for your pr. I just check it in ascend npus with torch2.1.0 and torch_npu 2.1.0, it works. Here is my test code:

AutoModelForCausalLM.from_pretrained("qwen2_7b",device_map="auto")

cc @SunMarc

jiaqiw09 avatar Jun 09 '25 08:06 jiaqiw09

Thanks a lot for verifying @jiaqiw09! Great to hear it works well on Ascend NPUs with torch 2.1.0 and torch_npu 2.1.0. Let me know if there’s anything else needed from my side.

gspeter-max avatar Jun 09 '25 08:06 gspeter-max

Thanks a lot for verifying @jiaqiw09! Great to hear it works well on Ascend NPUs with torch 2.1.0 and torch_npu 2.1.0. Let me know if there’s anything else needed from my side.

when torch >=2.5, ascend npus also support int index. However, I think it not necessary to introduce another if statement to code as "npu:0" works for both versions.

jiaqiw09 avatar Jun 09 '25 09:06 jiaqiw09

Can you give me the exact location.where you like to see this function ? Cc @SunMarc

gspeter-max avatar Jun 10 '25 16:06 gspeter-max

i think is better


def __add_prefix_to_device(param_device: Union[str, torch.device, int], param_name: str) -> str:
    
    """ Takes an integer device and adds the correct hardware prefix (e.g., 'npu:', 'mlu:') to it. 
    logic is similar to `infer_auto_device_map` in 'accelerate library'.
    """
    if isinstance(param_device, int):

        device_id = param_device
        if hasattr(torch, "npu") and torch.npu.is_available() and device_id < torch.npu.device_count():
            device_type_str = "npu"
        elif torch.cuda.is_available() and device_id < torch.cuda.device_count():
            device_type_str = "cuda"
        elif hasattr(torch, "mlu") and torch.mlu.is_available() and device_id < torch.mlu.device_count(): 
            device_type_str = "mlu" 
        elif hasattr(torch, "xpu") and torch.xpu.is_available() and device_id < torch.xpu.device_count(): 
            device_type_str = "xpu"
        elif hasattr(torch, "sdaa") and torch.sdaa.is_available() and device_id < torch.sdaa.device_count():
            device_type_str = "sdaa"
        elif hasattr(torch, "hpu") and torch.hpu.is_available() and device_id < torch.hpu.device_count():
            device_type_str = "hpu"
        elif device_id == -1 or device_id == 0:
            device_type_str = "cpu"
        else:
            # If the device_id is not recognized, we raise an error 
            raise ValueError(
                f"Invalid integer device_map '{device_id}' for parameter '{param_name}'. " 
                f"Cannot find a matching device type (NPU, CUDA, CPU) for this ID."
            )

        if device_type_str == "cpu":
            final_param_device = "cpu"
        else:
            final_param_device = f"{device_type_str}:{device_id}"

        return final_param_device

if you like i change this or tell me about that cc @SunMarc i am wait for your reply

gspeter-max avatar Jun 10 '25 16:06 gspeter-max

Let's keep this simple. The device if will never be -1 btw.

def _add_prefix_to_device(param_device: Union[str, torch.device, int], param_name: str) -> str:
    """ 
    Takes an integer device and adds the correct hardware prefix (e.g., 'npu:', 'mlu:') to it. 
    logic is similar to `infer_auto_device_map` in 'accelerate library'.
    """
    if isinstance(param_device, int):
        if is_npu_available():
            device_type_str = "npu"
        elif is_cuda_available():
            device_type_str = "cuda"
        elif is_mlu_available(): 
            device_type_str = "mlu" 
        elif is_xpu_available(): 
            device_type_str = "xpu"
        elif is_sdaa_available():
            device_type_str = "sdaa"
        elif is_musa_available():
               device_type_str = "musa"
        elif is_hpu_available():
            device_type_str = "hpu"
        else:
            raise ValueError(
                f"Invalid integer device_map '{param_device}' for parameter '{param_name}'. " 
                f"Cannot find a matching device type (CUDA, XPU...) for this ID."
            )
        return(f"{device_type_str}:{param_device}")
    return param_device

SunMarc avatar Jun 11 '25 12:06 SunMarc

please fix the code quality with make style

SunMarc avatar Jun 12 '25 13:06 SunMarc

Thanks you You guide me for this thank a lot 🫡 We are need to change something

gspeter-max avatar Jun 12 '25 13:06 gspeter-max

@bot /style

SunMarc avatar Jun 12 '25 13:06 SunMarc

@bot /style

SunMarc avatar Jun 12 '25 14:06 SunMarc

Can you fix the style issue @gspeter-max with make style and checking why it fails ? thanks !

SunMarc avatar Jun 12 '25 14:06 SunMarc

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.