DALLE-pytorch icon indicating copy to clipboard operation
DALLE-pytorch copied to clipboard

Inference with DeepSpeed

Open afiaka87 opened this issue 3 years ago • 9 comments

Trying to run generate.py on a DeepSpeed checkpoint currently breaks. Using inference with DeepSpeed should be relatively simple I think - but I couldn't quite figure it out and realized most of the code I was writing actually just belonged in the the DeepSpeedBackend code which I hadn't yet grokked yet. Anyway; so I don't forget - here is some very very broken code bad code that I had written before giving up last night:

Edit: pretend I never wrote this. 

afiaka87 avatar Jun 12 '21 00:06 afiaka87

Looking at train_dalle.py provides some insights from @janEbert prior grokking of Deep Speed. First mistake I'm making here is loading the checkpoint like this:

dalle.load_state_dict(weights)

which is apparently a no-no for DeepSpeed's engine.

afiaka87 avatar Jun 12 '21 10:06 afiaka87

Okay - I did things the way they're meant to be done (i believe) @rom1504 @janEbert @mehdidc


if args.fp16:
    engine = deepspeed.init_inference(dalle, dtype=torch.half)
engine = deepspeed.init_inference(dalle)
# training

for epoch in range(EPOCHS):
    if data_sampler:
        data_sampler.set_epoch(epoch)
    for i, (text, images) in enumerate(distr_dl):
        if args.fp16:
            images = images.half()
        text, images = map(lambda t: t.cuda(), (text, images))
        loss = engine(text, images, return_loss=True)
        
        # update everything
        # ...
        
        if i % 100 == 0:
            if distr_backend.is_root_worker():
                sample_text = text[:1]
                token_list = sample_text.masked_select(sample_text != 0).tolist()
                decoded_text = tokenizer.decode(token_list)

                image = dalle.generate_images(text[:1], filter_thres=0.9)  # topk sampling at 0.9
                log = {
                    **log,
                }
                if not avoid_model_calls:
                    log['image'] = wandb.Image(image, caption=decoded_text)

And this properly runs backpropagation via some automated strategy parameter I haven't understood yet. It's all undocumented so I'm just reading their code at this point. This may be an instance where their 'auto' policy is inserting values in the range of a -1,1 where the vqgan expects values in the range 0,1? Bit out of my depth on this one.

Traceback (most recent call last):
  File "DALLE-pytorch/train_dalle.py", line 456, in <module>
    loss = engine(text, images, return_loss=True)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/deepspeed/inference/engine.py", line 222, in forward
    return self.module(*inputs, **kwargs)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/evo_internal_1TB/DALLE-pytorch/dalle_pytorch/dalle_pytorch.py", line 486, in forward
    image = self.vae.get_codebook_indices(image)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/evo_internal_1TB/DALLE-pytorch/dalle_pytorch/vae.py", line 173, in get_codebook_indices
    _, _, [_, _, indices] = self.model.encode(img)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/taming_transformers-0.0.1-py3.7.egg/taming/models/vqgan.py", line 54, in encode
    quant, emb_loss, info = self.quantize(h)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/evo_internal_1TB/.anaconda/envs/sparse/lib/python3.7/site-packages/taming_transformers-0.0.1-py3.7.egg/taming/modules/vqvae/quantize.py", line 42, in forward
    torch.sum(self.embedding.weight**2, dim=1) - 2 * \
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

afiaka87 avatar Jun 12 '21 13:06 afiaka87

As always, apologies to Jan who I'm sure has already explained this issue ;) I'll admit to some amount of laziness with regard to doing the due diligence on all this 🤷

afiaka87 avatar Jun 12 '21 14:06 afiaka87

generating seems to be tricky because it seems the deepspeed or DataParallel etc only work through an nn.Module (forward). But the following code works for me to balance the gpus (model trained with stage 1) through dstr_dl:

(distr_dalle, _, distr_dl, _) = distr_backend.distribute(
    args=args,
    model=dalle,
    optimizer=None,
    model_parameters=None,
    training_data=ds if using_deepspeed else dl,
    lr_scheduler=None,
    config_params=deepspeed_config,
)

for i, text in enumerate(distr_dl):
    t = time.time()
    text = text.cuda()
    print(f"generating {i} batch ...batch size {text.shape[0]}")
    image = dalle.generate_images(
        text, filter_thres=0.9)  # topk sampling at 0.9
    sec_per_sample = (time.time() - t) / BATCH_SIZE
    print(i, f'second_per_sample - {sec_per_sample}')

Note that I use dalle because distr_dalle does not work. In my small test case (heads 16/depth 16), both gpus loads are exactly the same.

richcmwang avatar Jun 15 '21 17:06 richcmwang

yeah I guess that's useful if you want to generate many sample. That doesn't help to improve the speed of one batch though

rom1504 avatar Jun 15 '21 18:06 rom1504

It does not improve the speed of one batch per GPU, but with 2 (or multiple) GPUs, it does improve the speed. In my test case, the running time ratio for 1 GPU over 2 GPUs is 1.55.

richcmwang avatar Jun 15 '21 19:06 richcmwang

Thanks @richcmwang! I'll work on this later unless you wanna make the PR.

@rom1504 The DeepSpeed docs do indeed claim faster inference with the inference engine. Not sure how though.

afiaka87 avatar Jun 15 '21 22:06 afiaka87

@richcmwang Exactly, they need the forward call which I'm pretty sure is also the reason why FP16 generation fails. They recommended using a simple if-switch in the forward method like do_generations=True. If it's given, don't do the normal forward calculations but just generations and exit. I didn't find the time until now to try it, though.

Aside from inference being parallelizable, I think the biggest benefit is being able to do inference with models that don't fit into memory.

janEbert avatar Jun 15 '21 22:06 janEbert

@afiaka87 Please feel free to incorporate this. I tried inference but either get incorrect key "checkpoint_path" or unknown type "DeepSpeed" error message. Not sure the doc is accurate.

"checkpoint.json":
{
  "type": "DeepSpeed",
    "version": 0.3,
    "checkpoint_path": "path_to_checkpoints",
}

richcmwang avatar Jun 15 '21 23:06 richcmwang