jax
jax copied to clipboard
PyTorch Dataloading doesn't work with >0 workers
Hi!
I'm new to the JAX ecosystem, have used PyTorch and TensorFlow extensively for over 5 years.
My issue is that I can't get PyTorch data loading to work with jax/flax with num_workers>0
.
Following is a minimal example to reproduce my issues
import argparse
from typing import Sequence
from functools import partial
import flax
from typing import Any
import optax
from flax.training import train_state
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import tqdm
from torchvision.datasets import CIFAR10
from flax.training import common_utils
import torch
import torchvision.transforms as transforms
import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')
NUM_CLASSES = 10
NUM_EPOCHS = 50
BATCH_SIZE = 512
parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", default=0, type=int)
def collate_fn(batch):
inputs_np = []
targets_np = []
for item in batch:
inp_np = item[0].permute(1, 2, 0).detach().numpy()
tgts_np = item[1]
inputs_np.append(inp_np)
targets_np.append(tgts_np)
inputs_np = np.asarray(inputs_np)
targets_np = np.asarray(targets_np)
return inputs_np, targets_np
class CNN(nn.Module):
@nn.compact
def __call__(self, inputs, train=False):
conv = partial(nn.Conv, kernel_size=(3, 3), strides=(2, 2),
use_bias=False, kernel_init=jax.nn.initializers.kaiming_normal())
bn = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9,
epsilon=1e-5)
x = conv(features=32)(inputs)
x = bn()(x)
x = nn.relu(x)
x = conv(features=64)(x)
x = bn()(x)
x = nn.relu(x)
x = conv(features=128)(x)
x = bn()(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(4, 4), strides=(1, 1))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(NUM_CLASSES)(x)
return x
def initialize(key, inp_shape, model):
input_shape = (1,) + inp_shape
@jax.jit
def init(*args):
return model.init(*args)
variables = init({'params': key}, jnp.ones(input_shape))
return variables['params'], variables['batch_stats']
@jax.jit
def cross_entropy_loss(logits, labels):
one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
return jnp.mean(xentropy)
@jax.jit
def calculate_accuracy(logits, labels):
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return accuracy
@jax.jit
def train_step(state, images, labels):
step = state.step
@jax.jit
def cost_fn(params):
logits, new_model_state = state.apply_fn(
{"params": params, "batch_stats": state.batch_stats},
images,
mutable=['batch_stats'],
train=True
)
loss = cross_entropy_loss(logits, labels)
weight_penalty_params = jax.tree_leaves(params)
weight_l2 = sum([jnp.sum(x ** 2)
for x in weight_penalty_params
if x.ndim > 1])
weight_decay=0.0001
weight_penalty = weight_decay * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, (new_model_state, logits)
grad_fn = jax.value_and_grad(cost_fn, has_aux=True)
aux, grads = grad_fn(state.params)
new_model_state, logits = aux[1]
acc = calculate_accuracy(logits, labels)
new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
return new_state, aux[0], acc
@jax.jit
def eval_step(state, images, labels):
logits = state.apply_fn(
{"params": state.params,
"batch_stats": state.batch_stats},
images, train=False, mutable=False)
return calculate_accuracy(logits, labels)
class TrainState(train_state.TrainState):
batch_stats: Any
if __name__ == "__main__":
args = parser.parse_args()
cnn = CNN()
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, 4)
params, batch_stats = initialize(subkeys[0], (32, 32, 3), cnn)
tx = optax.adam(
1e-3
)
state = TrainState.create(
apply_fn=cnn.apply,
params=params,
tx=tx,
batch_stats=batch_stats
)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = BATCH_SIZE
trainset = CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
num_tr_steps = len(trainloader)
testset = CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, drop_last=True,
shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
num_test_steps = len(testloader)
for epoch in range(1, NUM_EPOCHS+1):
print("Starting epoch {}".format(epoch))
train_loss = []
train_acc = []
itercnt = 0
for batch in trainloader:
images, labels = batch
state, loss, acc = train_step(state, images, labels)
if itercnt == 0:
print("Input shape:", images.shape)
print("labels shape:", labels.shape)
if itercnt % 25 == 0:
print("[{:03d}] | Step: [{:04d}/{:04d}] | Loss: {:.04f} | Acc: {:.04f}".format(
epoch, itercnt, num_tr_steps, loss, acc
))
train_loss.append(jax.device_get(loss))
train_acc.append(jax.device_get(acc))
itercnt += 1
print("Validating...")
val_accs = []
for batch in testloader:
images, labels = batch
acc = eval_step(state, images, labels)
val_accs.append(jax.device_get(acc))
print("Epoch {:03d} done...".format(epoch))
print("\t Train loss: {:.04f} | Train Acc: {:.04f}".format(
np.mean(train_loss), np.mean(train_acc)))
print("\t Val Acc: {:.04f}".format(np.mean(val_accs)))
Problem encountered:
I've tried running the script on both TPU
and GPU
: it works fine when num_workers = 0
, but doesn't work with num_workers > 0
.
An earlier issue from 2020 recommended setting torch.multiprocessing.set_start_method('spawn')
, but that didn't fix the issue for me. Unlike the author of that issue, I'm not using jax primitives in the data loading pipeline at all (as can be seen in the collate_fn()
function)
With num_workers>0
, I get the following errors:
On GPU
- With
torch.multiprocessing.set_start_method('spawn')
throwsRuntimeError: context has already been set
- With
torch.multiprocessing.set_start_method('fork')
throwsFailed setting context: CUDA_ERROR_NOT_INITIALIZED: initialization error
On TPUv2-8 VM
- With
torch.multiprocessing.set_start_method('spawn')
throwslibtpu.so already in use by another process
, followed byRuntimeError: context has already been set
later in the stack trace. - With
torch.multiprocessing.set_start_method('fork')
I get no error, the dataloader hangs indefinitely.
Following are the packages being used:
torch==1.9.0+cu111
jax==0.2.26
jaxlib==0.1.75 #+cuda11.cudnn82 for GPU
Any help is appreciated!
You might get more traction asking about this in the torch project.
You might get more traction asking about this in the torch project.
I'll give that a try, but this only happens when the model itself is implemented in Jax.
In the same env, everything-torch works absolutely fine.
Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU libtpu.so
can be used with multiprocessing at all.
Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.
Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU
libtpu.so
can be used with multiprocessing at all.Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.
Thanks for the detailed reply. I'll see how much time moving all data operations to tensorflow based ops will take.
However, I believe this information should be added to the official tutorial on using pytorch dataloaders with Jax, as this is quite the limitation. Using multiple workers in dataloaders is a standard practice in the PyTorch realm.
It will work for small datasets and a quick proof of concept for researchers/teams thinking about making the move to Jax, sure, but for full-bore training, using torch data loaders with Jax would not be feasible. Adding this as a disclaimer to the above-mentioned tutorial will save valuable time in my opinion.
+1 I agree these gotchas are major and should be mentioned front and center, it took me a long time and many lost hours this week to figure this out for myself. It's a limitation of the jax ecosystem right now!
This issue appears to be a regression compared to one year ago. I was using multi-worker data loaders in sabertooth and they worked fine at the time, but no longer work with newly started TPU VMs. I want to emphasize that the data workers are not using JAX nor accessing the TPUs in any way, just doing pure numpy computation.
torch.multiprocessing.set_start_method('spawn')
sort of works as a work around. I've managed to avoid the error RuntimeError: context has already been set
with the idiom if __name__ == '__main__': torch.multiprocessing.set_start_method('spawn')
-- I had to wrap it so that each spawned worker don't itself attempt to set the start method. However this workaround still has issues: each worker takes a really long time to spawn, and generates a bunch of libtpu.so already in use by another process
messages. Setting persistent_workers=True
helps cut down on these but it's still annoying.
Given that this is a regression, is it really the case that it can't be fixed? None of the child processes are actually doing anything with the TPU.
Agreed, would be great to find a solution ASAP thank you @nikitakit !
Are there any updates on this? It is frustrating to find out that PyTorch data loader cannot work with Jax on TPU despite it is used in Jax's official examples.
@nikitakit - do you perhaps have a small repro for the failure?
@levskaya there is a code snippet for the failure in #9767.
Following up on this
I also met similar problems on A100 GPU. But I have no idea how to fix it.
Same problem on NV T4 GPU, it's a disaster when training a tiny model with huge datasets.