CLIP icon indicating copy to clipboard operation
CLIP copied to clipboard

CLIP preprocess hangs when using multiprocessing

Open theahura opened this issue 3 years ago • 6 comments

When running CLIP inside a multiprocessing.Process, the system hangs as soon as it reaches a preprocess step (in practice I assume this is actually any torch operation). A minimal example:

import torch
import clip
from PIL import Image
import multiprocessing as mp
import sys

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


def infer():
  print("PREPROCESSING")
  sys.stdout.flush()
  image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)

  print("TOKENIZING")
  sys.stdout.flush()
  text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

  print("INFERRING")
  sys.stdout.flush()
  with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    print(f'{image_features.shape}')
    print(f'{text_features.shape}')
    sys.stdout.flush()

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

  print(
      f"Label probs: {probs}")  # prints: [[0.9927937  0.00421068 0.00299572]]
  sys.stdout.flush()


p = mp.Process(target=infer, daemon=True)
p.start()
p.join()

This example is the equivalent of the starter how-to-use example in the README, but wrapping the model inference in a Process.

The output of this code is:

/home/amol/code/soot/debugging/clip_tests/env/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check th
at you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0                                                  
PREPROCESSING   

after which it hangs. Any suggestions?

theahura avatar Jul 22 '21 22:07 theahura

Digging into this a bit more, I managed to get this down to the following, even more minimal, example of hanging using clip:

import os                                                                                                                                                                 
import urllib                                                                                                                                                             
                                                                                                                                                                          
from tqdm import tqdm                                                                                                                                                     
                                                                                                                                                                          
import torch                                                                                                                                                              
import clip                                                                                                                                                               
from PIL import Image                                                                                                                                                     
import multiprocessing as mp                                                                                                                                              
                                                                                                                                                                          
                                                                                                                                                                          
def _download(url, root=os.path.expanduser("~/.cache/clip")):                        
  os.makedirs(root, exist_ok=True)        
  download_target = os.path.join(root, os.path.basename(url))                        
                                                                                                                                                                          
  with urllib.request.urlopen(url) as source, open(download_target,                  
                                                   "wb") as output:                  
    with tqdm(total=int(source.info().get("Content-Length")),                        
              ncols=80,
              unit='iB',
              unit_scale=True) as loop:   
      while True:                         
        buffer = source.read(8192)        
        if not buffer:                                                                                                                                                    
          break                                                                      

        output.write(buffer)
        loop.update(len(buffer))

  return download_target


def load(device="cpu"):
  model_path = _download(
      "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"  # noqa
  )                                       
  model = torch.jit.load(model_path, map_location="cpu").eval()                      
  model = clip.model.build_model(model.state_dict()).to(device)
  
load()                                    


def test():                               
  print("GETTING IMAGE")
  im = Image.open("CLIP.png")
  print("CONVERTING")
  im = im.convert('RGB')
  print("MADE TENSOR")
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(im.tobytes()))                
  print("VIEW")                           
  img = img.view(im.size[1], im.size[0], len(im.getbands()))                         
  print("PERMUTING")                      
  img = img.permute((2, 0, 1)).contiguous()                                          
  print("DIV")                            
  img = img.float().div(255)
  print("UNSQUEEZE")                      
  img = img.unsqueeze(0)


p = mp.Process(target=test, daemon=True)
p.start()                                 
p.join()

Interestingly, if you comment out the load() call or the clip.model.build_model() call, everything just works. Note that nothing in that the load() call returns is actually being used; what is the build_model() doing to the global scope that is causing these hangs?

theahura avatar Jul 23 '21 06:07 theahura

It'll depend on the platform and the start method, but in general you shouldn't expect CUDA-related code to work across the process boundary. Initialize the model in the process that you'll be using the model. It's best to not call any CUDA-related code (e.g. anything that accesses torch.cuda) from the parent process.

jongwook avatar Jul 23 '21 06:07 jongwook

So as far as I can tell, in the second code block the child process is not referencing anything in the parent process. And I'm not using cuda, all of this is running on a CPU.

Loading the same model per process seems prohibitively expensive from an io perspective; surely there's a way to load the process once and share the model?

Is the clip load process using multithreading or locking or something? I'm confused why it would hang only if an unreferenced line is present.

theahura avatar Jul 23 '21 13:07 theahura

Got it down to an even simpler version than before:

import torch
import clip
from PIL import Image
import multiprocessing as mp

model = clip.model.CLIP(512, 224, 12, 768, 32, 77, 49408, 512, 8, 12)


def test():
  print("GETTING IMAGE")
  im = Image.open("CLIP.png")
  print("CONVERTING")
  im = im.convert('RGB')
  print("MADE TENSOR")
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(im.tobytes()))
  print("VIEW")
  img = img.view(im.size[1], im.size[0], len(im.getbands()))
  print("PERMUTING")
  img = img.permute((2, 0, 1))
  print("CONTINGUOUS")
  img = img.contiguous()
  print("DIV")
  img = img.float().div(255)
  print("UNSQUEEZE")
  img = img.unsqueeze(0)
  return img


p = mp.Process(target=test, daemon=True)
p.start()
p.join()

It looks like it's actually blocking on the contiguous() call, but only if the clip model is initialized. Notice that I've removed all state management, and the process isn't using the underlying model :thinking:

theahura avatar Jul 23 '21 16:07 theahura

These issues can be solved by changing the Process start method. In your latest example you can make it work by creating the process like this

ctx = mp.get_context("spawn")
p = ctx.Process(target=test, daemon=True)

and everything should run fine.

Also, when PyTorch is involved it is better to use Pytorch drop-in replacement of Python multiprocessing module: import torch.multiprocessing as mp

godimarcovr avatar Dec 24 '21 10:12 godimarcovr

If I understand right, changing the context to spawn causes the model to reload, which sorta defeats the point of loading the model in the parent. In that scenario I could just as easily load the model after the thread has started, and it would take about the same time?

theahura avatar Jan 05 '22 06:01 theahura