PiPPy
PiPPy copied to clipboard
TP+PiPPy failing on HF examples.
installing from src and PT nightlies, trying to add TP to the HF inference example its failing with
RuntimeError: aten.add.Tensor: got mixed distributed and non-distributed tensors
I am wondering if I am missing any step here.
env
(PT-nightlies) ubuntu@ip-172-31-44-234:~$ python -m "torch.utils.collect_env"
Collecting environment information...
PyTorch version: 2.1.0.dev20230328+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31
Python version: 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:39:03) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1030-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7R32
Stepping: 0
CPU MHz: 2799.870
BogoMIPS: 5599.74
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 768 KiB
L1i cache: 768 KiB
L2 cache: 12 MiB
L3 cache: 96 MiB
NUMA node0 CPU(s): 0-47
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid
Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.1.0+e650d3708b
[pip3] torch==2.1.0.dev20230328+cu117
[pip3] torchaudio==2.1.0.dev20230328+cu117
[pip3] torchvision==0.16.0.dev20230328+cu117
[pip3] vit-pytorch==1.2.0
[conda] numpy 1.24.1 pypi_0 pypi
[conda] pytorch-triton 2.1.0+e650d3708b pypi_0 pypi
[conda] torch 2.1.0.dev20230328+cu117 pypi_0 pypi
[conda] torchaudio 2.1.0.dev20230328+cu117 pypi_0 pypi
[conda] torchvision 0.16.0.dev20230328+cu117 pypi_0 pypi
[conda] vit-pytorch 1.2.0 pypi_0 pypi
Repro
# Copyright (c) Meta Platforms, Inc. and affiliates
import argparse
import os
import time
import torch
import pippy
import pippy.fx
from pippy import run_pippy
from pippy.hf import PiPPyHFTracer
from pippy import split_on_size_threshold, split_into_equal_size
from transformers import AutoModelForSeq2SeqLM
from transformers import OPTModel, BloomModel
from PIL import Image
import requests
from transformers import AutoFeatureExtractor, RegNetModel
from torch.distributed._tensor import (
DeviceMesh,
)
from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
)
pippy.fx.Tracer.proxy_buffer_attributes = True
gigabyte_size = 1024 ** 3
megabyte_size = 1024 ** 2
def format_to_gb(item, precision=4):
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
metric_num = item / gigabyte_size
metric_num = round(metric_num, ndigits=precision)
return metric_num
def print_mem_usage():
memory_reserved = format_to_gb(torch.cuda.memory_reserved())
memory_allocated = format_to_gb(torch.cuda.memory_allocated())
print(
f"memory_reserved: {memory_reserved} GB, "
f"memory_allocated: {memory_allocated} GB"
)
def get_number_of_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def generate_input(args):
bs = args.batch_size * args.chunks
seq_length = args.seq_length
model_config = args.model.config
torch.manual_seed(args.rank)
# preparing inputs based on the model choice
if 't5' in args.model_name:
inp = torch.empty(bs, seq_length, dtype=torch.long, device=args.device).random_(model_config.vocab_size)
model_input_dict = {'input_ids': inp, 'decoder_input_ids': inp}
elif 'opt' or 'bloom' in args.model_name:
inp = torch.empty(bs, seq_length, dtype=torch.long, device=args.device).random_(model_config.vocab_size)
model_input_dict = {'input_ids': inp}
elif 'regnet' in args.model_name:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = args.feature_extractor(image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"]
model_input_dict = {'pixel_values': inputs["pixel_values"]}
return model_input_dict
def run_all(pp_ranks, args):
model = args.model
model.eval()
model.config.use_cache = False # don't output `past_key_values`
num_ranks = len(pp_ranks)
device_type = "cuda" if args.cuda else "cpu"
pp_rank = args.rank // args.tp_group_size
if args.rank == 0:
print("Using schedule:", args.schedule)
print(model.config)
print(f"model total number of params = {get_number_of_params(model) // 10 ** 6}M")
if args.auto_split == "threshold":
split_policy = split_on_size_threshold(490 * 1e6)
elif args.auto_split == "equal_size":
split_policy = split_into_equal_size(num_ranks)
model_input_dict = generate_input(args)
# Use default value for other kwargs than those in `model_input_dict`
concrete_args = pippy.create_default_args(
model,
except_keys=model_input_dict.keys(),
)
model_init_start = time.time()
# model.to("cuda") # with/out running into the same error
pipe_driver, stage_mod = pippy.all_compile(
model,
num_ranks,
args.chunks,
schedule=args.schedule,
split_policy=split_policy,
tracer=PiPPyHFTracer(),
concrete_args=concrete_args,
)
# Create TP device mesh
my_device_mesh = None
for stage in range(args.pp_group_size):
start_rank = stage * args.tp_group_size
tp_ranks = list(range(start_rank, start_rank + args.tp_group_size))
tp_device_mesh = DeviceMesh(
device_type,
tp_ranks,
)
if stage == pp_rank:
my_device_mesh = tp_device_mesh
# Tensor parallelize submodules
print(f"Rank {args.rank} calling parallelize_module with {my_device_mesh}")
parallelize_module(stage_mod, my_device_mesh, PairwiseParallel())
model_init_end = time.time()
params = get_number_of_params(stage_mod)
print(f"submod_{args.rank} {params // 10 ** 6}M params")
if args.rank == 0:
print(f"Model init time: {model_init_end - model_init_start} s")
print_mem_usage()
print('Running model pipeline.')
for _ in range(args.num_batches):
pipe_driver(**model_input_dict)
print('Inference is finished')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))
parser.add_argument('--model_name', type=str, default='facebook/opt-350m')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--chunks', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
parser.add_argument('--num_batches', type=int, default=1)
parser.add_argument('--seq_length', type=int, default=16)
parser.add_argument('--avg_seqlen', type=int, default=16)
parser.add_argument('--max_seqlen', type=int, default=16)
parser.add_argument('--seqlen-stdev', type=int, default=10)
parser.add_argument('-s', '--schedule', type=str, default="FillDrain")
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
parser.add_argument('--visualize', type=int, default=1, choices=[0, 1])
parser.add_argument('--pp_group_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
parser.add_argument('--auto_split', type=str, default="equal_size")
args = parser.parse_args()
assert args.world_size % args.pp_group_size == 0
args.tp_group_size = args.world_size // args.pp_group_size
print(f"Using tensor parallel group size: {args.tp_group_size}")
# Main process loads model
print(f"Loading model {args.model_name}")
if 't5' in args.model_name:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, use_cache=False)
if 'opt' in args.model_name:
model = OPTModel.from_pretrained(args.model_name, use_cache=False)
if 'bloom' in args.model_name:
model = BloomModel.from_pretrained(args.model_name, use_cache=False)
if 'regnet' in args.model_name:
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/regnet-y-10b-seer")
model = RegNetModel.from_pretrained("facebook/regnet-y-10b-seer")
args.feature_extractor = feature_extractor
args.model = model
args.gspmd = 1
run_pippy(run_all, args)
Error log
This kind of hangs there..
(PT-nightlies) ubuntu@ip-172-31-44-234:~/tau/examples/inference$ python TP_HF.py
Using tensor parallel group size: 1
Loading model facebook/opt-350m
[PiPPy] World size: 4, DP group size: 1, PP group size: 4
rank = 0 host/pid/device = ip-172-31-44-234/10535/cuda:0
rank = 2 host/pid/device = ip-172-31-44-234/10537/cuda:2
rank = 1 host/pid/device = ip-172-31-44-234/10536/cuda:1
rank = 3 host/pid/device = ip-172-31-44-234/10538/cuda:3
Using schedule: FillDrain
OPTConfig {
"_name_or_path": "facebook/opt-350m",
"_remove_final_layer_norm": false,
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": false,
"dropout": 0.1,
"enable_bias": true,
"eos_token_id": 2,
"ffn_dim": 4096,
"hidden_size": 1024,
"init_std": 0.02,
"layer_norm_elementwise_affine": true,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.27.3",
"use_cache": false,
"vocab_size": 50272,
"word_embed_proj_dim": 512
}
model total number of params = 331M
Rank 2 calling parallelize_module with DeviceMesh:([2])
Rank 3 calling parallelize_module with DeviceMesh:([3])
Rank 1 calling parallelize_module with DeviceMesh:([1])
Rank 0 calling parallelize_module with DeviceMesh:([0])
submod_0 81M params
Model init time: 2.4849045276641846 s
memory_reserved: 1.2598 GB, memory_allocated: 1.2338 GB
Running model pipeline.
submod_1 80M params
submod_2 81M params
submod_3 86M params
Exception in thread worker_0:
Traceback (most recent call last):
File "/opt/conda/envs/PT-nightlies/lib/python3.9/threading.py", line 980, in _bootstrap_inner
self.run()
File "/opt/conda/envs/PT-nightlies/lib/python3.9/threading.py", line 917, in run
self._target(*self._args, **self._kwargs)
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/PipelineDriver.py", line 485, in worker_loop
out_val, flat_tensor_args = forward(
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/PipelineDriver.py", line 448, in forward
out_val = forward_maybe_with_ddp(args, kwargs)
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/PipelineDriver.py", line 432, in forward_maybe_with_ddp
out_val = stage_executor.mod(*args, **kwargs)
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/fx/graph_module.py", line 662, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/fx/graph_module.py", line 281, in __call__
raise e
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/fx/graph_module.py", line 271, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.13", line 62, in forward
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/torch/_tensor.py", line 1296, in __torch_function__
ret = func(*args, **kwargs)
File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/torch/distributed/_tensor/api.py", line 231, in __torch_dispatch__
raise RuntimeError(
RuntimeError: aten.add.Tensor: got mixed distributed and non-distributed tensors.
Hey @HamidShojanazeri, upon looking into the error log, you are using TP size = 1. Can you try TP size more than 1, like 2?
Because if TP size = 1, there is no reason to parallelize model.
Hey @HamidShojanazeri, I am curious about your example too (going to attempt something very similar :) ). Did you get this to work, either with a different TP size or some other fix?
Hi @fduwjj, I tried this with TP size = 2 and still running into the same RuntimeError: aten.add.Tensor: got mixed distributed and non-distributed tensors.
. Any ideas what could be going wrong here?
@Vatshank Can you try to use ColwiseParallel and RowwiseParallel instead? You need to specify the path of the model though. Let me know if that helps.