Model trains but only if I dont JIT the step function?
Edit from the Future
I fixed it and it's described here.
Preface (not super relevant; can be skipped)
Ok, so this is a weird one, which took me HOURS to find. I wanted to implement ResNet and train it on Cifar10 for another YT video, so I started hacking away ported the PyTorch implementation to Equinox. So far, so good. But I couldn't for the love of God get it to train. The PyTorch version had no problem training - even without any preprocessing of the data, no fancy-schmancy learning rate schedulers, just the most straight forward implementation you can think of.
I thought I was going crazy; I thought maybe it was because of the BatchNorm (because I saw a couple of open issues) - so I implemented a slightly different version that matches PyTorch version EXACTLY. But to no avail. I started to check the intermediate outputs of the network, maybe something is off there? No. Then I even turned off BatchNorm entirely in both networks. The PyTorch one - even without BatchNorm - trained no problems at all. But not my version; so it's definitely not because of the BatchNorm discrepancies. At this point it's basically just a large CNN with some residual connections.
Copy-pastable version of ResNet (jaxtyping required)
from typing import Type
import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
# from jaxonmodels.layers.batch_norm import BatchNorm
class Downsample(eqx.Module):
conv: eqx.nn.Conv2d
# bn: BatchNorm
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
key: jt.PRNGKeyArray,
):
_, subkey = jax.random.split(key)
self.conv = eqx.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
use_bias=False,
key=subkey,
)
# self.bn = BatchNorm(out_channels, axis_name="batch")
def __call__(
self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
x = self.conv(x)
# x, state = self.bn(x, state)
return x, state
class BasicBlock(eqx.Module):
downsample: Downsample | None
conv1: eqx.nn.Conv2d
# bn1: BatchNorm
conv2: eqx.nn.Conv2d
# bn2: BatchNorm
expansion: int = eqx.field(static=True, default=1)
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
downsample: Downsample | None,
groups: int,
base_width: int,
dilation: int,
key: jt.PRNGKeyArray,
):
key, *subkeys = jax.random.split(key, 3)
self.conv1 = eqx.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
use_bias=False,
key=subkeys[0],
)
# self.bn1 = BatchNorm(input_size=out_channels, axis_name="batch")
self.conv2 = eqx.nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
use_bias=False,
key=subkeys[1],
)
# self.bn2 = BatchNorm(input_size=out_channels, axis_name="batch")
self.downsample = downsample
def __call__(self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State):
i = x
x = self.conv1(x)
# x, state = self.bn1(x, state)
x = jax.nn.relu(x)
x = self.conv2(x)
# x, state = self.bn2(x, state)
if self.downsample:
i, state = self.downsample(i, state)
x += i
x = jax.nn.relu(x)
return x, state
class Bottleneck(eqx.Module):
downsample: Downsample | None
conv1: eqx.nn.Conv2d
# bn1: BatchNorm
conv2: eqx.nn.Conv2d
# bn2: BatchNorm
conv3: eqx.nn.Conv2d
# bn3: BatchNorm
expansion: int = eqx.field(static=True, default=4)
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
downsample: Downsample | None,
groups: int,
base_width: int,
dilation: int,
key: jt.PRNGKeyArray,
):
_, *subkeys = jax.random.split(key, 4)
width = int(out_channels * (base_width / 64.0)) * groups
self.conv1 = eqx.nn.Conv2d(
in_channels, width, kernel_size=1, use_bias=False, key=subkeys[0]
)
# self.bn1 = BatchNorm(width, axis_name="batch")
self.conv2 = eqx.nn.Conv2d(
width,
width,
kernel_size=3,
stride=stride,
groups=groups,
dilation=dilation,
padding=dilation,
use_bias=False,
key=subkeys[1],
)
# self.bn2 = BatchNorm(width, axis_name="batch")
self.conv3 = eqx.nn.Conv2d(
width,
out_channels * self.expansion,
kernel_size=1,
key=subkeys[2],
use_bias=False,
)
# self.bn3 = BatchNorm(out_channels * self.expansion, axis_name="batch")
self.downsample = downsample
def __call__(
self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
i = x
x = self.conv1(x)
# x, state = self.bn1(x, state)
x = jax.nn.relu(x)
x = self.conv2(x)
# x, state = self.bn2(x, state)
x = jax.nn.relu(x)
x = self.conv3(x)
# x, state = self.bn3(x, state)
if self.downsample:
i, state = self.downsample(i, state)
x += i
x = jax.nn.relu(x)
return x, state
class ResNet(eqx.Module):
conv1: eqx.nn.Conv2d
# bn: BatchNorm
mp: eqx.nn.MaxPool2d
layer1: list[BasicBlock | Bottleneck]
layer2: list[BasicBlock | Bottleneck]
layer3: list[BasicBlock | Bottleneck]
layer4: list[BasicBlock | Bottleneck]
avg: eqx.nn.AdaptiveAvgPool2d
fc: eqx.nn.Linear
running_internal_channels: int = eqx.field(static=True, default=64)
dilation: int = eqx.field(static=True, default=1)
def __init__(
self,
block: Type[BasicBlock | Bottleneck],
layers: list[int],
n_classes: int,
zero_init_residual: bool,
groups: int,
width_per_group: int,
replace_stride_with_dilation: list[bool] | None,
key: jt.PRNGKeyArray,
input_channels: int = 3,
):
key, *subkeys = jax.random.split(key, 10)
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
f"`replace_stride_with_dilation` should either be `None` "
f"or have a length of 3, got {replace_stride_with_dilation} instead."
)
self.conv1 = eqx.nn.Conv2d(
in_channels=input_channels,
out_channels=self.running_internal_channels,
kernel_size=7,
stride=2,
padding=3,
use_bias=False,
key=subkeys[0],
)
# self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block,
64,
layers[0],
stride=1,
dilate=False,
groups=groups,
base_width=width_per_group,
key=subkeys[1],
)
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0],
groups=groups,
base_width=width_per_group,
key=subkeys[2],
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
groups=groups,
base_width=width_per_group,
key=subkeys[3],
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2],
groups=groups,
base_width=width_per_group,
key=subkeys[4],
)
self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])
if zero_init_residual:
# todo: init last bn layer with zero weights
pass
def _make_layer(
self,
block: Type[BasicBlock | Bottleneck],
out_channels: int,
blocks: int,
stride: int,
dilate: bool,
groups: int,
base_width: int,
key: jt.PRNGKeyArray,
) -> list[BasicBlock | Bottleneck]:
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if (
stride != 1
or self.running_internal_channels != out_channels * block.expansion
):
key, subkey = jax.random.split(key)
downsample = Downsample(
self.running_internal_channels,
out_channels * block.expansion,
stride,
subkey,
)
layers = []
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
stride=stride,
downsample=downsample,
groups=groups,
base_width=base_width,
dilation=previous_dilation,
key=subkey,
)
)
self.running_internal_channels = out_channels * block.expansion
for _ in range(1, blocks):
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
groups=groups,
base_width=base_width,
dilation=self.dilation,
stride=1,
downsample=None,
key=subkey,
)
)
return layers
def __call__(
self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
x = self.conv1(x)
# x, state = self.bn(x, state)
x = jax.nn.relu(x)
x = self.mp(x)
for layer in self.layer1:
x, state = layer(x, state)
for layer in self.layer2:
x, state = layer(x, state)
for layer in self.layer3:
x, state = layer(x, state)
for layer in self.layer4:
x, state = layer(x, state)
x = self.avg(x)
x = jnp.ravel(x)
x = self.fc(x)
return x, state
def resnet18(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
key, subkey = jax.random.split(key)
resnet, state = eqx.nn.make_with_state(ResNet)(
BasicBlock,
[2, 2, 2, 2],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
key=key,
)
# initializer = jax.nn.initializers.he_normal()
# is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
# get_weights = lambda m: [
# x.weight for x in jax.tree.leaves(m, is_leaf=is_conv2d) if is_conv2d(x)
# ]
# weights = get_weights(resnet)
# new_weights = [
# initializer(subkey, weight.shape, jnp.float32)
# for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
# ]
# resnet = eqx.tree_at(get_weights, resnet, new_weights)
return resnet, state
def resnet34(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
BasicBlock,
[3, 4, 6, 3],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
key=key,
)
def resnet50(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 4, 6, 3],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
key=key,
)
def resnet101(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 4, 23, 3],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
key=key,
)
def resnet152(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 8, 36, 3],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
key=key,
)
def resnext50_32x4d(
key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 4, 6, 3],
n_classes,
zero_init_residual=False,
groups=32,
width_per_group=4,
replace_stride_with_dilation=None,
key=key,
)
def resnext101_32x8d(
key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 4, 23, 3],
n_classes,
zero_init_residual=False,
groups=32,
width_per_group=8,
replace_stride_with_dilation=None,
key=key,
)
def resnext101_64x4d(
key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 4, 23, 3],
n_classes,
zero_init_residual=False,
groups=64,
width_per_group=4,
replace_stride_with_dilation=None,
key=key,
)
def wide_resnet50_2(
key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 4, 6, 3],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64 * 2,
replace_stride_with_dilation=None,
key=key,
)
def wide_resnet101_2(
key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
return eqx.nn.make_with_state(ResNet)(
Bottleneck,
[3, 4, 23, 3],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64 * 2,
replace_stride_with_dilation=None,
key=key,
)
The issue
I was started to get desperate and started to check the gradients, for which I needed to un-JIT the step function to see the gradient print statements. And there it is: after removing the eqx.filter_jit wrapper around the step function, the network started to train.
Copy-pasteable training loop code (requires tensorflow tensorflow_datasets clu tqdm jaxtyping)
import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from clu import metrics
from tqdm import tqdm
# from jaxonmodels.models.resnet import ResNet, resnet18
# copy paste here the code from the details above
(train, test), info = tfds.load(
"cifar10", split=["train", "test"], with_info=True, as_supervised=True
) # pyright: ignore
def preprocess(
img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
img = tf.cast(img, tf.float32) / 255.0 # pyright: ignore
mean = tf.constant([0.4914, 0.4822, 0.4465])
std = tf.constant([0.2470, 0.2435, 0.2616])
img = (img - mean) / std # pyright: ignore
img = tf.transpose(img, perm=[2, 0, 1])
# label = tf.one_hot(label, depth=10)
return img, label
def preprocess_train(
img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], mode="REFLECT")
img = tf.image.random_crop(img, [32, 32, 3])
img = tf.image.random_flip_left_right(img) # pyright: ignore
return preprocess(img, label)
train_dataset = train.map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
SHUFFLE_VAL = len(train_dataset) // 1000
BATCH_SIZE = 128
train_dataset = train_dataset.shuffle(SHUFFLE_VAL)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
test_dataset = test.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)
train_dataset = tfds.as_numpy(train_dataset)
test_dataset = tfds.as_numpy(test_dataset)
def loss_fn(
resnet: ResNet,
x: jt.Array,
y: jt.Array,
state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
logits, state = eqx.filter_vmap(
resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(x, state)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
return jnp.mean(loss), (logits, state)
# @eqx.filter_jit
def step(
resnet: jt.PyTree,
state: eqx.nn.State,
x: jt.Array,
y: jt.Array,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
):
(loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
loss_fn, has_aux=True
)(resnet, x, y, state)
updates, opt_state = optimizer.update(grads, opt_state, resnet)
resnet = eqx.apply_updates(resnet, updates)
return resnet, state, opt_state, loss_value, logits
class TrainMetrics(eqx.Module, metrics.Collection):
loss: metrics.Average.from_output("loss") # pyright: ignore
accuracy: metrics.Accuracy
def eval(
resnet: ResNet, test_dataset, state, key: jt.PRNGKeyArray
) -> TrainMetrics:
eval_metrics = TrainMetrics.empty()
for x, y in test_dataset:
y = jnp.array(y, dtype=jnp.int32)
loss, (logits, state) = loss_fn(resnet, x, y, state)
eval_metrics = eval_metrics.merge(
TrainMetrics.single_from_model_output(
logits=logits, labels=y, loss=loss
)
)
return eval_metrics
train_metrics = TrainMetrics.empty()
resnet, state = resnet18(key=jax.random.key(0), n_classes=10)
learning_rate = 0.1
weight_decay = 5e-4
optimizer = optax.sgd(learning_rate)
opt_state = optimizer.init(eqx.filter(resnet, eqx.is_inexact_array_like))
key = jax.random.key(99)
n_epochs = 100
for epoch in range(n_epochs):
batch_count = len(train_dataset)
pbar = tqdm(enumerate(train_dataset), total=batch_count, desc=f"Epoch {epoch}")
for i, (x, y) in pbar:
y = jnp.array(y, dtype=jnp.int32)
resnet, state, opt_state, loss, logits = step(
resnet, state, x, y, optimizer, opt_state
)
train_metrics = train_metrics.merge(
TrainMetrics.single_from_model_output(
logits=logits, labels=y, loss=loss
)
)
vals = train_metrics.compute()
pbar.set_postfix(
{"loss": f"{vals['loss']:.4f}", "acc": f"{vals['accuracy']:.4f}"}
)
key, subkey = jax.random.split(key)
eval_metrics = eval(resnet, test_dataset, state, subkey)
evals = eval_metrics.compute()
print(
f"Epoch {epoch}: "
f"test_loss={evals['loss']:.4f}, "
f"test_acc={evals['accuracy']:.4f}"
)
The relevant part (tl;dr)
The issue is here:
def loss_fn(
resnet: ResNet,
x: jt.Array,
y: jt.Array,
state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
logits, state = eqx.filter_vmap(
resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(x, state)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
return jnp.mean(loss), (logits, state)
# @eqx.filter_jit
def step(
resnet: jt.PyTree,
state: eqx.nn.State,
x: jt.Array,
y: jt.Array,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
):
(loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
loss_fn, has_aux=True
)(resnet, x, y, state)
updates, opt_state = optimizer.update(grads, opt_state, resnet)
resnet = eqx.apply_updates(resnet, updates)
return resnet, state, opt_state, loss_value, logits
From my perspective, this looks just like standard JAX "boilerplate" code. I see no reason, why JITting the step function would interfere with training the model.
My other attemps
So perhaps I can get rid of the state, I thought, since I don't even use BatchNorm anymore. But that makes no difference. I tried JITting a smaller portion, as shown in the RNN example
@eqx.filter_value_and_grad
def compute_loss(model, x, y):
pred_y = jax.vmap(model)(x)
# Trains with respect to binary cross-entropy
return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))
But the equivalent version didn't improve the model.
I spent all day on this and am now out of options and in German we'd say "es ist wie verhext" , so perhaps anyone here has an idea? ANY help is HIGHLY appreciated.
I think the code can be further simplified
code
from typing import Type
import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
# from jaxonmodels.layers.batch_norm import BatchNorm
class Downsample(eqx.Module):
conv: eqx.nn.Conv2d
# bn: BatchNorm
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
key: jt.PRNGKeyArray,
):
_, subkey = jax.random.split(key)
self.conv = eqx.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
use_bias=False,
key=subkey,
)
# self.bn = BatchNorm(out_channels, axis_name="batch")
def __call__(
self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
x = self.conv(x)
# x, state = self.bn(x, state)
return x, state
class BasicBlock(eqx.Module):
downsample: Downsample | None
conv1: eqx.nn.Conv2d
# bn1: BatchNorm
conv2: eqx.nn.Conv2d
# bn2: BatchNorm
expansion: int = eqx.field(static=True, default=1)
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
downsample: Downsample | None,
groups: int,
base_width: int,
dilation: int,
key: jt.PRNGKeyArray,
):
key, *subkeys = jax.random.split(key, 3)
self.conv1 = eqx.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
use_bias=False,
key=subkeys[0],
)
# self.bn1 = BatchNorm(input_size=out_channels, axis_name="batch")
self.conv2 = eqx.nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
use_bias=False,
key=subkeys[1],
)
# self.bn2 = BatchNorm(input_size=out_channels, axis_name="batch")
self.downsample = downsample
def __call__(self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State):
i = x
x = self.conv1(x)
# x, state = self.bn1(x, state)
x = jax.nn.relu(x)
x = self.conv2(x)
# x, state = self.bn2(x, state)
if self.downsample:
i, state = self.downsample(i, state)
x += i
x = jax.nn.relu(x)
return x, state
class Bottleneck(eqx.Module):
downsample: Downsample | None
conv1: eqx.nn.Conv2d
# bn1: BatchNorm
conv2: eqx.nn.Conv2d
# bn2: BatchNorm
conv3: eqx.nn.Conv2d
# bn3: BatchNorm
expansion: int = eqx.field(static=True, default=4)
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
downsample: Downsample | None,
groups: int,
base_width: int,
dilation: int,
key: jt.PRNGKeyArray,
):
_, *subkeys = jax.random.split(key, 4)
width = int(out_channels * (base_width / 64.0)) * groups
self.conv1 = eqx.nn.Conv2d(
in_channels, width, kernel_size=1, use_bias=False, key=subkeys[0]
)
# self.bn1 = BatchNorm(width, axis_name="batch")
self.conv2 = eqx.nn.Conv2d(
width,
width,
kernel_size=3,
stride=stride,
groups=groups,
dilation=dilation,
padding=dilation,
use_bias=False,
key=subkeys[1],
)
# self.bn2 = BatchNorm(width, axis_name="batch")
self.conv3 = eqx.nn.Conv2d(
width,
out_channels * self.expansion,
kernel_size=1,
key=subkeys[2],
use_bias=False,
)
# self.bn3 = BatchNorm(out_channels * self.expansion, axis_name="batch")
self.downsample = downsample
def __call__(
self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
i = x
x = self.conv1(x)
# x, state = self.bn1(x, state)
x = jax.nn.relu(x)
x = self.conv2(x)
# x, state = self.bn2(x, state)
x = jax.nn.relu(x)
x = self.conv3(x)
# x, state = self.bn3(x, state)
if self.downsample:
i, state = self.downsample(i, state)
x += i
x = jax.nn.relu(x)
return x, state
class ResNet(eqx.Module):
conv1: eqx.nn.Conv2d
# bn: BatchNorm
mp: eqx.nn.MaxPool2d
layer1: list[BasicBlock | Bottleneck]
layer2: list[BasicBlock | Bottleneck]
layer3: list[BasicBlock | Bottleneck]
layer4: list[BasicBlock | Bottleneck]
avg: eqx.nn.AdaptiveAvgPool2d
fc: eqx.nn.Linear
running_internal_channels: int = eqx.field(static=True, default=64)
dilation: int = eqx.field(static=True, default=1)
def __init__(
self,
block: Type[BasicBlock | Bottleneck],
layers: list[int],
n_classes: int,
zero_init_residual: bool,
groups: int,
width_per_group: int,
replace_stride_with_dilation: list[bool] | None,
key: jt.PRNGKeyArray,
input_channels: int = 3,
):
key, *subkeys = jax.random.split(key, 10)
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
f"`replace_stride_with_dilation` should either be `None` "
f"or have a length of 3, got {replace_stride_with_dilation} instead."
)
self.conv1 = eqx.nn.Conv2d(
in_channels=input_channels,
out_channels=self.running_internal_channels,
kernel_size=7,
stride=2,
padding=3,
use_bias=False,
key=subkeys[0],
)
# self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block,
64,
layers[0],
stride=1,
dilate=False,
groups=groups,
base_width=width_per_group,
key=subkeys[1],
)
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0],
groups=groups,
base_width=width_per_group,
key=subkeys[2],
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
groups=groups,
base_width=width_per_group,
key=subkeys[3],
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2],
groups=groups,
base_width=width_per_group,
key=subkeys[4],
)
self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])
if zero_init_residual:
# todo: init last bn layer with zero weights
pass
def _make_layer(
self,
block: Type[BasicBlock | Bottleneck],
out_channels: int,
blocks: int,
stride: int,
dilate: bool,
groups: int,
base_width: int,
key: jt.PRNGKeyArray,
) -> list[BasicBlock | Bottleneck]:
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if (
stride != 1
or self.running_internal_channels != out_channels * block.expansion
):
key, subkey = jax.random.split(key)
downsample = Downsample(
self.running_internal_channels,
out_channels * block.expansion,
stride,
subkey,
)
layers = []
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
stride=stride,
downsample=downsample,
groups=groups,
base_width=base_width,
dilation=previous_dilation,
key=subkey,
)
)
self.running_internal_channels = out_channels * block.expansion
for _ in range(1, blocks):
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
groups=groups,
base_width=base_width,
dilation=self.dilation,
stride=1,
downsample=None,
key=subkey,
)
)
return layers
def __call__(
self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
x = self.conv1(x)
# x, state = self.bn(x, state)
x = jax.nn.relu(x)
x = self.mp(x)
for layer in self.layer1:
x, state = layer(x, state)
for layer in self.layer2:
x, state = layer(x, state)
for layer in self.layer3:
x, state = layer(x, state)
for layer in self.layer4:
x, state = layer(x, state)
x = self.avg(x)
x = jnp.ravel(x)
x = self.fc(x)
return x, state
def resnet18(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
key, subkey = jax.random.split(key)
resnet, state = eqx.nn.make_with_state(ResNet)(
BasicBlock,
[2, 2, 2, 2],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
key=key,
)
# initializer = jax.nn.initializers.he_normal()
# is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
# get_weights = lambda m: [
# x.weight for x in jax.tree.leaves(m, is_leaf=is_conv2d) if is_conv2d(x)
# ]
# weights = get_weights(resnet)
# new_weights = [
# initializer(subkey, weight.shape, jnp.float32)
# for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
# ]
# resnet = eqx.tree_at(get_weights, resnet, new_weights)
return resnet, state
import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
tf.random.set_seed(0)
(train, test), info = tfds.load(
"cifar10", split=["train", "test"], with_info=True, as_supervised=True
) # pyright: ignore
def preprocess(
img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
img = tf.cast(img, tf.float32) / 255.0 # pyright: ignore
mean = tf.constant([0.4914, 0.4822, 0.4465])
std = tf.constant([0.2470, 0.2435, 0.2616])
img = (img - mean) / std # pyright: ignore
img = tf.transpose(img, perm=[2, 0, 1])
# label = tf.one_hot(label, depth=10)
return img, label
def preprocess_train(
img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], mode="REFLECT")
img = tf.image.random_crop(img, [32, 32, 3])
img = tf.image.random_flip_left_right(img) # pyright: ignore
return preprocess(img, label)
train_dataset = train.map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
SHUFFLE_VAL = len(train_dataset) // 1000
BATCH_SIZE = 2
train_dataset = train_dataset.shuffle(SHUFFLE_VAL)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
test_dataset = test.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)
train_dataset = tfds.as_numpy(train_dataset)
test_dataset = tfds.as_numpy(test_dataset)
def loss_fn(
resnet: ResNet,
x: jt.Array,
y: jt.Array,
state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
logits, state = eqx.filter_vmap(
resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(x, state)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
return jnp.mean(loss), (logits, state)
@eqx.filter_jit
def step(
resnet: jt.PyTree,
state: eqx.nn.State,
x: jt.Array,
y: jt.Array,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
):
(loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
loss_fn, has_aux=True
)(resnet, x, y, state)
updates, opt_state = optimizer.update(grads, opt_state, resnet)
new_r = eqx.apply_updates(resnet, updates)
return new_r, state, opt_state, loss_value, logits, grads, updates
def step_nj(
resnet: jt.PyTree,
state: eqx.nn.State,
x: jt.Array,
y: jt.Array,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
):
(loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
loss_fn, has_aux=True
)(resnet, x, y, state)
updates, opt_state = optimizer.update(grads, opt_state, resnet)
new_r = eqx.apply_updates(resnet, updates)
return new_r, state, opt_state, loss_value, logits, grads, updates
resnet, state = resnet18(key=jax.random.key(0), n_classes=10)
learning_rate = 0.1
weight_decay = 5e-4
optimizer = optax.sgd(learning_rate)
opt_state = optimizer.init(eqx.filter(resnet, eqx.is_inexact_array_like))
key = jax.random.key(99)
n_epochs = 100
for epoch in range(n_epochs):
batch_count = len(train_dataset)
for i, (x, y) in enumerate(train_dataset):
y = jnp.array(y, dtype=jnp.int32)
new_j, state, os_j, loss, logits, g, u = step(
resnet, state, x, y, optimizer, opt_state
)
print("\n WJ", state, loss, logits)
y = jnp.array(y, dtype=jnp.int32)
new_nj, state, os_nj, loss, logits_nj, g_nj, u_nj = step_nj(
resnet, state, x, y, optimizer, opt_state
)
print("\n NJ", state, loss, logits_nj)
print(jnp.allclose(logits, logits_nj))
print(eqx.tree_equal(g, g_nj))
print(eqx.tree_equal(u, u_nj))
print(eqx.tree_equal(os_j, os_nj))
print(eqx.tree_equal(new_j, new_nj))
print(eqx.tree_equal(jax.tree.leaves(new_j), jax.tree.leaves(new_nj)))
l = jax.tree.leaves(new_j)
lj = jax.tree.leaves(new_nj)
print(len(l), len(lj))
for i in range(len(l)):
try:
print(i, jnp.linalg.norm(lj[i] - l[i]), jnp.allclose(lj[i], l[i]))
if jnp.isnan(jnp.linalg.norm(lj[i] - l[i])):
print(lj[i], l[i])
except:
print(i, lj[i], l[i], lj[i] == l[i])
break
break
The only differences I saw under jit is that the gradient is slightly different (with a norm of 1e-9 I assume that's just within precision), but not always. I would be surprise if this is the source of the problem, but I'm just trying to narrow it down since it's a very large setup currently
WJ State() 2.333593 [[-0.04504904 0.03363977 -0.03377024 -0.02731024 -0.00935392 -0.07189757
0.0312047 0.04465355 -0.03419897 0.01027971]
[-0.04424123 0.03772916 -0.02880739 -0.01243044 -0.00705145 -0.06763338
0.03016024 0.04261264 -0.03077785 0.0194398 ]]
NJ State() 2.333593 [[-0.04504904 0.03363976 -0.03377024 -0.02731025 -0.00935392 -0.07189757
0.0312047 0.04465355 -0.03419897 0.01027971]
[-0.04424123 0.03772916 -0.02880739 -0.01243044 -0.00705145 -0.06763338
0.03016024 0.04261264 -0.03077785 0.0194398 ]]
True
False
False
True
False
False
22 22
0 2.735993e-08 True
1 1.10683e-08 True
2 1.1603525e-08 True
3 1.2323362e-08 True
4 1.2514869e-08 True
5 1.3517954e-08 True
6 2.0453514e-08 True
7 2.238987e-08 True
8 1.1969572e-08 True
9 9.68429e-09 True
10 1.0157422e-08 True
11 1.7057491e-08 True
12 1.5252258e-08 True
13 1.0765473e-08 True
14 6.8855863e-09 True
15 6.1813026e-09 True
16 9.383501e-09 True
17 5.7024003e-09 True
18 4.0014845e-09 True
19 1.3972744e-09 True
20 2.6966664e-09 True
21 0.0 True
This is indeed a large setup. I'll focus today on making it smaller and further narrow down the issue. (It was late at night when I posted it)
Progress!
The issue lies somewhere in the BasicBlock | Bottleneck layers. Because as soon as I commented those out in the __call__ function, the model trains (even when jitted)
Details
class ResNet(eqx.Module):
conv1: eqx.nn.Conv2d
bn: BatchNorm
mp: eqx.nn.MaxPool2d
layer1: list[BasicBlock | Bottleneck]
layer2: list[BasicBlock | Bottleneck]
layer3: list[BasicBlock | Bottleneck]
layer4: list[BasicBlock | Bottleneck]
avg: eqx.nn.AdaptiveAvgPool2d
fc: eqx.nn.Linear
running_internal_channels: int = eqx.field(static=True, default=64)
dilation: int = eqx.field(static=True, default=1)
def __init__(
self,
block: Type[BasicBlock | Bottleneck],
layers: list[int],
n_classes: int,
zero_init_residual: bool,
groups: int,
width_per_group: int,
replace_stride_with_dilation: list[bool] | None,
key: jt.PRNGKeyArray,
input_channels: int = 3,
):
key, *subkeys = jax.random.split(key, 10)
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
f"`replace_stride_with_dilation` should either be `None` "
f"or have a length of 3, got {replace_stride_with_dilation} instead."
)
self.conv1 = eqx.nn.Conv2d(
in_channels=input_channels,
out_channels=self.running_internal_channels,
kernel_size=7,
stride=2,
padding=3,
use_bias=False,
key=subkeys[0],
)
self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block,
64,
layers[0],
stride=1,
dilate=False,
groups=groups,
base_width=width_per_group,
key=subkeys[1],
)
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0],
groups=groups,
base_width=width_per_group,
key=subkeys[2],
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
groups=groups,
base_width=width_per_group,
key=subkeys[3],
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2],
groups=groups,
base_width=width_per_group,
key=subkeys[4],
)
self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
# self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])
# change the last layer, otherwise the dims won't match :s
self.fc = eqx.nn.Linear(64, n_classes, key=subkeys[-1])
if zero_init_residual:
# todo: init last bn layer with zero weights
pass
def _make_layer(
self,
block: Type[BasicBlock | Bottleneck],
out_channels: int,
blocks: int,
stride: int,
dilate: bool,
groups: int,
base_width: int,
key: jt.PRNGKeyArray,
) -> list[BasicBlock | Bottleneck]:
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if (
stride != 1
or self.running_internal_channels != out_channels * block.expansion
):
key, subkey = jax.random.split(key)
downsample = Downsample(
self.running_internal_channels,
out_channels * block.expansion,
stride,
subkey,
)
layers = []
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
stride=stride,
downsample=downsample,
groups=groups,
base_width=base_width,
dilation=previous_dilation,
key=subkey,
)
)
self.running_internal_channels = out_channels * block.expansion
for _ in range(1, blocks):
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
groups=groups,
base_width=base_width,
dilation=self.dilation,
stride=1,
downsample=None,
key=subkey,
)
)
return layers
def __call__(
self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
x = self.conv1(x)
x, state = self.bn(x, state)
x = jax.nn.relu(x)
x = self.mp(x)
# Comment these out
# for layer in self.layer1:
# x, state = layer(x, state)
# for layer in self.layer2:
# x, state = layer(x, state)
# for layer in self.layer3:
# x, state = layer(x, state)
# for layer in self.layer4:
# x, state = layer(x, state)
x = self.avg(x)
x = jnp.ravel(x)
x = self.fc(x)
return x, state
I'll need to investigate why though.
Edit:
Ok, so at this point, I'm almost convinced that it must be related with how JAX treats these modules when they are in a Python list. I tried using eqx.nn.Sequential but that didn't work. jax.lax.scan wouldn't work because the list is not a JAX array and jax.lax.while_loop also fails because JAX can't index the list using a traced integer.
Ok, strangely enough, I got it.
Out of desperation, I just did uv pip install -e . -U and uv pip install --upgrade "jax[cuda12]" and now it works.
The reason was this:
After I had installed JAX via
uv pip install --upgrade "jax[cuda12]"
I also installed torch and torchvision afterwards like so:
uv pip install torch torchvision
But this had the unfortunate side-effect of overwriting the JAX Nvidia CUDA libraries. So, once I reinstalled the JAX CUDA libraries with this
uv pip install --upgrade "jax[cuda12]"
it fetched the JAX-compatible CUDA libraries, and then it worked.
In other words, don't mix-mash your CUDA! I'll need to find a way to have both installed and working at the same time, but that's a topic for another time.
What's left unanswered is why the model was in fact improving - even under JIT - when I removed those layers, but perhaps that's out-of-scope here.
Edit 2: Install PyTorch first and then JAX; from my tests PyTorch works fine with the JAX CUDAs but not the other way around.
This sounds pretty weird :) But I'm glad you got to the bottom of it!