[WIP] Fix incorrect CUDA stream synchronization in activation_offload async offload
Thanks for asking me to work on this. I will get started on it and keep this PR's description up to date as I form a plan and make progress.
Original description:
Summary There is a bug in xtuner/v1/utils/activation_offload.py that can cause incorrect async ordering between CUDA streams during activation offloading/prefetching. This manifests as potential race conditions, deadlocks (mutual stream waits), or consuming tensors before H2D/D2H copies complete.
Root Causes
- Events are recorded on the wrong stream or without strong producer/consumer linkage, e.g., creating an event on the current stream and making another stream wait for it without ensuring the data producer actually recorded completion.
- Mutual waits between streams (working_stream.wait_stream(h2d_stream) followed by h2d_stream.wait_stream(working_stream)) in _unpack_from_cpu can lead to deadlock or no-ops depending on stream identity.
- Mixing wait_stream and event waits inconsistently (including waits on default/current streams) causes unclear dependencies and potential global syncs.
Goals
- Make stream synchronization explicit and correct using events recorded on the stream that executes the copy, and have consumer streams wait on those events only.
- Remove mutual wait_stream patterns and avoid unnecessary default stream waits.
- Bind tensor lifetimes to the stream(s) performing the copies or consuming the tensors via record_stream to avoid premature reuse.
Changes Requested Edit xtuner/v1/utils/activation_offload.py as follows:
- SwapTensor
- Create dedicated events: self.d2h_event and self.h2d_event.
- In launch_d2h(d2h_stream):
- Launch the copy on d2h_stream, record d2h_event on d2h_stream immediately after the copy is enqueued, and set stat to "host".
- Remove the temporary forward_event and cross-stream wait.
- In wait_d2h_finished():
- Wait on d2h_event from the current stream, then resize device storage to 0.
- In launch_h2d(h2d_stream, resize_storage, consumer_stream):
- Optionally resize storage, enqueue H2D copy on h2d_stream, record h2d_event on h2d_stream, set stat to "device".
- Make consumer_stream wait on h2d_event and record_stream(consumer_stream) on tensor to tie its lifetime.
- In prefetch_launch_h2d(h2d_stream, resize_storage):
- Same as launch_h2d but without consumer binding; record_stream(h2d_stream) to keep lifetime until the copy completes.
- In wait_h2d_finished():
- Wait on h2d_event from the current stream if needed.
- OffloadManager
- del_npu_tensor: call act.wait_d2h_finished() for all matching keys; do not rely on stream.wait_stream.
- prefetch_get: remove cross wait between d2h and h2d streams; just call prefetch_launch_h2d(h2d_stream, True) and rely on events.
- async_save_on_cpu hooks
- _pack_to_cpu:
- When after_block, ensure previous block’s tensors finish D2H and shrink storage via OffloadManager().del_npu_tensor.
- Before calling launch_d2h, make d2h_stream wait on the producing stream (torch.cuda.current_stream()).
- _unpack_from_cpu:
- Remove mutual wait_stream calls. Instead obtain consumer_stream = torch.cuda.current_stream(), then call swap_tensor.launch_h2d(h2d_stream, True, consumer_stream). If prefetch is enabled, compute keys and call OffloadManager().prefetch_get.
Proposed Implementation Please replace the current file with the following implementation, which applies the above corrections and keeps the original API surface compatible:
""" This file is adapted from: https://gitee.com/ascend/MindSpeed-MM/blob/master/mindspeed_mm/utils/async_offload.py Original Author: liyx616 Original License: MIT Modifications: To enable compatibility on both GPU and NPU, replace all torch.npu with torch.cuda, and then use the transfer_to_npu interface """ import torch from torch.autograd.graph import saved_tensors_hooks from xtuner.v1.utils.device import get_device if get_device() == "npu": from torch_npu.contrib import transfer_to_npu # noqa def base_check_fn(tensor): if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter): return False if tensor.storage().size() <= 0: return False return True class GetCnt: def __init__(self): self._block_idx = -1 self._block_tensor_nums = {} # offload tensors per block def get_cnt(self, block_idx): after_block = False if block_idx > self._block_idx: self._block_tensor_nums[block_idx] = 1 if block_idx != 0: after_block = True self._block_idx = block_idx elif block_idx == self._block_idx: self._block_tensor_nums[block_idx] += 1 else: # one step end self._block_idx = block_idx self._block_tensor_nums = {block_idx: 1} offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}" return offload_tensor_key, after_block def get_prefetch_keys(self, block_idx, tensor_idx): prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None) if prefetch_block_idx is None: return [] prefetch_block_tensor_nums = self._block_tensor_nums[prefetch_block_idx] block_tensor_nums = self._block_tensor_nums[block_idx] start = tensor_idx * prefetch_block_tensor_nums // block_tensor_nums end = (tensor_idx + 1) * prefetch_block_tensor_nums // block_tensor_nums prefetch_idxs = list(range(start, end)) return [f"{block_idx - 1}_{prefetch_idx}" for prefetch_idx in prefetch_idxs] class SwapTensor: def __init__(self, tensor, key): self.tensor = tensor self.size = tensor.size() self.storage_size = tensor.storage().size() self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device="cpu") self.is_slice_tensor = tensor.storage().size() != tensor.numel() self.stat = "device" self.key = key # events marking copy completion self.d2h_event = None self.h2d_event = None # device to host def launch_d2h(self, d2h_stream: torch.cuda.Stream): if self.stat != "device": return with torch.no_grad(): with torch.cuda.stream(d2h_stream): if self.is_slice_tensor: self.tensor_cpu.copy_(self.tensor, non_blocking=True) else: self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True) if self.d2h_event is None: self.d2h_event = torch.cuda.Event() self.d2h_event.record(d2h_stream) self.stat = "host" # synchronize d2h and resize 0 def wait_d2h_finished(self): if self.stat != "host": return if self.d2h_event is not None: torch.cuda.current_stream().wait_event(self.d2h_event) self.tensor.storage().resize_(0) # resize storage_size and host to device def launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool, consumer_stream: torch.cuda.Stream): if self.stat != "host": return if resize_storage: self.tensor.storage().resize_(self.storage_size) with torch.no_grad(): with torch.cuda.stream(h2d_stream): if self.is_slice_tensor: self.tensor.copy_(self.tensor_cpu, non_blocking=True) else: self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True) if self.h2d_event is None: self.h2d_event = torch.cuda.Event() self.h2d_event.record(h2d_stream) self.stat = "device" consumer_stream.wait_event(self.h2d_event) self.tensor.record_stream(consumer_stream) # prefetch host to device without binding to a consumer stream def prefetch_launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool): if self.stat != "host": return if resize_storage: self.tensor.storage().resize_(self.storage_size) with torch.no_grad(): with torch.cuda.stream(h2d_stream): if self.is_slice_tensor: self.tensor.copy_(self.tensor_cpu, non_blocking=True) else: self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True) if self.h2d_event is None: self.h2d_event = torch.cuda.Event() self.h2d_event.record(h2d_stream) self.stat = "device" self.tensor.record_stream(h2d_stream) # synchronize h2d def wait_h2d_finished(self): if self.stat != "device": return if self.h2d_event is not None: torch.cuda.current_stream().wait_event(self.h2d_event) class SingletonMeta(type): """Single meta class.""" _instances = {} # type: ignore def __call__(cls, *args, **kwargs): if cls not in cls._instances: instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance return cls._instances[cls] class OffloadItem: """Class for offload item.""" def __init__(self, act=None, ref_cnt=0, event=None): self.act = act self.ref_cnt = ref_cnt self.event = event def get_event(self): return self.event def has_event(self): return self.event is not None class OffloadManager(metaclass=SingletonMeta): """Class for offload manager.""" def __init__(self, check=False): self.items = {} self.check = check self.device_item = [] self.getcnt = GetCnt() def get_cnt(self, block_idx): return self.getcnt.get_cnt(block_idx) def assert_exist(self, key): if key not in self.items: raise RuntimeError(f"Key {key} does not exist in items") def exist(self, key): return key in self.items def assert_not_exist(self, key): if key not in self.items: raise RuntimeError(f"Key {key} already exist in items") def put(self, key, act, event=None): if key in self.items: self.items[key].act = act self.items[key].ref_cnt += 1 self.items[key].event = event else: self.items[key] = OffloadItem(act, 1, event) def put_npu_tensor(self, act): self.device_item.append(act) def del_npu_tensor(self, prefile_key, d2h_stream): for key in list(self.items.keys()): if key.startswith(prefile_key): self.items[key].act.wait_d2h_finished() def get(self, key): self.assert_exist(key) item = self.items[key] act = item.act if item.has_event(): item.get_event().wait() item.ref_cnt -= 1 if item.ref_cnt == 0: self.clear(key) return act def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream): prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx) for prefetch_key in prefetch_keys: if self.exist(prefetch_key): prefetch_swap_tensor = self.get(prefetch_key) prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True) def empty(self): return len(self.items) == 0 def clear(self, key=None): if key is None: self.items.clear() else: self.assert_exist(key) self.items.pop(key) # event interface # def get_event(self, key): self.assert_exist(key) item = self.items[key] event = item.get_event() return event def has_event(self, key): if not self.exist(key): return False item = self.items[key] return item.has_event() class async_save_on_cpu(saved_tensors_hooks): def __init__(self, h2d_stream, d2h_stream, block_idx, depth, custom_check_fn=None, prefetch=True) -> None: def _pack_to_cpu(tensor): if not base_check_fn(tensor): return tensor if (custom_check_fn is not None) and (not custom_check_fn(tensor)): return tensor key, after_block = OffloadManager().get_cnt(block_idx) if after_block: OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream) swap_tensor = SwapTensor(tensor, key) if block_idx < depth - 1: producing_stream = torch.cuda.current_stream() d2h_stream.wait_stream(producing_stream) swap_tensor.launch_d2h(d2h_stream) OffloadManager().put(key, swap_tensor) return swap_tensor def _unpack_from_cpu(swap_tensor) -> torch.Tensor: if isinstance(swap_tensor, torch.Tensor): return swap_tensor consumer_stream = torch.cuda.current_stream() swap_tensor.launch_h2d(h2d_stream, True, consumer_stream) if prefetch: block_idx_str, tensor_idx_str = swap_tensor.key.split("_") OffloadManager().prefetch_get(int(block_idx_str), int(tensor_idx_str), h2d_stream, d2h_stream) return swap_tensor.tensor super().__init__(_pack_to_cpu, _unpack_from_cpu)
This pull request was created as a result of the following prompt from Copilot chat.
Summary There is a bug in xtuner/v1/utils/activation_offload.py that can cause incorrect async ordering between CUDA streams during activation offloading/prefetching. This manifests as potential race conditions, deadlocks (mutual stream waits), or consuming tensors before H2D/D2H copies complete.
Root Causes
- Events are recorded on the wrong stream or without strong producer/consumer linkage, e.g., creating an event on the current stream and making another stream wait for it without ensuring the data producer actually recorded completion.
- Mutual waits between streams (working_stream.wait_stream(h2d_stream) followed by h2d_stream.wait_stream(working_stream)) in _unpack_from_cpu can lead to deadlock or no-ops depending on stream identity.
- Mixing wait_stream and event waits inconsistently (including waits on default/current streams) causes unclear dependencies and potential global syncs.
Goals
- Make stream synchronization explicit and correct using events recorded on the stream that executes the copy, and have consumer streams wait on those events only.
- Remove mutual wait_stream patterns and avoid unnecessary default stream waits.
- Bind tensor lifetimes to the stream(s) performing the copies or consuming the tensors via record_stream to avoid premature reuse.
Changes Requested Edit xtuner/v1/utils/activation_offload.py as follows:
- SwapTensor
- Create dedicated events: self.d2h_event and self.h2d_event.
- In launch_d2h(d2h_stream):
- Launch the copy on d2h_stream, record d2h_event on d2h_stream immediately after the copy is enqueued, and set stat to "host".
- Remove the temporary forward_event and cross-stream wait.
- In wait_d2h_finished():
- Wait on d2h_event from the current stream, then resize device storage to 0.
- In launch_h2d(h2d_stream, resize_storage, consumer_stream):
- Optionally resize storage, enqueue H2D copy on h2d_stream, record h2d_event on h2d_stream, set stat to "device".
- Make consumer_stream wait on h2d_event and record_stream(consumer_stream) on tensor to tie its lifetime.
- In prefetch_launch_h2d(h2d_stream, resize_storage):
- Same as launch_h2d but without consumer binding; record_stream(h2d_stream) to keep lifetime until the copy completes.
- In wait_h2d_finished():
- Wait on h2d_event from the current stream if needed.
- OffloadManager
- del_npu_tensor: call act.wait_d2h_finished() for all matching keys; do not rely on stream.wait_stream.
- prefetch_get: remove cross wait between d2h and h2d streams; just call prefetch_launch_h2d(h2d_stream, True) and rely on events.
- async_save_on_cpu hooks
- _pack_to_cpu:
- When after_block, ensure previous block’s tensors finish D2H and shrink storage via OffloadManager().del_npu_tensor.
- Before calling launch_d2h, make d2h_stream wait on the producing stream (torch.cuda.current_stream()).
- _unpack_from_cpu:
- Remove mutual wait_stream calls. Instead obtain consumer_stream = torch.cuda.current_stream(), then call swap_tensor.launch_h2d(h2d_stream, True, consumer_stream). If prefetch is enabled, compute keys and call OffloadManager().prefetch_get.
Proposed Implementation Please replace the current file with the following implementation, which applies the above corrections and keeps the original API surface compatible:
""" This file is adapted from: https://gitee.com/ascend/MindSpeed-MM/blob/master/mindspeed_mm/utils/async_offload.py Original Author: liyx616 Original License: MIT Modifications: To enable compatibility on both GPU and NPU, replace all torch.npu with torch.cuda, and then use the transfer_to_npu interface """ import torch from torch.autograd.graph import saved_tensors_hooks from xtuner.v1.utils.device import get_device if get_device() == "npu": from torch_npu.contrib import transfer_to_npu # noqa def base_check_fn(tensor): if isinstance(tensor._base, torch.nn.parameter.Parameter) or isinstance(tensor, torch.nn.parameter.Parameter): return False if tensor.storage().size() <= 0: return False return True class GetCnt: def __init__(self): self._block_idx = -1 self._block_tensor_nums = {} # offload tensors per block def get_cnt(self, block_idx): after_block = False if block_idx > self._block_idx: self._block_tensor_nums[block_idx] = 1 if block_idx != 0: after_block = True self._block_idx = block_idx elif block_idx == self._block_idx: self._block_tensor_nums[block_idx] += 1 else: # one step end self._block_idx = block_idx self._block_tensor_nums = {block_idx: 1} offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}" return offload_tensor_key, after_block def get_prefetch_keys(self, block_idx, tensor_idx): prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None) if prefetch_block_idx is None: return [] prefetch_block_tensor_nums = self._block_tensor_nums[prefetch_block_idx] block_tensor_nums = self._block_tensor_nums[block_idx] start = tensor_idx * prefetch_block_tensor_nums // block_tensor_nums end = (tensor_idx + 1) * prefetch_block_tensor_nums // block_tensor_nums prefetch_idxs = list(range(start, end)) return [f"{block_idx - 1}_{prefetch_idx}" for prefetch_idx in prefetch_idxs] class SwapTensor: def __init__(self, tensor, key): self.tensor = tensor self.size = tensor.size() self.storage_size = tensor.storage().size() self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device="cpu") self.is_slice_tensor = tensor.storage().size() != tensor.numel() self.stat = "device" self.key = key # events marking copy completion self.d2h_event = None self.h2d_event = None # device to host def launch_d2h(self, d2h_stream: torch.cuda.Stream): if self.stat != "device": return with torch.no_grad(): with torch.cuda.stream(d2h_stream): if self.is_slice_tensor: self.tensor_cpu.copy_(self.tensor, non_blocking=True) else: self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True) if self.d2h_event is None: self.d2h_event = torch.cuda.Event() self.d2h_event.record(d2h_stream) self.stat = "host" # synchronize d2h and resize 0 def wait_d2h_finished(self): if self.stat != "host": return if self.d2h_event is not None: torch.cuda.current_stream().wait_event(self.d2h_event) self.tensor.storage().resize_(0) # resize storage_size and host to device def launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool, consumer_stream: torch.cuda.Stream): if self.stat != "host": return if resize_storage: self.tensor.storage().resize_(self.storage_size) with torch.no_grad(): with torch.cuda.stream(h2d_stream): if self.is_slice_tensor: self.tensor.copy_(self.tensor_cpu, non_blocking=True) else: self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True) if self.h2d_event is None: self.h2d_event = torch.cuda.Event() self.h2d_event.record(h2d_stream) self.stat = "device" consumer_stream.wait_event(self.h2d_event) self.tensor.record_stream(consumer_stream) # prefetch host to device without binding to a consumer stream def prefetch_launch_h2d(self, h2d_stream: torch.cuda.Stream, resize_storage: bool): if self.stat != "host": return if resize_storage: self.tensor.storage().resize_(self.storage_size) with torch.no_grad(): with torch.cuda.stream(h2d_stream): if self.is_slice_tensor: self.tensor.copy_(self.tensor_cpu, non_blocking=True) else: self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True) if self.h2d_event is None: self.h2d_event = torch.cuda.Event() self.h2d_event.record(h2d_stream) self.stat = "device" self.tensor.record_stream(h2d_stream) # synchronize h2d def wait_h2d_finished(self): if self.stat != "device": return if self.h2d_event is not None: torch.cuda.current_stream().wait_event(self.h2d_event) class SingletonMeta(type): """Single meta class.""" _instances = {} # type: ignore def __call__(cls, *args, **kwargs): if cls not in cls._instances: instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance return cls._instances[cls] class OffloadItem: """Class for offload item.""" def __init__(self, act=None, ref_cnt=0, event=None): self.act = act self.ref_cnt = ref_cnt self.event = event def get_event(self): return self.event def has_event(self): return self.event is not None class OffloadManager(metaclass=SingletonMeta): """Class for offload manager.""" def __init__(self, check=False): self.items = {} self.check = check self.device_item = [] self.getcnt = GetCnt() def get_cnt(self, block_idx): return self.getcnt.get_cnt(block_idx) def assert_exist(self, key): if key not in self.items: raise RuntimeError(f"Key {key} does not exist in items") def exist(self, key): return key in self.items def assert_not_exist(self, key): if key not in self.items: raise RuntimeError(f"Key {key} already exist in items") def put(self, key, act, event=None): if key in self.items: self.items[key].act = act self.items[key].ref_cnt += 1 self.items[key].event = event else: self.items[key] = OffloadItem(act, 1, event) def put_npu_tensor(self, act): self.device_item.append(act) def del_npu_tensor(self, prefile_key, d2h_stream): for key in list(self.items.keys()): if key.startswith(prefile_key): self.items[key].act.wait_d2h_finished() def get(self, key): self.assert_exist(key) item = self.items[key] act = item.act if item.has_event(): item.get_event().wait() item.ref_cnt -= 1 if item.ref_cnt == 0: self.clear(key) return act def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream): prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx) for prefetch_key in prefetch_keys: if self.exist(prefetch_key): prefetch_swap_tensor = self.get(prefetch_key) prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True) def empty(self): return len(self.items) == 0 def clear(self, key=None): if key is None: self.items.clear() else: self.assert_exist(key) self.items.pop(key) # event interface # def get_event(self, key): self.assert_exist(key) item = self.items[key] event = item.get_event() return event def has_event(self, key): if not self.exist(key): return False item = self.items[key] return item.has_event() class async_save_on_cpu(saved_tensors_hooks): def __init__(self, h2d_stream, d2h_stream, block_idx, depth, custom_check_fn=None, prefetch=True) -> None: def _pack_to_cpu(tensor): if not base_check_fn(tensor): return tensor if (custom_check_fn is not None) and (not custom_check_fn(tensor)): return tensor key, after_block = OffloadManager().get_cnt(block_idx) if after_block: OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream) swap_tensor = SwapTensor(tensor, key) if block_idx < depth - 1: producing_stream = torch.cuda.current_stream() d2h_stream.wait_stream(producing_stream) swap_tensor.launch_d2h(d2h_stream) OffloadManager().put(key, swap_tensor) return swap_tensor def _unpack_from_cpu(swap_tensor) -> torch.Tensor: if isinstance(swap_tensor, torch.Tensor): return swap_tensor consumer_stream = torch.cuda.current_stream() swap_tensor.launch_h2d(h2d_stream, True, consumer_stream) if prefetch: block_idx_str, tensor_idx_str = swap_tensor.key.split("_") OffloadManager().prefetch_get(int(block_idx_str), int(tensor_idx_str), h2d_stream, d2h_stream) return swap_tensor.tensor super().__init__(_pack_to_cpu, _unpack_from_cpu)
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.
@copilot Please edit the file directly.