xtuner icon indicating copy to clipboard operation
xtuner copied to clipboard

[WIP] Fix incorrect CUDA stream synchronization in activation_offload async offload

Open Copilot opened this issue 3 months ago • 1 comments

Thanks for asking me to work on this. I will get started on it and keep this PR's description up to date as I form a plan and make progress.

Original description:

Summary There is a bug in xtuner/v1/utils/activation_offload.py that can cause incorrect async ordering between CUDA streams during activation offloading/prefetching. This manifests as potential race conditions, deadlocks (mutual stream waits), or consuming tensors before H2D/D2H copies complete.

Root Causes

  • Events are recorded on the wrong stream or without strong producer/consumer linkage, e.g., creating an event on the current stream and making another stream wait for it without ensuring the data producer actually recorded completion.
  • Mutual waits between streams (working_stream.wait_stream(h2d_stream) followed by h2d_stream.wait_stream(working_stream)) in _unpack_from_cpu can lead to deadlock or no-ops depending on stream identity.
  • Mixing wait_stream and event waits inconsistently (including waits on default/current streams) causes unclear dependencies and potential global syncs.

Goals

  • Make stream synchronization explicit and correct using events recorded on the stream that executes the copy, and have consumer streams wait on those events only.
  • Remove mutual wait_stream patterns and avoid unnecessary default stream waits.
  • Bind tensor lifetimes to the stream(s) performing the copies or consuming the tensors via record_stream to avoid premature reuse.

Changes Requested Edit xtuner/v1/utils/activation_offload.py as follows:

  1. SwapTensor
  • Create dedicated events: self.d2h_event and self.h2d_event.
  • In launch_d2h(d2h_stream):
    • Launch the copy on d2h_stream, record d2h_event on d2h_stream immediately after the copy is enqueued, and set stat to "host".
    • Remove the temporary forward_event and cross-stream wait.
  • In wait_d2h_finished():
    • Wait on d2h_event from the current stream, then resize device storage to 0.
  • In launch_h2d(h2d_stream, resize_storage, consumer_stream):
    • Optionally resize storage, enqueue H2D copy on h2d_stream, record h2d_event on h2d_stream, set stat to "device".
    • Make consumer_stream wait on h2d_event and record_stream(consumer_stream) on tensor to tie its lifetime.
  • In prefetch_launch_h2d(h2d_stream, resize_storage):
    • Same as launch_h2d but without consumer binding; record_stream(h2d_stream) to keep lifetime until the copy completes.
  • In wait_h2d_finished():
    • Wait on h2d_event from the current stream if needed.
  1. OffloadManager
  • del_npu_tensor: call act.wait_d2h_finished() for all matching keys; do not rely on stream.wait_stream.
  • prefetch_get: remove cross wait between d2h and h2d streams; just call prefetch_launch_h2d(h2d_stream, True) and rely on events.
  1. async_save_on_cpu hooks
  • _pack_to_cpu:
    • When after_block, ensure previous block’s tensors finish D2H and shrink storage via OffloadManager().del_npu_tensor.
    • Before calling launch_d2h, make d2h_stream wait on the producing stream (torch.cuda.current_stream()).
  • _unpack_from_cpu:
    • Remove mutual wait_stream calls. Instead obtain consumer_stream = torch.cuda.current_stream(), then call swap_tensor.launch_h2d(h2d_stream, True, consumer_stream). If prefetch is enabled, compute keys and call OffloadManager().prefetch_get.

Proposed Implementation Please replace the current file with the following implementation, which applies the above corrections and keeps the original API surface compatible:

"""
This file is adapted from: https://gitee.com/ascend/MindSpeed-MM/blob/master/mindspeed_mm/utils/async_offload.py
Original Author: liyx616
Original License: MIT
Modifications: To enable compatibility on both GPU and NPU, replace all torch.npu with torch.cuda, and then use the transfer_to_npu interface
"""

import torch
from torch.autograd.graph import saved_tensors_hooks

from xtuner.v1.utils.device import get_device


if get_device() == "npu":
    from torch_npu.contrib import transfer_to_npu  # noqa


def base_check_fn(tensor):
    if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter):
        return False
    if tensor.storage().size() <= 0:
        return False
    return True


class GetCnt:
    def __init__(self):
        self._block_idx = -1
        self._block_tensor_nums = {}  # offload tensors per block

    def get_cnt(self, block_idx):
        after_block = False
        if block_idx > self._block_idx:
            self._block_tensor_nums[block_idx] = 1
            if block_idx != 0:
                after_block = True
            self._block_idx = block_idx
        elif block_idx == self._block_idx:
            self._block_tensor_nums[block_idx] += 1
        else:
            # one step end
            self._block_idx = block_idx
            self._block_tensor_nums = {block_idx: 1}

        offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}"
        return offload_tensor_key, after_block

    def get_prefetch_keys(self, block_idx, tensor_idx):
        prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None)

        if prefetch_block_idx is None:
            return []

        prefetch_block_tensor_nums = self._block_tensor_nums[prefetch_block_idx]
        block_tensor_nums = self._block_tensor_nums[block_idx]
        start = tensor_idx * prefetch_block_tensor_nums // block_tensor_nums
        end = (tensor_idx + 1) * prefetch_block_tensor_nums // block_tensor_nums
        prefetch_idxs = list(range(start, end))
        return [f"{block_idx - 1}_{prefetch_idx}" for prefetch_idx in prefetch_idxs]


class SwapTensor:
    def __init__(self, tensor, key):
        self.tensor = tensor
        self.size = tensor.size()
        self.storage_size = tensor.storage().size()
        self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device="cpu")

        self.is_slice_tensor = tensor.storage().size() != tensor.numel()
        self.stat = "device"
        self.key = key

        # events marking copy completion
        self.d2h_event = None
        self.h2d_event = None

    # device to host
    def launch_d2h(self, d2h_stream: torch.cuda.Stream):
        if self.stat != "device":
            return

        with torch.no_grad():
            with torch.cuda.stream(d2h_stream):
                if self.is_slice_tensor:
                    self.tensor_cpu.copy_(self.tensor, non_blocking=True)
                else:
                    self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True)

                if self.d2h_event is None:
                    self.d2h_event = torch.cuda.Event()
                self.d2h_event.record(d2h_stream)

                self.stat = "host"

    # synchronize d2h and resize 0
    def wait_d2h_finished(self):
        if self.stat != "host":
            return
        if self.d2h_event is not None:
            torch.cuda.current_stream().wait_event(self.d2h_event)
        self.tensor.storage().resize_(0)

    # resize storage_size and host to device
    def launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool, consumer_stream: torch.cuda.Stream):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"

                consumer_stream.wait_event(self.h2d_event)
                self.tensor.record_stream(consumer_stream)

    # prefetch host to device without binding to a consumer stream
    def prefetch_launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"
                self.tensor.record_stream(h2d_stream)

    # synchronize h2d
    def wait_h2d_finished(self):
        if self.stat != "device":
            return
        if self.h2d_event is not None:
            torch.cuda.current_stream().wait_event(self.h2d_event)


class SingletonMeta(type):
    """Single meta class."""

    _instances = {}  # type: ignore

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            instance = super().__call__(*args, **kwargs)
            cls._instances[cls] = instance

        return cls._instances[cls]


class OffloadItem:
    """Class for offload item."""

    def __init__(self, act=None, ref_cnt=0, event=None):
        self.act = act
        self.ref_cnt = ref_cnt
        self.event = event

    def get_event(self):
        return self.event

    def has_event(self):
        return self.event is not None


class OffloadManager(metaclass=SingletonMeta):
    """Class for offload manager."""

    def __init__(self, check=False):
        self.items = {}
        self.check = check
        self.device_item = []
        self.getcnt = GetCnt()

    def get_cnt(self, block_idx):
        return self.getcnt.get_cnt(block_idx)

    def assert_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} does not exist in items")

    def exist(self, key):
        return key in self.items

    def assert_not_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} already exist in items")

    def put(self, key, act, event=None):
        if key in self.items:
            self.items[key].act = act
            self.items[key].ref_cnt += 1
            self.items[key].event = event
        else:
            self.items[key] = OffloadItem(act, 1, event)

    def put_npu_tensor(self, act):
        self.device_item.append(act)

    def del_npu_tensor(self, prefile_key, d2h_stream):
        for key in list(self.items.keys()):
            if key.startswith(prefile_key):
                self.items[key].act.wait_d2h_finished()

    def get(self, key):
        self.assert_exist(key)
        item = self.items[key]

        act = item.act
        if item.has_event():
            item.get_event().wait()

        item.ref_cnt -= 1
        if item.ref_cnt == 0:
            self.clear(key)
        return act

    def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream):
        prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx)
        for prefetch_key in prefetch_keys:
            if self.exist(prefetch_key):
                prefetch_swap_tensor = self.get(prefetch_key)
                prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True)

    def empty(self):
        return len(self.items) == 0

    def clear(self, key=None):
        if key is None:
            self.items.clear()
        else:
            self.assert_exist(key)
            self.items.pop(key)

    # event interface #

    def get_event(self, key):
        self.assert_exist(key)
        item = self.items[key]
        event = item.get_event()
        return event

    def has_event(self, key):
        if not self.exist(key):
            return False
        item = self.items[key]
        return item.has_event()


class async_save_on_cpu(saved_tensors_hooks):
    def __init__(self, h2d_stream, d2h_stream, block_idx, depth, custom_check_fn=None, prefetch=True) -> None:
        def _pack_to_cpu(tensor):
            if not base_check_fn(tensor):
                return tensor

            if (custom_check_fn is not None) and (not custom_check_fn(tensor)):
                return tensor

            key, after_block = OffloadManager().get_cnt(block_idx)

            if after_block:
                OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream)

            swap_tensor = SwapTensor(tensor, key)

            if block_idx < depth - 1:
                producing_stream = torch.cuda.current_stream()
                d2h_stream.wait_stream(producing_stream)
                swap_tensor.launch_d2h(d2h_stream)

            OffloadManager().put(key, swap_tensor)
            return swap_tensor

        def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
            if isinstance(swap_tensor, torch.Tensor):
                return swap_tensor

            consumer_stream = torch.cuda.current_stream()
            swap_tensor.launch_h2d(h2d_stream, True, consumer_stream)

            if prefetch:
                block_idx_str, tensor_idx_str = swap_tensor.key.split("_")
                OffloadManager().prefetch_get(int(block_idx_str), int(tensor_idx_str), h2d_stream, d2h_stream)
            return swap_tensor.tensor

        super().__init__(_pack_to_cpu, _unpack_from_cpu)

This pull request was created as a result of the following prompt from Copilot chat.

Summary There is a bug in xtuner/v1/utils/activation_offload.py that can cause incorrect async ordering between CUDA streams during activation offloading/prefetching. This manifests as potential race conditions, deadlocks (mutual stream waits), or consuming tensors before H2D/D2H copies complete.

Root Causes

  • Events are recorded on the wrong stream or without strong producer/consumer linkage, e.g., creating an event on the current stream and making another stream wait for it without ensuring the data producer actually recorded completion.
  • Mutual waits between streams (working_stream.wait_stream(h2d_stream) followed by h2d_stream.wait_stream(working_stream)) in _unpack_from_cpu can lead to deadlock or no-ops depending on stream identity.
  • Mixing wait_stream and event waits inconsistently (including waits on default/current streams) causes unclear dependencies and potential global syncs.

Goals

  • Make stream synchronization explicit and correct using events recorded on the stream that executes the copy, and have consumer streams wait on those events only.
  • Remove mutual wait_stream patterns and avoid unnecessary default stream waits.
  • Bind tensor lifetimes to the stream(s) performing the copies or consuming the tensors via record_stream to avoid premature reuse.

Changes Requested Edit xtuner/v1/utils/activation_offload.py as follows:

  1. SwapTensor
  • Create dedicated events: self.d2h_event and self.h2d_event.
  • In launch_d2h(d2h_stream):
    • Launch the copy on d2h_stream, record d2h_event on d2h_stream immediately after the copy is enqueued, and set stat to "host".
    • Remove the temporary forward_event and cross-stream wait.
  • In wait_d2h_finished():
    • Wait on d2h_event from the current stream, then resize device storage to 0.
  • In launch_h2d(h2d_stream, resize_storage, consumer_stream):
    • Optionally resize storage, enqueue H2D copy on h2d_stream, record h2d_event on h2d_stream, set stat to "device".
    • Make consumer_stream wait on h2d_event and record_stream(consumer_stream) on tensor to tie its lifetime.
  • In prefetch_launch_h2d(h2d_stream, resize_storage):
    • Same as launch_h2d but without consumer binding; record_stream(h2d_stream) to keep lifetime until the copy completes.
  • In wait_h2d_finished():
    • Wait on h2d_event from the current stream if needed.
  1. OffloadManager
  • del_npu_tensor: call act.wait_d2h_finished() for all matching keys; do not rely on stream.wait_stream.
  • prefetch_get: remove cross wait between d2h and h2d streams; just call prefetch_launch_h2d(h2d_stream, True) and rely on events.
  1. async_save_on_cpu hooks
  • _pack_to_cpu:
    • When after_block, ensure previous block’s tensors finish D2H and shrink storage via OffloadManager().del_npu_tensor.
    • Before calling launch_d2h, make d2h_stream wait on the producing stream (torch.cuda.current_stream()).
  • _unpack_from_cpu:
    • Remove mutual wait_stream calls. Instead obtain consumer_stream = torch.cuda.current_stream(), then call swap_tensor.launch_h2d(h2d_stream, True, consumer_stream). If prefetch is enabled, compute keys and call OffloadManager().prefetch_get.

Proposed Implementation Please replace the current file with the following implementation, which applies the above corrections and keeps the original API surface compatible:

"""
This file is adapted from: https://gitee.com/ascend/MindSpeed-MM/blob/master/mindspeed_mm/utils/async_offload.py
Original Author: liyx616
Original License: MIT
Modifications: To enable compatibility on both GPU and NPU, replace all torch.npu with torch.cuda, and then use the transfer_to_npu interface
"""

import torch
from torch.autograd.graph import saved_tensors_hooks

from xtuner.v1.utils.device import get_device


if get_device() == "npu":
    from torch_npu.contrib import transfer_to_npu  # noqa


def base_check_fn(tensor):
    if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter):
        return False
    if tensor.storage().size() <= 0:
        return False
    return True


class GetCnt:
    def __init__(self):
        self._block_idx = -1
        self._block_tensor_nums = {}  # offload tensors per block

    def get_cnt(self, block_idx):
        after_block = False
        if block_idx > self._block_idx:
            self._block_tensor_nums[block_idx] = 1
            if block_idx != 0:
                after_block = True
            self._block_idx = block_idx
        elif block_idx == self._block_idx:
            self._block_tensor_nums[block_idx] += 1
        else:
            # one step end
            self._block_idx = block_idx
            self._block_tensor_nums = {block_idx: 1}

        offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}"
        return offload_tensor_key, after_block

    def get_prefetch_keys(self, block_idx, tensor_idx):
        prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None)

        if prefetch_block_idx is None:
            return []

        prefetch_block_tensor_nums = self._block_tensor_nums[prefetch_block_idx]
        block_tensor_nums = self._block_tensor_nums[block_idx]
        start = tensor_idx * prefetch_block_tensor_nums // block_tensor_nums
        end = (tensor_idx + 1) * prefetch_block_tensor_nums // block_tensor_nums
        prefetch_idxs = list(range(start, end))
        return [f"{block_idx - 1}_{prefetch_idx}" for prefetch_idx in prefetch_idxs]


class SwapTensor:
    def __init__(self, tensor, key):
        self.tensor = tensor
        self.size = tensor.size()
        self.storage_size = tensor.storage().size()
        self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device="cpu")

        self.is_slice_tensor = tensor.storage().size() != tensor.numel()
        self.stat = "device"
        self.key = key

        # events marking copy completion
        self.d2h_event = None
        self.h2d_event = None

    # device to host
    def launch_d2h(self, d2h_stream: torch.cuda.Stream):
        if self.stat != "device":
            return

        with torch.no_grad():
            with torch.cuda.stream(d2h_stream):
                if self.is_slice_tensor:
                    self.tensor_cpu.copy_(self.tensor, non_blocking=True)
                else:
                    self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True)

                if self.d2h_event is None:
                    self.d2h_event = torch.cuda.Event()
                self.d2h_event.record(d2h_stream)

                self.stat = "host"

    # synchronize d2h and resize 0
    def wait_d2h_finished(self):
        if self.stat != "host":
            return
        if self.d2h_event is not None:
            torch.cuda.current_stream().wait_event(self.d2h_event)
        self.tensor.storage().resize_(0)

    # resize storage_size and host to device
    def launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool, consumer_stream: torch.cuda.Stream):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"

                consumer_stream.wait_event(self.h2d_event)
                self.tensor.record_stream(consumer_stream)

    # prefetch host to device without binding to a consumer stream
    def prefetch_launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool):
        if self.stat != "host":
            return

        if resize_storage:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch.cuda.stream(h2d_stream):
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)

                if self.h2d_event is None:
                    self.h2d_event = torch.cuda.Event()
                self.h2d_event.record(h2d_stream)

                self.stat = "device"
                self.tensor.record_stream(h2d_stream)

    # synchronize h2d
    def wait_h2d_finished(self):
        if self.stat != "device":
            return
        if self.h2d_event is not None:
            torch.cuda.current_stream().wait_event(self.h2d_event)


class SingletonMeta(type):
    """Single meta class."""

    _instances = {}  # type: ignore

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            instance = super().__call__(*args, **kwargs)
            cls._instances[cls] = instance

        return cls._instances[cls]


class OffloadItem:
    """Class for offload item."""

    def __init__(self, act=None, ref_cnt=0, event=None):
        self.act = act
        self.ref_cnt = ref_cnt
        self.event = event

    def get_event(self):
        return self.event

    def has_event(self):
        return self.event is not None


class OffloadManager(metaclass=SingletonMeta):
    """Class for offload manager."""

    def __init__(self, check=False):
        self.items = {}
        self.check = check
        self.device_item = []
        self.getcnt = GetCnt()

    def get_cnt(self, block_idx):
        return self.getcnt.get_cnt(block_idx)

    def assert_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} does not exist in items")

    def exist(self, key):
        return key in self.items

    def assert_not_exist(self, key):
        if key not in self.items:
            raise RuntimeError(f"Key {key} already exist in items")

    def put(self, key, act, event=None):
        if key in self.items:
            self.items[key].act = act
            self.items[key].ref_cnt += 1
            self.items[key].event = event
        else:
            self.items[key] = OffloadItem(act, 1, event)

    def put_npu_tensor(self, act):
        self.device_item.append(act)

    def del_npu_tensor(self, prefile_key, d2h_stream):
        for key in list(self.items.keys()):
            if key.startswith(prefile_key):
                self.items[key].act.wait_d2h_finished()

    def get(self, key):
        self.assert_exist(key)
        item = self.items[key]

        act = item.act
        if item.has_event():
            item.get_event().wait()

        item.ref_cnt -= 1
        if item.ref_cnt == 0:
            self.clear(key)
        return act

    def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream):
        prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx)
        for prefetch_key in prefetch_keys:
            if self.exist(prefetch_key):
                prefetch_swap_tensor = self.get(prefetch_key)
                prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True)

    def empty(self):
        return len(self.items) == 0

    def clear(self, key=None):
        if key is None:
            self.items.clear()
        else:
            self.assert_exist(key)
            self.items.pop(key)

    # event interface #

    def get_event(self, key):
        self.assert_exist(key)
        item = self.items[key]
        event = item.get_event()
        return event

    def has_event(self, key):
        if not self.exist(key):
            return False
        item = self.items[key]
        return item.has_event()


class async_save_on_cpu(saved_tensors_hooks):
    def __init__(self, h2d_stream, d2h_stream, block_idx, depth, custom_check_fn=None, prefetch=True) -> None:
        def _pack_to_cpu(tensor):
            if not base_check_fn(tensor):
                return tensor

            if (custom_check_fn is not None) and (not custom_check_fn(tensor)):
                return tensor

            key, after_block = OffloadManager().get_cnt(block_idx)

            if after_block:
                OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream)

            swap_tensor = SwapTensor(tensor, key)

            if block_idx < depth - 1:
                producing_stream = torch.cuda.current_stream()
                d2h_stream.wait_stream(producing_stream)
                swap_tensor.launch_d2h(d2h_stream)

            OffloadManager().put(key, swap_tensor)
            return swap_tensor

        def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
            if isinstance(swap_tensor, torch.Tensor):
                return swap_tensor

            consumer_stream = torch.cuda.current_stream()
            swap_tensor.launch_h2d(h2d_stream, True, consumer_stream)

            if prefetch:
                block_idx_str, tensor_idx_str = swap_tensor.key.split("_")
                OffloadManager().prefetch_get(int(block_idx_str), int(tensor_idx_str), h2d_stream, d2h_stream)
            return swap_tensor.tensor

        super().__init__(_pack_to_cpu, _unpack_from_cpu)

💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

Copilot avatar Sep 12 '25 14:09 Copilot

@copilot Please edit the file directly.

pppppM avatar Sep 12 '25 14:09 pppppM