PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

TP+PiPPy failing on HF examples.

Open HamidShojanazeri opened this issue 1 year ago • 4 comments

installing from src and PT nightlies, trying to add TP to the HF inference example its failing with

RuntimeError: aten.add.Tensor: got mixed distributed and non-distributed tensors I am wondering if I am missing any step here.

env

(PT-nightlies) ubuntu@ip-172-31-44-234:~$ python -m "torch.utils.collect_env"
Collecting environment information...
PyTorch version: 2.1.0.dev20230328+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:39:03)  [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1030-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   48 bits physical, 48 bits virtual
CPU(s):                          48
On-line CPU(s) list:             0-47
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7R32
Stepping:                        0
CPU MHz:                         2799.870
BogoMIPS:                        5599.74
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       768 KiB
L1i cache:                       768 KiB
L2 cache:                        12 MiB
L3 cache:                        96 MiB
NUMA node0 CPU(s):               0-47
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.1.0+e650d3708b
[pip3] torch==2.1.0.dev20230328+cu117
[pip3] torchaudio==2.1.0.dev20230328+cu117
[pip3] torchvision==0.16.0.dev20230328+cu117
[pip3] vit-pytorch==1.2.0
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] pytorch-triton            2.1.0+e650d3708b          pypi_0    pypi
[conda] torch                     2.1.0.dev20230328+cu117          pypi_0    pypi
[conda] torchaudio                2.1.0.dev20230328+cu117          pypi_0    pypi
[conda] torchvision               0.16.0.dev20230328+cu117          pypi_0    pypi
[conda] vit-pytorch               1.2.0                    pypi_0    pypi

Repro

# Copyright (c) Meta Platforms, Inc. and affiliates
import argparse
import os
import time

import torch
import pippy
import pippy.fx
from pippy import run_pippy
from pippy.hf import PiPPyHFTracer
from pippy import split_on_size_threshold, split_into_equal_size
from transformers import  AutoModelForSeq2SeqLM
from transformers import OPTModel, BloomModel
from PIL import Image
import requests
from transformers import AutoFeatureExtractor, RegNetModel 
from torch.distributed._tensor import (
    DeviceMesh,
)
from torch.distributed.tensor.parallel import (
    PairwiseParallel,
    parallelize_module,
)


pippy.fx.Tracer.proxy_buffer_attributes = True

gigabyte_size = 1024 ** 3
megabyte_size = 1024 ** 2


def format_to_gb(item, precision=4):
    """quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
    metric_num = item / gigabyte_size
    metric_num = round(metric_num, ndigits=precision)
    return metric_num


def print_mem_usage():
    memory_reserved = format_to_gb(torch.cuda.memory_reserved())
    memory_allocated = format_to_gb(torch.cuda.memory_allocated())
    print(
        f"memory_reserved: {memory_reserved} GB, "
        f"memory_allocated: {memory_allocated} GB"
    )


def get_number_of_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def generate_input(args):
    bs = args.batch_size * args.chunks
    seq_length = args.seq_length
    model_config = args.model.config
    torch.manual_seed(args.rank)

    # preparing inputs based on the model choice
    if 't5' in args.model_name:
        inp = torch.empty(bs, seq_length, dtype=torch.long, device=args.device).random_(model_config.vocab_size)
        model_input_dict = {'input_ids': inp, 'decoder_input_ids': inp}
    elif 'opt' or 'bloom' in args.model_name:
        inp = torch.empty(bs, seq_length, dtype=torch.long, device=args.device).random_(model_config.vocab_size)
        model_input_dict = {'input_ids': inp}
    elif 'regnet' in args.model_name:
        url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        image = Image.open(requests.get(url, stream=True).raw)
        inputs = args.feature_extractor(image, return_tensors="pt")
        inputs["pixel_values"] = inputs["pixel_values"]
        model_input_dict = {'pixel_values': inputs["pixel_values"]}

    return model_input_dict


def run_all(pp_ranks, args):
    model = args.model
    model.eval()
    model.config.use_cache = False  # don't output `past_key_values`
    num_ranks = len(pp_ranks)
    device_type = "cuda" if args.cuda else "cpu"
    pp_rank = args.rank // args.tp_group_size

    if args.rank == 0:
        print("Using schedule:", args.schedule)
        print(model.config)
        print(f"model total number of params = {get_number_of_params(model) // 10 ** 6}M")

    if args.auto_split == "threshold":
        split_policy = split_on_size_threshold(490 * 1e6)
    elif args.auto_split == "equal_size":
        split_policy = split_into_equal_size(num_ranks)

    model_input_dict = generate_input(args)
    # Use default value for other kwargs than those in `model_input_dict`
    concrete_args = pippy.create_default_args(
        model,
        except_keys=model_input_dict.keys(),
    )

    model_init_start = time.time()
    # model.to("cuda") # with/out running into the same error
    pipe_driver, stage_mod = pippy.all_compile(
        model,
        num_ranks,
        args.chunks,
        schedule=args.schedule,
        split_policy=split_policy,
        tracer=PiPPyHFTracer(),
        concrete_args=concrete_args,
    )

      # Create TP device mesh
    my_device_mesh = None
    for stage in range(args.pp_group_size):
        start_rank = stage * args.tp_group_size
        tp_ranks = list(range(start_rank, start_rank + args.tp_group_size))
        tp_device_mesh = DeviceMesh(
            device_type,
            tp_ranks,
        )
        if stage == pp_rank:
            my_device_mesh = tp_device_mesh

    # Tensor parallelize submodules
    print(f"Rank {args.rank} calling parallelize_module with {my_device_mesh}")
    parallelize_module(stage_mod, my_device_mesh, PairwiseParallel())


    model_init_end = time.time()

    params = get_number_of_params(stage_mod)
    print(f"submod_{args.rank} {params // 10 ** 6}M params")

    if args.rank == 0:
        print(f"Model init time: {model_init_end - model_init_start} s")
        print_mem_usage()
        print('Running model pipeline.')

        for _ in range(args.num_batches):
            pipe_driver(**model_input_dict)

        print('Inference is finished')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
    parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
    parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
    parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))

    parser.add_argument('--model_name', type=str, default='facebook/opt-350m')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--chunks', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
    parser.add_argument('--num_batches', type=int, default=1)
    parser.add_argument('--seq_length', type=int, default=16)
    parser.add_argument('--avg_seqlen', type=int, default=16)
    parser.add_argument('--max_seqlen', type=int, default=16)
    parser.add_argument('--seqlen-stdev', type=int, default=10)

    parser.add_argument('-s', '--schedule', type=str, default="FillDrain")
    parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
    parser.add_argument('--visualize', type=int, default=1, choices=[0, 1])
    parser.add_argument('--pp_group_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
    parser.add_argument('--auto_split', type=str, default="equal_size")

    args = parser.parse_args()

    assert args.world_size % args.pp_group_size == 0
    args.tp_group_size = args.world_size // args.pp_group_size
    print(f"Using tensor parallel group size: {args.tp_group_size}")

    # Main process loads model
    print(f"Loading model {args.model_name}")
    if 't5' in args.model_name:
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, use_cache=False)
    if 'opt' in args.model_name:
        model = OPTModel.from_pretrained(args.model_name, use_cache=False)
    if 'bloom' in args.model_name:
        model = BloomModel.from_pretrained(args.model_name, use_cache=False)
    if 'regnet' in args.model_name:
        feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/regnet-y-10b-seer")
        model = RegNetModel.from_pretrained("facebook/regnet-y-10b-seer")
        args.feature_extractor = feature_extractor
    args.model = model

    args.gspmd = 1
    run_pippy(run_all, args)

Error log

This kind of hangs there..

(PT-nightlies) ubuntu@ip-172-31-44-234:~/tau/examples/inference$ python TP_HF.py 
Using tensor parallel group size: 1
Loading model facebook/opt-350m
[PiPPy] World size: 4, DP group size: 1, PP group size: 4
rank = 0 host/pid/device = ip-172-31-44-234/10535/cuda:0
rank = 2 host/pid/device = ip-172-31-44-234/10537/cuda:2
rank = 1 host/pid/device = ip-172-31-44-234/10536/cuda:1
rank = 3 host/pid/device = ip-172-31-44-234/10538/cuda:3
Using schedule: FillDrain
OPTConfig {
  "_name_or_path": "facebook/opt-350m",
  "_remove_final_layer_norm": false,
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "architectures": [
    "OPTForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "do_layer_norm_before": false,
  "dropout": 0.1,
  "enable_bias": true,
  "eos_token_id": 2,
  "ffn_dim": 4096,
  "hidden_size": 1024,
  "init_std": 0.02,
  "layer_norm_elementwise_affine": true,
  "layerdrop": 0.0,
  "max_position_embeddings": 2048,
  "model_type": "opt",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "prefix": "</s>",
  "torch_dtype": "float16",
  "transformers_version": "4.27.3",
  "use_cache": false,
  "vocab_size": 50272,
  "word_embed_proj_dim": 512
}

model total number of params = 331M
Rank 2 calling parallelize_module with DeviceMesh:([2])
Rank 3 calling parallelize_module with DeviceMesh:([3])
Rank 1 calling parallelize_module with DeviceMesh:([1])
Rank 0 calling parallelize_module with DeviceMesh:([0])
submod_0 81M params
Model init time: 2.4849045276641846 s
memory_reserved: 1.2598 GB, memory_allocated: 1.2338 GB
Running model pipeline.
submod_1 80M params
submod_2 81M params
submod_3 86M params
Exception in thread worker_0:
Traceback (most recent call last):
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/PipelineDriver.py", line 485, in worker_loop
    out_val, flat_tensor_args = forward(
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/PipelineDriver.py", line 448, in forward
    out_val = forward_maybe_with_ddp(args, kwargs)
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/PipelineDriver.py", line 432, in forward_maybe_with_ddp
    out_val = stage_executor.mod(*args, **kwargs)
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/fx/graph_module.py", line 281, in __call__
    raise e
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/pippy-0.1.0a0+27e4111-py3.9.egg/pippy/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.13", line 62, in forward
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/torch/_tensor.py", line 1296, in __torch_function__
    ret = func(*args, **kwargs)
  File "/opt/conda/envs/PT-nightlies/lib/python3.9/site-packages/torch/distributed/_tensor/api.py", line 231, in __torch_dispatch__
    raise RuntimeError(
RuntimeError: aten.add.Tensor: got mixed distributed and non-distributed tensors.

HamidShojanazeri avatar Mar 31 '23 04:03 HamidShojanazeri

Hey @HamidShojanazeri, upon looking into the error log, you are using TP size = 1. Can you try TP size more than 1, like 2?

Because if TP size = 1, there is no reason to parallelize model.

fduwjj avatar May 03 '23 02:05 fduwjj

Hey @HamidShojanazeri, I am curious about your example too (going to attempt something very similar :) ). Did you get this to work, either with a different TP size or some other fix?

Vatshank avatar May 15 '23 06:05 Vatshank

Hi @fduwjj, I tried this with TP size = 2 and still running into the same RuntimeError: aten.add.Tensor: got mixed distributed and non-distributed tensors.. Any ideas what could be going wrong here?

Vatshank avatar May 21 '23 16:05 Vatshank

@Vatshank Can you try to use ColwiseParallel and RowwiseParallel instead? You need to specify the path of the model though. Let me know if that helps.

fduwjj avatar May 25 '23 21:05 fduwjj