jax icon indicating copy to clipboard operation
jax copied to clipboard

PyTorch Dataloading doesn't work with >0 workers

Open SarthakYadav opened this issue 3 years ago • 13 comments

Hi!

I'm new to the JAX ecosystem, have used PyTorch and TensorFlow extensively for over 5 years.

My issue is that I can't get PyTorch data loading to work with jax/flax with num_workers>0. Following is a minimal example to reproduce my issues

import argparse
from typing import Sequence
from functools import partial
import flax
from typing import Any
import optax
from flax.training import train_state
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import tqdm
from torchvision.datasets import CIFAR10
from flax.training import common_utils
import torch
import torchvision.transforms as transforms
import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')


NUM_CLASSES = 10
NUM_EPOCHS = 50
BATCH_SIZE = 512


parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", default=0, type=int)


def collate_fn(batch):
    inputs_np = []
    targets_np = []
    for item in batch:
        inp_np = item[0].permute(1, 2, 0).detach().numpy()
        tgts_np = item[1]
        inputs_np.append(inp_np)
        targets_np.append(tgts_np)
    inputs_np = np.asarray(inputs_np)
    targets_np = np.asarray(targets_np)
    return inputs_np, targets_np


class CNN(nn.Module):
    @nn.compact
    def __call__(self, inputs, train=False):
        conv = partial(nn.Conv, kernel_size=(3, 3), strides=(2, 2), 
                       use_bias=False, kernel_init=jax.nn.initializers.kaiming_normal())
        bn = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9,
                   epsilon=1e-5)
        x = conv(features=32)(inputs)
        x = bn()(x)
        x = nn.relu(x)
        x = conv(features=64)(x)
        x = bn()(x)
        x = nn.relu(x)
        x = conv(features=128)(x)
        x = bn()(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(4, 4), strides=(1, 1))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(NUM_CLASSES)(x)
        return x


def initialize(key, inp_shape, model):
  input_shape = (1,) + inp_shape
  @jax.jit
  def init(*args):
    return model.init(*args)
  variables = init({'params': key}, jnp.ones(input_shape))
  return variables['params'], variables['batch_stats']


@jax.jit
def cross_entropy_loss(logits, labels):
    one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
    xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
    return jnp.mean(xentropy)

@jax.jit
def calculate_accuracy(logits, labels):
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return accuracy


@jax.jit
def train_step(state, images, labels):
    step = state.step
    @jax.jit
    def cost_fn(params):
        logits, new_model_state = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            images,
            mutable=['batch_stats'],
            train=True
        )
        loss = cross_entropy_loss(logits, labels)
        weight_penalty_params = jax.tree_leaves(params)
        weight_l2 = sum([jnp.sum(x ** 2)
                        for x in weight_penalty_params
                        if x.ndim > 1])
        weight_decay=0.0001
        weight_penalty = weight_decay * 0.5 * weight_l2
        loss = loss + weight_penalty
        return loss, (new_model_state, logits)
    grad_fn = jax.value_and_grad(cost_fn, has_aux=True)
    aux, grads = grad_fn(state.params)
    new_model_state, logits = aux[1]
    acc = calculate_accuracy(logits, labels)
    new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    return new_state, aux[0], acc

@jax.jit
def eval_step(state, images, labels):
    logits = state.apply_fn(
        {"params": state.params, 
        "batch_stats": state.batch_stats}, 
        images, train=False, mutable=False)
    return calculate_accuracy(logits, labels)


class TrainState(train_state.TrainState):
    batch_stats: Any


if __name__ == "__main__":
  args = parser.parse_args()
  cnn = CNN()
  key = jax.random.PRNGKey(0)
  key, *subkeys = jax.random.split(key, 4)
  params, batch_stats = initialize(subkeys[0], (32, 32, 3), cnn)
  tx = optax.adam(
    1e-3
  )
  state = TrainState.create(
      apply_fn=cnn.apply,
      params=params,
      tx=tx,
      batch_stats=batch_stats
  )
  transform = transforms.Compose(
  [transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  batch_size = BATCH_SIZE
  trainset = CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=True,
                                          shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
  num_tr_steps = len(trainloader)
  testset = CIFAR10(root='./data', train=False,
                                      download=True, transform=transform)
  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, drop_last=True,
                                          shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
  num_test_steps = len(testloader)
  
  for epoch in range(1, NUM_EPOCHS+1):
    print("Starting epoch {}".format(epoch))
    train_loss = []
    train_acc = []
    itercnt = 0
    for batch in trainloader:
      images, labels = batch
      state, loss, acc = train_step(state, images, labels)
      if itercnt == 0:
        print("Input shape:", images.shape)
        print("labels shape:", labels.shape)
      if itercnt % 25 == 0:
        print("[{:03d}] | Step: [{:04d}/{:04d}] | Loss: {:.04f} | Acc: {:.04f}".format(
          epoch, itercnt, num_tr_steps, loss, acc
        ))
      train_loss.append(jax.device_get(loss))
      train_acc.append(jax.device_get(acc))
      itercnt += 1
    print("Validating...")
    val_accs = []
    for batch in testloader:
      images, labels = batch
      acc = eval_step(state, images, labels)
      val_accs.append(jax.device_get(acc))

    print("Epoch {:03d} done...".format(epoch))
    print("\t Train loss: {:.04f} | Train Acc: {:.04f}".format(
      np.mean(train_loss), np.mean(train_acc)))
    print("\t Val Acc: {:.04f}".format(np.mean(val_accs)))

Problem encountered:

I've tried running the script on both TPU and GPU: it works fine when num_workers = 0, but doesn't work with num_workers > 0.

An earlier issue from 2020 recommended setting torch.multiprocessing.set_start_method('spawn'), but that didn't fix the issue for me. Unlike the author of that issue, I'm not using jax primitives in the data loading pipeline at all (as can be seen in the collate_fn() function)

With num_workers>0, I get the following errors:

On GPU

  • With torch.multiprocessing.set_start_method('spawn') throws RuntimeError: context has already been set
  • With torch.multiprocessing.set_start_method('fork') throws Failed setting context: CUDA_ERROR_NOT_INITIALIZED: initialization error

On TPUv2-8 VM

  • With torch.multiprocessing.set_start_method('spawn') throws libtpu.so already in use by another process, followed by RuntimeError: context has already been set later in the stack trace.
  • With torch.multiprocessing.set_start_method('fork') I get no error, the dataloader hangs indefinitely.

Following are the packages being used:

torch==1.9.0+cu111
jax==0.2.26
jaxlib==0.1.75       #+cuda11.cudnn82 for GPU

Any help is appreciated!

SarthakYadav avatar Jan 13 '22 15:01 SarthakYadav

You might get more traction asking about this in the torch project.

jakevdp avatar Jan 14 '22 17:01 jakevdp

You might get more traction asking about this in the torch project.

I'll give that a try, but this only happens when the model itself is implemented in Jax.

In the same env, everything-torch works absolutely fine.

SarthakYadav avatar Jan 14 '22 17:01 SarthakYadav

Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU libtpu.so can be used with multiprocessing at all.

Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.

levskaya avatar Jan 15 '22 01:01 levskaya

Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU libtpu.so can be used with multiprocessing at all.

Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.

Thanks for the detailed reply. I'll see how much time moving all data operations to tensorflow based ops will take.

However, I believe this information should be added to the official tutorial on using pytorch dataloaders with Jax, as this is quite the limitation. Using multiple workers in dataloaders is a standard practice in the PyTorch realm.

It will work for small datasets and a quick proof of concept for researchers/teams thinking about making the move to Jax, sure, but for full-bore training, using torch data loaders with Jax would not be feasible. Adding this as a disclaimer to the above-mentioned tutorial will save valuable time in my opinion.

SarthakYadav avatar Jan 15 '22 10:01 SarthakYadav

+1 I agree these gotchas are major and should be mentioned front and center, it took me a long time and many lost hours this week to figure this out for myself. It's a limitation of the jax ecosystem right now!

jaanli avatar Feb 04 '22 21:02 jaanli

This issue appears to be a regression compared to one year ago. I was using multi-worker data loaders in sabertooth and they worked fine at the time, but no longer work with newly started TPU VMs. I want to emphasize that the data workers are not using JAX nor accessing the TPUs in any way, just doing pure numpy computation.

torch.multiprocessing.set_start_method('spawn') sort of works as a work around. I've managed to avoid the error RuntimeError: context has already been set with the idiom if __name__ == '__main__': torch.multiprocessing.set_start_method('spawn') -- I had to wrap it so that each spawned worker don't itself attempt to set the start method. However this workaround still has issues: each worker takes a really long time to spawn, and generates a bunch of libtpu.so already in use by another process messages. Setting persistent_workers=True helps cut down on these but it's still annoying.

Given that this is a regression, is it really the case that it can't be fixed? None of the child processes are actually doing anything with the TPU.

nikitakit avatar Feb 10 '22 07:02 nikitakit

Agreed, would be great to find a solution ASAP thank you @nikitakit !

jaanli avatar Feb 10 '22 11:02 jaanli

Are there any updates on this? It is frustrating to find out that PyTorch data loader cannot work with Jax on TPU despite it is used in Jax's official examples.

haoliuhl avatar Mar 03 '22 05:03 haoliuhl

@nikitakit - do you perhaps have a small repro for the failure?

levskaya avatar Mar 04 '22 02:03 levskaya

@levskaya there is a code snippet for the failure in #9767.

haoliuhl avatar Mar 13 '22 17:03 haoliuhl

Following up on this

jaanli avatar Apr 04 '22 12:04 jaanli

I also met similar problems on A100 GPU. But I have no idea how to fix it.

rensushan avatar Feb 11 '23 09:02 rensushan

Same problem on NV T4 GPU, it's a disaster when training a tiny model with huge datasets.

noahzhy avatar Mar 14 '24 21:03 noahzhy