torchft icon indicating copy to clipboard operation
torchft copied to clipboard

[WIP] Streaming DiLoCo prototype

Open H-Huang opened this issue 6 months ago • 0 comments

Creating a small script to quickly hack on the implementation for streaming DiLoCo.

Run with (start lighthouse first by looking at command in README.md):

cd streaming_diloco_prototype torchx run

Issues found:

  1. Quantization only supports 2D tensors (can workaround)
replica_0/0     File "/home/howardhuang/.conda/envs/torchft/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
replica_0/0       return f(*args, **kwargs)
replica_0/0     File "/data/users/howardhuang/torchft/streaming_diloco_prototype/train.py", line 275, in streaming_diloco
replica_0/0       fut = allreduce_quantized(params_data, ReduceOp.AVG, pg)
replica_0/0     File "/data/users/howardhuang/torchft/torchft/collectives.py", line 104, in allreduce_quantized
replica_0/0       quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
replica_0/0     File "/data/users/howardhuang/torchft/torchft/quantization.py", line 520, in fused_quantize_into_fp8
replica_0/0       ) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
replica_0/0     File "/data/users/howardhuang/torchft/torchft/quantization.py", line 450, in _prepare_quantize_fp8
replica_0/0       assert len(inputs[i].shape) == 2, "Only 2D tensors are supported"
replica_0/0   AssertionError: Only 2D tensors are supported
  1. triton runtime jit issue when calling _fused_kernel_quantize_into_fp8[grid]
replica_0/0     File "/data/users/howardhuang/torchft/streaming_diloco_prototype/train.py", line 275, in streaming_diloco
replica_0/0       fut = allreduce_quantized(params_data, ReduceOp.AVG, pg)
replica_0/0     File "/data/users/howardhuang/torchft/torchft/collectives.py", line 104, in allreduce_quantized
replica_0/0       quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
replica_0/0     File "/data/users/howardhuang/torchft/torchft/quantization.py", line 531, in fused_quantize_into_fp8
replica_0/0       _fused_kernel_quantize_into_fp8[grid](
replica_0/0     File "/home/howardhuang/.conda/envs/torchft/lib/python3.10/site-packages/triton/runtime/jit.py", line 499, in run
replica_0/0       if key not in self.cache[device]:
replica_0/0   TypeError: unhashable type: 'constexpr'

H-Huang avatar May 28 '25 16:05 H-Huang