intel-extension-for-pytorch
intel-extension-for-pytorch copied to clipboard
Communication and compute on separate Streams do not overlap
Describe the bug
Communication and computation do not appear to overlap when launching kernels in different xpu.Stream
s (on Intel GPU Max 1550s). Being able to overlap communication and communication is crucial for efficiency. DeepSpeed and FSDP both use Stream
objects for this purpose, for instance.
To test this, I am launching communication and compute in various permutations of using Stream
s or not. Driver code which operates on both xpu
and cuda
:
"""
Tests the ability of Stream objects to overlap computation and compute.
Compute: bfloat16 matmuls
Comms: bfloat16 all_reduce
The script first times the comms and compute operations separately. Then, comms and compute
operations are launched together in various ways:
* All kernels sent to the default stream
* Comms and compute kernels sent to separate streams
Expectation:
* No overlap when all kernels are in the default stream (since they run sequentially).
Total time is approximately equal to the sum of the individually measured comms and compute
times.
* Comms and compute overlap when processed by different streams. Total time is less than the sum
of the individually comms and compute times.
The ratio of the various times are printed out to test overlap.
Example of running with two gpus on one node:
torchrun --nnodes=1 --nproc-per-node=2 streams_overlap_test.py
"""
import io
import os
from contextlib import contextmanager
from dataclasses import dataclass
from time import perf_counter
from typing import Optional
import torch
import torch.distributed as dist
if torch.cuda.is_available():
assert torch.cuda.is_available()
from torch import cuda as accel # noqa
DEVICE_TYPE = "cuda"
BACKEND = "nccl"
else:
import intel_extension_for_pytorch as ipex # noqa
from torch import xpu as accel # noqa
import oneccl_bindings_for_pytorch # noqa
DEVICE_TYPE = "xpu"
BACKEND = "ccl"
# Matrix sizes, iterations, and warmups. Dimensions chosen to make the compute and comms times
# similar.
COMPUTE_DIM = 2**14
COMMS_DIM = 4 * COMPUTE_DIM
ITERS = 20
WARMUPS = 3
RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}")
DTYPE = torch.bfloat16
accel.set_device(DEVICE)
compute_stream = accel.Stream(device=DEVICE)
comms_stream = accel.Stream(device=DEVICE)
compute_matrix = torch.randn(COMPUTE_DIM, COMPUTE_DIM, device=DEVICE, dtype=DTYPE)
comms_matrix = torch.randn(COMMS_DIM, COMMS_DIM, device=DEVICE, dtype=DTYPE)
# Simple timer class via a context manager. Time w/ perf_counter rather than Events, due to
# https://github.com/intel/intel-extension-for-pytorch/issues/568
@dataclass
class Time:
s: int = 0.0
@contextmanager
def timer():
t = Time()
accel.synchronize()
start = perf_counter()
yield t
# Barrier to ensure all comms are finished on all ranks
dist.barrier()
# An sync CPU to all kernels in all streams.
accel.synchronize()
stop = perf_counter()
# Update the elapsed time in the yielded Time object.
t.s = stop - start
def compute(stream: Optional[accel.Stream] = None) -> None:
with accel.stream(stream):
for _ in range(ITERS):
compute_matrix @ compute_matrix
def comms(stream: Optional[accel.Stream] = None) -> None:
with accel.stream(stream):
for _ in range(ITERS):
dist.all_reduce(comms_matrix)
def main() -> None:
for _ in range(WARMUPS):
compute()
comms()
# Perform computation and comms in different permutations, sometimes using Streams.
with timer() as t_compute_only:
compute()
with timer() as t_comms_only:
comms()
with timer() as t_total_default_stream:
compute()
comms()
with timer() as t_total_compute_stream:
compute(compute_stream)
comms()
with timer() as t_total_comms_stream:
compute()
comms(comms_stream)
with timer() as t_total_compute_and_comms_stream:
compute(compute_stream)
comms(comms_stream)
# Print out results
str_buffer = io.StringIO()
str_buffer.write(f"{RANK=}\n")
str_buffer.write(f"\t Compute matrix shape: {compute_matrix.shape}\n")
str_buffer.write(f"\t Comms matrix shape: {comms_matrix.shape}\n")
# Compare the case of submitting all work to the default stream to performing the operations
# independently. Expect they should take approximately the same amount of time, since all
# kernels run sequentially (ratio ~= 1).
str_buffer.write("\n")
str_buffer.write(f"\t {t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=}\n")
# Performing the compute in a non-default stream should allow for overlap (ratio < 1).
str_buffer.write("\n")
str_buffer.write(f"\t {t_total_compute_stream.s / t_total_default_stream.s =}\n")
# Performing the communication in a non-default stream should allow for overlap (ratio < 1).
str_buffer.write("\n")
str_buffer.write(f"\t {t_total_comms_stream.s / t_total_default_stream.s=}\n")
# Performing the compute and computation in separate, non-default streams should allow for
# overlap (ratio < 1).
str_buffer.write("\n")
str_buffer.write(f"\t {t_total_compute_and_comms_stream.s / t_total_default_stream.s =}\n")
print(str_buffer.getvalue(), flush=True)
if __name__ == "__main__":
try:
dist.init_process_group(backend=BACKEND)
main()
finally:
dist.destroy_process_group()
Running the above on two A100s, I get:
# On CUDA:
RANK=1
Compute matrix shape: torch.Size([16384, 16384])
Comms matrix shape: torch.Size([65536, 65536])
t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.999107484420899
t_total_compute_stream.s / t_total_default_stream.s =0.8255087794284478
t_total_comms_stream.s / t_total_default_stream.s=0.8239232889706464
t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.820933508932193
RANK=0
Compute matrix shape: torch.Size([16384, 16384])
Comms matrix shape: torch.Size([65536, 65536])
t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.999110786949743
t_total_compute_stream.s / t_total_default_stream.s =0.8255076235561873
t_total_comms_stream.s / t_total_default_stream.s=0.8239272185173557
t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.820936611580174
Running on two Intel GPU Max 1550s, I get:
# XPU
RANK=0
Compute matrix shape: torch.Size([16384, 16384])
Comms matrix shape: torch.Size([65536, 65536])
t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.9993644113989368
t_total_compute_stream.s / t_total_default_stream.s =1.0017862128444763
t_total_comms_stream.s / t_total_default_stream.s=0.9987232523971512
t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.9996738417529752
RANK=1
Compute matrix shape: torch.Size([16384, 16384])
Comms matrix shape: torch.Size([65536, 65536])
t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.99933541624957
t_total_compute_stream.s / t_total_default_stream.s =1.001785787678655
t_total_comms_stream.s / t_total_default_stream.s=0.9987192462460256
t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.9996800354416536
A clear speed-up can be seen when using Stream
s in their various permutations on A100s, while no speedup is visible on xpu
. Absolute timings are not included above, but I have verified that the individual compute and comms times are comparable to each other in all cases.
Is this expected? Is there anything clearly wrong with the test code? The SYCL docs seem to imply that overlap should be possible.
Are there are any relevant environment variables that I might need to set?
Versions
PyTorch version: 2.1.0a0+cxx11.abi
PyTorch CXX11 ABI: Yes
IPEX version: 2.1.10+xpu
IPEX commit: a12f9f650
Build type: Release
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: N/A
IGC version: N/A
CMake version: N/A
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.14.21-150500.55.31_13.0.62-cray_shasta_c-x86_64-with-glibc2.35
Is XPU available: False
DPCPP runtime version: N/A
MKL version: N/A
GPU models and configuration:
Intel OpenCL ICD version: 23.30.26918.50-736~22.04
Level Zero version: 1.3.26918.50-736~22.04
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7713 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3720.7029
CPU min MHz: 1500.0000
BogoMIPS: 3992.49
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-15,128-143
NUMA node1 CPU(s): 16-31,144-159
NUMA node2 CPU(s): 32-47,160-175
NUMA node3 CPU(s): 48-63,176-191
NUMA node4 CPU(s): 64-79,192-207
NUMA node5 CPU(s): 80-95,208-223
NUMA node6 CPU(s): 96-111,224-239
NUMA node7 CPU(s): 112-127,240-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.10+xpu
[pip3] mypy==1.5.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.1.0a0+cxx11.abi
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.1.0a0+cxx11.abi
[pip3] torchvision==0.16.0a0+cxx11.abi
[conda] N/A
CC @jingxu10 @tye1, thank you!
Hello, thanks for reporting this issue. I will look into this issue and get back to you.
Thank you @YuningQiu , greatly appreciated!
Hello @garrett361, regarding the specific script mentioned in the GitHub issue, it currently does not overlapping function on PVC.
How it operates on the A100 GPU:
- The script dispatches a series of compute tasks followed by collective operations. These are issued to the GPU without blocking the host, meaning that the compute kernels and collectives are queued before most of them are executed.
- On the A100 GPU, the compute and collective kernels are initiated in an alternating pattern and are executed concurrently. Additional information: On the A100, collectives are executed within kernels that utilize only a few threads. As the first compute kernel nears completion and hardware resources free up, the first independent allreduce from a separate stream is scheduled (while the second compute kernel, which is dependent, waits for its complete execution). Once the first compute kernel finishes, the threads from the second compute kernel begin to operate simultaneously with the collective, as the collective kernel occupies only a limited number of streaming multiprocessors.
Reasons for incompatibility with PVC:
- By default, the initiation of the second allreduce is implicitly delayed until the first allreduce is complete. At this point, several compute tasks but only one collective have been sent to the PVC. Additional information: When using the default (scheduled) path in oneCCL, the destruction of the event at the end of the collective submission code snippet triggers an artificial wait for the collective to complete within the event destructor. This wait blocks the host thread from continuing.
- On PVC, non-dependent kernels from multiple streams are executed in the order they were submitted. The reduction kernel in the first allreduce cannot start until the final compute kernel has finished. Note: Even though oneCCL might use the copy command for data transfer by default, the copy and reduction operations are still interdependent. Therefore, the possibility of overlapping is restricted to the last compute task and a portion of the first allreduce.
Hi @YuningQiu , thank you for the very detailed response! I have a few follow-ups.
By default, the initiation of the second allreduce is implicitly delayed until the first allreduce is complete. At this point, several compute tasks but only one collective have been sent to the PVC
- Ah, you mean even the launch of the second allreduce kernel is delayed?
the destruction of the event at the end of the collective submission code snippet triggers an artificial wait for the collective to complete within the event destructor. This wait blocks the host thread from continuing.
- And this means that the collective blocks any additional kernels being launched, irrespective of what
Stream
they were sent to?
non-dependent kernels from multiple streams are executed in the order they were submitted.
- This means that kernels are executed in launch order regardless of what stream they are put into? If so, I don't understand the utility of
Stream
s.
Note: Even though oneCCL might use the copy command for data transfer by default, the copy and reduction operations are still interdependent. Therefore, the possibility of overlapping is restricted to the last compute task and a portion of the first allreduce.
- I didn't quite understand this. What is the importance of the copy operation here with respect to overlapping?
Finally: I am a little confused about where in the stack the issue lies. Is there an obstruction to overlapping compute and comms at the hardware level? Or is it something in ipex
, torch-ccl
, elsewhere?
And for more color, all of the above seems consistent with what I have seen from the pytorch profiler.
These are traces of a very similar workload where I attempted to overlap comms and compute for two iterations on cuda
(A100) and xpu
(1550).
CUDA
cuda
: both compute and comms operations launch kernels and return immediately on the host, as seen in the minuscule vertical lines preceding the cudaDeviceSynchronize
.
XPU
xpu
: compute launches kernels and returns immediately, but collectives block and span a long time period until the collective finishes.
Isolated Compute and Comms on XPU
I also isolated the xpu
cases where I perform only the compute or the comms individually. The same effects can be seen.
Compute only:
Comms only:
Hello @garrett361, thanks for providing more details. We will take them back and discuss internally. We will keep you posted with any updates.
Also, could you please share with us the PyTorch profiling file that you are showing above? Thanks a lot!
@YuningQiu hi, could you tell me why this was closed please?
I also see I never followed up with the profiling script, my apologies. I can do that next week.
HI @garrett361, I heard that one of or teams from Intel has been directly in touch with you on this issue, and you also created an issue on intel/touch-ccl GitHub repo. Do you want to keep this issue open? Thanks a lot!
Hi @YuningQiu yes, I had a very helpful chat with members of the team. We also said we’d track progress through these GitHub issues, so could you please reopen it?
I cross posted to torch-ccl since I wasn’t sure how that team and ipex interact, nor if they also track ipex issues.
Thanks!
Adding more traces of attempted overlap with other collectives, per Intel's request via direct communication. The results are all qualitatively similar:
- Compute and communication do not overlap
- Launching collective kernels blocks the host thread
- Gaps between comms kernels
All traces taken on Sunspot with versions: torch.__version__='2.1.0a0+cxx11.abi', ipex.__version__='2.1.10+xpu', torch_ccl.__version__='2.1.100+xpu'
All profiles created using the profile_comms_compute_overlap.py
script here with different --collective
args, and otherwise default values, on an single Sunspot (1550) node.
All Gather
torch.distributed.all_gather
All Gather Into Tensor
torch.distributed.all_gather_into_tensor
Reduce Scatter Tensor
torch.distributed.reduce_scatter_tensor
All Reduce
torch.distributed.all_reduce
(I'm not sure why this one uses more streams than the above all_reduce
trace. The previous trace was taken on a different machine from Sunspot.)