torchft
torchft copied to clipboard
[WIP] Streaming DiLoCo prototype
Creating a small script to quickly hack on the implementation for streaming DiLoCo.
Run with (start lighthouse first by looking at command in README.md):
cd streaming_diloco_prototype
torchx run
Issues found:
- Quantization only supports 2D tensors (can workaround)
replica_0/0 File "/home/howardhuang/.conda/envs/torchft/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
replica_0/0 return f(*args, **kwargs)
replica_0/0 File "/data/users/howardhuang/torchft/streaming_diloco_prototype/train.py", line 275, in streaming_diloco
replica_0/0 fut = allreduce_quantized(params_data, ReduceOp.AVG, pg)
replica_0/0 File "/data/users/howardhuang/torchft/torchft/collectives.py", line 104, in allreduce_quantized
replica_0/0 quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
replica_0/0 File "/data/users/howardhuang/torchft/torchft/quantization.py", line 520, in fused_quantize_into_fp8
replica_0/0 ) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
replica_0/0 File "/data/users/howardhuang/torchft/torchft/quantization.py", line 450, in _prepare_quantize_fp8
replica_0/0 assert len(inputs[i].shape) == 2, "Only 2D tensors are supported"
replica_0/0 AssertionError: Only 2D tensors are supported
- triton runtime jit issue when calling
_fused_kernel_quantize_into_fp8[grid]
replica_0/0 File "/data/users/howardhuang/torchft/streaming_diloco_prototype/train.py", line 275, in streaming_diloco
replica_0/0 fut = allreduce_quantized(params_data, ReduceOp.AVG, pg)
replica_0/0 File "/data/users/howardhuang/torchft/torchft/collectives.py", line 104, in allreduce_quantized
replica_0/0 quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
replica_0/0 File "/data/users/howardhuang/torchft/torchft/quantization.py", line 531, in fused_quantize_into_fp8
replica_0/0 _fused_kernel_quantize_into_fp8[grid](
replica_0/0 File "/home/howardhuang/.conda/envs/torchft/lib/python3.10/site-packages/triton/runtime/jit.py", line 499, in run
replica_0/0 if key not in self.cache[device]:
replica_0/0 TypeError: unhashable type: 'constexpr'