equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Model trains but only if I dont JIT the step function?

Open Artur-Galstyan opened this issue 10 months ago • 5 comments

Edit from the Future

I fixed it and it's described here.

Preface (not super relevant; can be skipped)

Ok, so this is a weird one, which took me HOURS to find. I wanted to implement ResNet and train it on Cifar10 for another YT video, so I started hacking away ported the PyTorch implementation to Equinox. So far, so good. But I couldn't for the love of God get it to train. The PyTorch version had no problem training - even without any preprocessing of the data, no fancy-schmancy learning rate schedulers, just the most straight forward implementation you can think of.

I thought I was going crazy; I thought maybe it was because of the BatchNorm (because I saw a couple of open issues) - so I implemented a slightly different version that matches PyTorch version EXACTLY. But to no avail. I started to check the intermediate outputs of the network, maybe something is off there? No. Then I even turned off BatchNorm entirely in both networks. The PyTorch one - even without BatchNorm - trained no problems at all. But not my version; so it's definitely not because of the BatchNorm discrepancies. At this point it's basically just a large CNN with some residual connections.

Copy-pastable version of ResNet (jaxtyping required)

from typing import Type

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt

# from jaxonmodels.layers.batch_norm import BatchNorm


class Downsample(eqx.Module):
    conv: eqx.nn.Conv2d
    # bn: BatchNorm

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        key: jt.PRNGKeyArray,
    ):
        _, subkey = jax.random.split(key)
        self.conv = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=stride,
            use_bias=False,
            key=subkey,
        )

        # self.bn = BatchNorm(out_channels, axis_name="batch")

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        x = self.conv(x)
        # x, state = self.bn(x, state)

        return x, state


class BasicBlock(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    expansion: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        key, *subkeys = jax.random.split(key, 3)

        self.conv1 = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            use_bias=False,
            key=subkeys[0],
        )
        # self.bn1 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_bias=False,
            key=subkeys[1],
        )
        # self.bn2 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.downsample = downsample

    def __call__(self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State):
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)

        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)

        return x, state


class Bottleneck(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    conv3: eqx.nn.Conv2d
    # bn3: BatchNorm

    expansion: int = eqx.field(static=True, default=4)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        _, *subkeys = jax.random.split(key, 4)

        width = int(out_channels * (base_width / 64.0)) * groups
        self.conv1 = eqx.nn.Conv2d(
            in_channels, width, kernel_size=1, use_bias=False, key=subkeys[0]
        )
        # self.bn1 = BatchNorm(width, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            width,
            width,
            kernel_size=3,
            stride=stride,
            groups=groups,
            dilation=dilation,
            padding=dilation,
            use_bias=False,
            key=subkeys[1],
        )

        # self.bn2 = BatchNorm(width, axis_name="batch")

        self.conv3 = eqx.nn.Conv2d(
            width,
            out_channels * self.expansion,
            kernel_size=1,
            key=subkeys[2],
            use_bias=False,
        )

        # self.bn3 = BatchNorm(out_channels * self.expansion, axis_name="batch")

        self.downsample = downsample

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)
        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)
        x = jax.nn.relu(x)

        x = self.conv3(x)
        # x, state = self.bn3(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)
        return x, state


class ResNet(eqx.Module):
    conv1: eqx.nn.Conv2d
    # bn: BatchNorm
    mp: eqx.nn.MaxPool2d

    layer1: list[BasicBlock | Bottleneck]
    layer2: list[BasicBlock | Bottleneck]
    layer3: list[BasicBlock | Bottleneck]
    layer4: list[BasicBlock | Bottleneck]

    avg: eqx.nn.AdaptiveAvgPool2d
    fc: eqx.nn.Linear

    running_internal_channels: int = eqx.field(static=True, default=64)
    dilation: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        block: Type[BasicBlock | Bottleneck],
        layers: list[int],
        n_classes: int,
        zero_init_residual: bool,
        groups: int,
        width_per_group: int,
        replace_stride_with_dilation: list[bool] | None,
        key: jt.PRNGKeyArray,
        input_channels: int = 3,
    ):
        key, *subkeys = jax.random.split(key, 10)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                f"`replace_stride_with_dilation` should either be `None` "
                f"or have a length of 3, got {replace_stride_with_dilation} instead."
            )

        self.conv1 = eqx.nn.Conv2d(
            in_channels=input_channels,
            out_channels=self.running_internal_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            use_bias=False,
            key=subkeys[0],
        )

        # self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
        self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(
            block,
            64,
            layers[0],
            stride=1,
            dilate=False,
            groups=groups,
            base_width=width_per_group,
            key=subkeys[1],
        )
        self.layer2 = self._make_layer(
            block,
            128,
            layers[1],
            stride=2,
            dilate=replace_stride_with_dilation[0],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[2],
        )
        self.layer3 = self._make_layer(
            block,
            256,
            layers[2],
            stride=2,
            dilate=replace_stride_with_dilation[1],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[3],
        )
        self.layer4 = self._make_layer(
            block,
            512,
            layers[3],
            stride=2,
            dilate=replace_stride_with_dilation[2],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[4],
        )

        self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
        self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])

        if zero_init_residual:
            # todo: init last bn layer with zero weights
            pass

    def _make_layer(
        self,
        block: Type[BasicBlock | Bottleneck],
        out_channels: int,
        blocks: int,
        stride: int,
        dilate: bool,
        groups: int,
        base_width: int,
        key: jt.PRNGKeyArray,
    ) -> list[BasicBlock | Bottleneck]:
        downsample = None
        previous_dilation = self.dilation

        if dilate:
            self.dilation *= stride
            stride = 1

        if (
            stride != 1
            or self.running_internal_channels != out_channels * block.expansion
        ):
            key, subkey = jax.random.split(key)
            downsample = Downsample(
                self.running_internal_channels,
                out_channels * block.expansion,
                stride,
                subkey,
            )
        layers = []

        key, subkey = jax.random.split(key)
        layers.append(
            block(
                in_channels=self.running_internal_channels,
                out_channels=out_channels,
                stride=stride,
                downsample=downsample,
                groups=groups,
                base_width=base_width,
                dilation=previous_dilation,
                key=subkey,
            )
        )

        self.running_internal_channels = out_channels * block.expansion

        for _ in range(1, blocks):
            key, subkey = jax.random.split(key)
            layers.append(
                block(
                    in_channels=self.running_internal_channels,
                    out_channels=out_channels,
                    groups=groups,
                    base_width=base_width,
                    dilation=self.dilation,
                    stride=1,
                    downsample=None,
                    key=subkey,
                )
            )

        return layers

    def __call__(
        self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
        x = self.conv1(x)
        # x, state = self.bn(x, state)
        x = jax.nn.relu(x)
        x = self.mp(x)

        for layer in self.layer1:
            x, state = layer(x, state)

        for layer in self.layer2:
            x, state = layer(x, state)

        for layer in self.layer3:
            x, state = layer(x, state)

        for layer in self.layer4:
            x, state = layer(x, state)

        x = self.avg(x)
        x = jnp.ravel(x)

        x = self.fc(x)

        return x, state


def resnet18(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    key, subkey = jax.random.split(key)
    resnet, state = eqx.nn.make_with_state(ResNet)(
        BasicBlock,
        [2, 2, 2, 2],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )

    # initializer = jax.nn.initializers.he_normal()
    # is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
    # get_weights = lambda m: [
    #     x.weight for x in jax.tree.leaves(m, is_leaf=is_conv2d) if is_conv2d(x)
    # ]
    # weights = get_weights(resnet)
    # new_weights = [
    #     initializer(subkey, weight.shape, jnp.float32)
    #     for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
    # ]
    # resnet = eqx.tree_at(get_weights, resnet, new_weights)

    return resnet, state


def resnet34(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        BasicBlock,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnet50(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnet101(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnet152(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 8, 36, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnext50_32x4d(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=32,
        width_per_group=4,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnext101_32x8d(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=32,
        width_per_group=8,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnext101_64x4d(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=64,
        width_per_group=4,
        replace_stride_with_dilation=None,
        key=key,
    )


def wide_resnet50_2(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64 * 2,
        replace_stride_with_dilation=None,
        key=key,
    )


def wide_resnet101_2(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64 * 2,
        replace_stride_with_dilation=None,
        key=key,
    )

The issue

I was started to get desperate and started to check the gradients, for which I needed to un-JIT the step function to see the gradient print statements. And there it is: after removing the eqx.filter_jit wrapper around the step function, the network started to train.

Copy-pasteable training loop code (requires tensorflow tensorflow_datasets clu tqdm jaxtyping)

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from clu import metrics
from tqdm import tqdm

# from jaxonmodels.models.resnet import ResNet, resnet18
# copy paste here the code from the details above

(train, test), info = tfds.load(
    "cifar10", split=["train", "test"], with_info=True, as_supervised=True
) # pyright: ignore


def preprocess(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.cast(img, tf.float32) / 255.0 # pyright: ignore
    mean = tf.constant([0.4914, 0.4822, 0.4465])
    std = tf.constant([0.2470, 0.2435, 0.2616])
    img = (img - mean) / std # pyright: ignore

    img = tf.transpose(img, perm=[2, 0, 1])

    # label = tf.one_hot(label, depth=10)

    return img, label


def preprocess_train(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], mode="REFLECT")
    img = tf.image.random_crop(img, [32, 32, 3])
    img = tf.image.random_flip_left_right(img)  # pyright: ignore

    return preprocess(img, label)


train_dataset = train.map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
SHUFFLE_VAL = len(train_dataset) // 1000
BATCH_SIZE = 128
train_dataset = train_dataset.shuffle(SHUFFLE_VAL)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = test.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = tfds.as_numpy(train_dataset)
test_dataset = tfds.as_numpy(test_dataset)



def loss_fn(
    resnet: ResNet,
    x: jt.Array,
    y: jt.Array,
    state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
    logits, state = eqx.filter_vmap(
        resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(loss), (logits, state)

# @eqx.filter_jit
def step(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    resnet = eqx.apply_updates(resnet, updates)
    return resnet, state, opt_state, loss_value, logits



class TrainMetrics(eqx.Module, metrics.Collection):
    loss: metrics.Average.from_output("loss")  # pyright: ignore
    accuracy: metrics.Accuracy


def eval(
    resnet: ResNet, test_dataset, state, key: jt.PRNGKeyArray
) -> TrainMetrics:
    eval_metrics = TrainMetrics.empty()
    for x, y in test_dataset:
        y = jnp.array(y, dtype=jnp.int32)
        loss, (logits, state) = loss_fn(resnet, x, y, state)
        eval_metrics = eval_metrics.merge(
            TrainMetrics.single_from_model_output(
                logits=logits, labels=y, loss=loss
            )
        )

    return eval_metrics


train_metrics = TrainMetrics.empty()

resnet, state = resnet18(key=jax.random.key(0), n_classes=10)

learning_rate = 0.1
weight_decay = 5e-4
optimizer = optax.sgd(learning_rate)

opt_state = optimizer.init(eqx.filter(resnet, eqx.is_inexact_array_like))

key = jax.random.key(99)
n_epochs = 100


for epoch in range(n_epochs):
    batch_count = len(train_dataset)

    pbar = tqdm(enumerate(train_dataset), total=batch_count, desc=f"Epoch {epoch}")
    for i, (x, y) in pbar:
        y = jnp.array(y, dtype=jnp.int32)
        resnet, state, opt_state, loss, logits = step(
            resnet, state, x, y, optimizer, opt_state
        )
        train_metrics = train_metrics.merge(
            TrainMetrics.single_from_model_output(
                logits=logits, labels=y, loss=loss
            )
        )

        vals = train_metrics.compute()
        pbar.set_postfix(
            {"loss": f"{vals['loss']:.4f}", "acc": f"{vals['accuracy']:.4f}"}
        )
    key, subkey = jax.random.split(key)
    eval_metrics = eval(resnet, test_dataset, state, subkey)
    evals = eval_metrics.compute()
    print(
        f"Epoch {epoch}: "
        f"test_loss={evals['loss']:.4f}, "
        f"test_acc={evals['accuracy']:.4f}"
    )

The relevant part (tl;dr)

The issue is here:

def loss_fn(
    resnet: ResNet,
    x: jt.Array,
    y: jt.Array,
    state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
    logits, state = eqx.filter_vmap(
        resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(loss), (logits, state)

# @eqx.filter_jit
def step(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    resnet = eqx.apply_updates(resnet, updates)
    return resnet, state, opt_state, loss_value, logits

From my perspective, this looks just like standard JAX "boilerplate" code. I see no reason, why JITting the step function would interfere with training the model.

My other attemps

So perhaps I can get rid of the state, I thought, since I don't even use BatchNorm anymore. But that makes no difference. I tried JITting a smaller portion, as shown in the RNN example

@eqx.filter_value_and_grad
    def compute_loss(model, x, y):
        pred_y = jax.vmap(model)(x)
        # Trains with respect to binary cross-entropy
        return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

But the equivalent version didn't improve the model.

I spent all day on this and am now out of options and in German we'd say "es ist wie verhext" , so perhaps anyone here has an idea? ANY help is HIGHLY appreciated.

Artur-Galstyan avatar Mar 12 '25 23:03 Artur-Galstyan

I think the code can be further simplified

code

from typing import Type

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt

# from jaxonmodels.layers.batch_norm import BatchNorm


class Downsample(eqx.Module):
    conv: eqx.nn.Conv2d
    # bn: BatchNorm

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        key: jt.PRNGKeyArray,
    ):
        _, subkey = jax.random.split(key)
        self.conv = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=stride,
            use_bias=False,
            key=subkey,
        )

        # self.bn = BatchNorm(out_channels, axis_name="batch")

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        x = self.conv(x)
        # x, state = self.bn(x, state)

        return x, state


class BasicBlock(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    expansion: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        key, *subkeys = jax.random.split(key, 3)

        self.conv1 = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            use_bias=False,
            key=subkeys[0],
        )
        # self.bn1 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_bias=False,
            key=subkeys[1],
        )
        # self.bn2 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.downsample = downsample

    def __call__(self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State):
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)

        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)

        return x, state


class Bottleneck(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    conv3: eqx.nn.Conv2d
    # bn3: BatchNorm

    expansion: int = eqx.field(static=True, default=4)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        _, *subkeys = jax.random.split(key, 4)

        width = int(out_channels * (base_width / 64.0)) * groups
        self.conv1 = eqx.nn.Conv2d(
            in_channels, width, kernel_size=1, use_bias=False, key=subkeys[0]
        )
        # self.bn1 = BatchNorm(width, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            width,
            width,
            kernel_size=3,
            stride=stride,
            groups=groups,
            dilation=dilation,
            padding=dilation,
            use_bias=False,
            key=subkeys[1],
        )

        # self.bn2 = BatchNorm(width, axis_name="batch")

        self.conv3 = eqx.nn.Conv2d(
            width,
            out_channels * self.expansion,
            kernel_size=1,
            key=subkeys[2],
            use_bias=False,
        )

        # self.bn3 = BatchNorm(out_channels * self.expansion, axis_name="batch")

        self.downsample = downsample

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)
        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)
        x = jax.nn.relu(x)

        x = self.conv3(x)
        # x, state = self.bn3(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)
        return x, state


class ResNet(eqx.Module):
    conv1: eqx.nn.Conv2d
    # bn: BatchNorm
    mp: eqx.nn.MaxPool2d

    layer1: list[BasicBlock | Bottleneck]
    layer2: list[BasicBlock | Bottleneck]
    layer3: list[BasicBlock | Bottleneck]
    layer4: list[BasicBlock | Bottleneck]

    avg: eqx.nn.AdaptiveAvgPool2d
    fc: eqx.nn.Linear

    running_internal_channels: int = eqx.field(static=True, default=64)
    dilation: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        block: Type[BasicBlock | Bottleneck],
        layers: list[int],
        n_classes: int,
        zero_init_residual: bool,
        groups: int,
        width_per_group: int,
        replace_stride_with_dilation: list[bool] | None,
        key: jt.PRNGKeyArray,
        input_channels: int = 3,
    ):
        key, *subkeys = jax.random.split(key, 10)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                f"`replace_stride_with_dilation` should either be `None` "
                f"or have a length of 3, got {replace_stride_with_dilation} instead."
            )

        self.conv1 = eqx.nn.Conv2d(
            in_channels=input_channels,
            out_channels=self.running_internal_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            use_bias=False,
            key=subkeys[0],
        )

        # self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
        self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(
            block,
            64,
            layers[0],
            stride=1,
            dilate=False,
            groups=groups,
            base_width=width_per_group,
            key=subkeys[1],
        )
        self.layer2 = self._make_layer(
            block,
            128,
            layers[1],
            stride=2,
            dilate=replace_stride_with_dilation[0],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[2],
        )
        self.layer3 = self._make_layer(
            block,
            256,
            layers[2],
            stride=2,
            dilate=replace_stride_with_dilation[1],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[3],
        )
        self.layer4 = self._make_layer(
            block,
            512,
            layers[3],
            stride=2,
            dilate=replace_stride_with_dilation[2],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[4],
        )

        self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
        self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])

        if zero_init_residual:
            # todo: init last bn layer with zero weights
            pass

    def _make_layer(
        self,
        block: Type[BasicBlock | Bottleneck],
        out_channels: int,
        blocks: int,
        stride: int,
        dilate: bool,
        groups: int,
        base_width: int,
        key: jt.PRNGKeyArray,
    ) -> list[BasicBlock | Bottleneck]:
        downsample = None
        previous_dilation = self.dilation

        if dilate:
            self.dilation *= stride
            stride = 1

        if (
            stride != 1
            or self.running_internal_channels != out_channels * block.expansion
        ):
            key, subkey = jax.random.split(key)
            downsample = Downsample(
                self.running_internal_channels,
                out_channels * block.expansion,
                stride,
                subkey,
            )
        layers = []

        key, subkey = jax.random.split(key)
        layers.append(
            block(
                in_channels=self.running_internal_channels,
                out_channels=out_channels,
                stride=stride,
                downsample=downsample,
                groups=groups,
                base_width=base_width,
                dilation=previous_dilation,
                key=subkey,
            )
        )

        self.running_internal_channels = out_channels * block.expansion

        for _ in range(1, blocks):
            key, subkey = jax.random.split(key)
            layers.append(
                block(
                    in_channels=self.running_internal_channels,
                    out_channels=out_channels,
                    groups=groups,
                    base_width=base_width,
                    dilation=self.dilation,
                    stride=1,
                    downsample=None,
                    key=subkey,
                )
            )

        return layers

    def __call__(
        self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
        x = self.conv1(x)
        # x, state = self.bn(x, state)
        x = jax.nn.relu(x)
        x = self.mp(x)

        for layer in self.layer1:
            x, state = layer(x, state)

        for layer in self.layer2:
            x, state = layer(x, state)

        for layer in self.layer3:
            x, state = layer(x, state)

        for layer in self.layer4:
            x, state = layer(x, state)

        x = self.avg(x)
        x = jnp.ravel(x)

        x = self.fc(x)

        return x, state


def resnet18(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    key, subkey = jax.random.split(key)
    resnet, state = eqx.nn.make_with_state(ResNet)(
        BasicBlock,
        [2, 2, 2, 2],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )

    # initializer = jax.nn.initializers.he_normal()
    # is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
    # get_weights = lambda m: [
    #     x.weight for x in jax.tree.leaves(m, is_leaf=is_conv2d) if is_conv2d(x)
    # ]
    # weights = get_weights(resnet)
    # new_weights = [
    #     initializer(subkey, weight.shape, jnp.float32)
    #     for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
    # ]
    # resnet = eqx.tree_at(get_weights, resnet, new_weights)

    return resnet, state

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm

tf.random.set_seed(0)

(train, test), info = tfds.load(
    "cifar10", split=["train", "test"], with_info=True, as_supervised=True
) # pyright: ignore


def preprocess(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.cast(img, tf.float32) / 255.0 # pyright: ignore
    mean = tf.constant([0.4914, 0.4822, 0.4465])
    std = tf.constant([0.2470, 0.2435, 0.2616])
    img = (img - mean) / std # pyright: ignore

    img = tf.transpose(img, perm=[2, 0, 1])

    # label = tf.one_hot(label, depth=10)

    return img, label


def preprocess_train(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], mode="REFLECT")
    img = tf.image.random_crop(img, [32, 32, 3])
    img = tf.image.random_flip_left_right(img)  # pyright: ignore

    return preprocess(img, label)


train_dataset = train.map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
SHUFFLE_VAL = len(train_dataset) // 1000
BATCH_SIZE = 2
train_dataset = train_dataset.shuffle(SHUFFLE_VAL)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = test.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = tfds.as_numpy(train_dataset)
test_dataset = tfds.as_numpy(test_dataset)

def loss_fn(
    resnet: ResNet,
    x: jt.Array,
    y: jt.Array,
    state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
    logits, state = eqx.filter_vmap(
        resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(loss), (logits, state)

@eqx.filter_jit
def step(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    new_r = eqx.apply_updates(resnet, updates)
    return new_r, state, opt_state, loss_value, logits, grads, updates

def step_nj(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    new_r = eqx.apply_updates(resnet, updates)
    return new_r, state, opt_state, loss_value, logits, grads, updates

resnet, state = resnet18(key=jax.random.key(0), n_classes=10)

learning_rate = 0.1
weight_decay = 5e-4
optimizer = optax.sgd(learning_rate)

opt_state = optimizer.init(eqx.filter(resnet, eqx.is_inexact_array_like))

key = jax.random.key(99)
n_epochs = 100

for epoch in range(n_epochs):
    batch_count = len(train_dataset)

    for i, (x, y) in enumerate(train_dataset):
        y = jnp.array(y, dtype=jnp.int32)
        new_j, state, os_j, loss, logits, g, u = step(
            resnet, state, x, y, optimizer, opt_state
        )
        print("\n WJ", state, loss, logits)
        y = jnp.array(y, dtype=jnp.int32)
        new_nj, state, os_nj, loss, logits_nj, g_nj, u_nj = step_nj(
            resnet, state, x, y, optimizer, opt_state
        )
        print("\n NJ", state, loss, logits_nj)
        print(jnp.allclose(logits, logits_nj))
        print(eqx.tree_equal(g, g_nj))
        print(eqx.tree_equal(u, u_nj))
        print(eqx.tree_equal(os_j, os_nj))
        print(eqx.tree_equal(new_j, new_nj))
        print(eqx.tree_equal(jax.tree.leaves(new_j), jax.tree.leaves(new_nj)))
        l = jax.tree.leaves(new_j)
        lj = jax.tree.leaves(new_nj)
        print(len(l), len(lj))
        for i in range(len(l)):
          try:
            print(i, jnp.linalg.norm(lj[i] - l[i]), jnp.allclose(lj[i], l[i]))
            if jnp.isnan(jnp.linalg.norm(lj[i] - l[i])):
              print(lj[i], l[i])
          except:
            print(i, lj[i], l[i], lj[i] == l[i])
        break
    break

The only differences I saw under jit is that the gradient is slightly different (with a norm of 1e-9 I assume that's just within precision), but not always. I would be surprise if this is the source of the problem, but I'm just trying to narrow it down since it's a very large setup currently

WJ State() 2.333593 [[-0.04504904  0.03363977 -0.03377024 -0.02731024 -0.00935392 -0.07189757
   0.0312047   0.04465355 -0.03419897  0.01027971]
 [-0.04424123  0.03772916 -0.02880739 -0.01243044 -0.00705145 -0.06763338
   0.03016024  0.04261264 -0.03077785  0.0194398 ]]

 NJ State() 2.333593 [[-0.04504904  0.03363976 -0.03377024 -0.02731025 -0.00935392 -0.07189757
   0.0312047   0.04465355 -0.03419897  0.01027971]
 [-0.04424123  0.03772916 -0.02880739 -0.01243044 -0.00705145 -0.06763338
   0.03016024  0.04261264 -0.03077785  0.0194398 ]]
True
False
False
True
False
False
22 22
0 2.735993e-08 True
1 1.10683e-08 True
2 1.1603525e-08 True
3 1.2323362e-08 True
4 1.2514869e-08 True
5 1.3517954e-08 True
6 2.0453514e-08 True
7 2.238987e-08 True
8 1.1969572e-08 True
9 9.68429e-09 True
10 1.0157422e-08 True
11 1.7057491e-08 True
12 1.5252258e-08 True
13 1.0765473e-08 True
14 6.8855863e-09 True
15 6.1813026e-09 True
16 9.383501e-09 True
17 5.7024003e-09 True
18 4.0014845e-09 True
19 1.3972744e-09 True
20 2.6966664e-09 True
21 0.0 True

lockwo avatar Mar 13 '25 00:03 lockwo

This is indeed a large setup. I'll focus today on making it smaller and further narrow down the issue. (It was late at night when I posted it)

Artur-Galstyan avatar Mar 13 '25 08:03 Artur-Galstyan

Progress!

The issue lies somewhere in the BasicBlock | Bottleneck layers. Because as soon as I commented those out in the __call__ function, the model trains (even when jitted)

Details

class ResNet(eqx.Module):
    conv1: eqx.nn.Conv2d
    bn: BatchNorm
    mp: eqx.nn.MaxPool2d

    layer1: list[BasicBlock | Bottleneck]
    layer2: list[BasicBlock | Bottleneck]
    layer3: list[BasicBlock | Bottleneck]
    layer4: list[BasicBlock | Bottleneck]

    avg: eqx.nn.AdaptiveAvgPool2d
    fc: eqx.nn.Linear

    running_internal_channels: int = eqx.field(static=True, default=64)
    dilation: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        block: Type[BasicBlock | Bottleneck],
        layers: list[int],
        n_classes: int,
        zero_init_residual: bool,
        groups: int,
        width_per_group: int,
        replace_stride_with_dilation: list[bool] | None,
        key: jt.PRNGKeyArray,
        input_channels: int = 3,
    ):
        key, *subkeys = jax.random.split(key, 10)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                f"`replace_stride_with_dilation` should either be `None` "
                f"or have a length of 3, got {replace_stride_with_dilation} instead."
            )

        self.conv1 = eqx.nn.Conv2d(
            in_channels=input_channels,
            out_channels=self.running_internal_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            use_bias=False,
            key=subkeys[0],
        )

        self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
        self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(
            block,
            64,
            layers[0],
            stride=1,
            dilate=False,
            groups=groups,
            base_width=width_per_group,
            key=subkeys[1],
        )
        self.layer2 = self._make_layer(
            block,
            128,
            layers[1],
            stride=2,
            dilate=replace_stride_with_dilation[0],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[2],
        )
        self.layer3 = self._make_layer(
            block,
            256,
            layers[2],
            stride=2,
            dilate=replace_stride_with_dilation[1],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[3],
        )
        self.layer4 = self._make_layer(
            block,
            512,
            layers[3],
            stride=2,
            dilate=replace_stride_with_dilation[2],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[4],
        )

        self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
        # self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])
        # change the last layer, otherwise the dims won't match :s 
        self.fc = eqx.nn.Linear(64, n_classes, key=subkeys[-1])

        if zero_init_residual:
            # todo: init last bn layer with zero weights
            pass

    def _make_layer(
        self,
        block: Type[BasicBlock | Bottleneck],
        out_channels: int,
        blocks: int,
        stride: int,
        dilate: bool,
        groups: int,
        base_width: int,
        key: jt.PRNGKeyArray,
    ) -> list[BasicBlock | Bottleneck]:
        downsample = None
        previous_dilation = self.dilation

        if dilate:
            self.dilation *= stride
            stride = 1

        if (
            stride != 1
            or self.running_internal_channels != out_channels * block.expansion
        ):
            key, subkey = jax.random.split(key)
            downsample = Downsample(
                self.running_internal_channels,
                out_channels * block.expansion,
                stride,
                subkey,
            )
        layers = []

        key, subkey = jax.random.split(key)
        layers.append(
            block(
                in_channels=self.running_internal_channels,
                out_channels=out_channels,
                stride=stride,
                downsample=downsample,
                groups=groups,
                base_width=base_width,
                dilation=previous_dilation,
                key=subkey,
            )
        )

        self.running_internal_channels = out_channels * block.expansion

        for _ in range(1, blocks):
            key, subkey = jax.random.split(key)
            layers.append(
                block(
                    in_channels=self.running_internal_channels,
                    out_channels=out_channels,
                    groups=groups,
                    base_width=base_width,
                    dilation=self.dilation,
                    stride=1,
                    downsample=None,
                    key=subkey,
                )
            )

        return layers

    def __call__(
        self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
        x = self.conv1(x)
        x, state = self.bn(x, state)
        x = jax.nn.relu(x)
        x = self.mp(x)


        # Comment these out
        # for layer in self.layer1:
        #     x, state = layer(x, state)

        # for layer in self.layer2:
        #     x, state = layer(x, state)

        # for layer in self.layer3:
        #     x, state = layer(x, state)

        # for layer in self.layer4:
        #     x, state = layer(x, state)

        x = self.avg(x)
        x = jnp.ravel(x)

        x = self.fc(x)

        return x, state

I'll need to investigate why though.

Edit: Ok, so at this point, I'm almost convinced that it must be related with how JAX treats these modules when they are in a Python list. I tried using eqx.nn.Sequential but that didn't work. jax.lax.scan wouldn't work because the list is not a JAX array and jax.lax.while_loop also fails because JAX can't index the list using a traced integer.

Artur-Galstyan avatar Mar 13 '25 08:03 Artur-Galstyan

Ok, strangely enough, I got it.

Out of desperation, I just did uv pip install -e . -U and uv pip install --upgrade "jax[cuda12]" and now it works.

The reason was this:

After I had installed JAX via

uv pip install --upgrade "jax[cuda12]"

I also installed torch and torchvision afterwards like so:

uv pip install torch torchvision

But this had the unfortunate side-effect of overwriting the JAX Nvidia CUDA libraries. So, once I reinstalled the JAX CUDA libraries with this

uv pip install --upgrade "jax[cuda12]"

it fetched the JAX-compatible CUDA libraries, and then it worked.

In other words, don't mix-mash your CUDA! I'll need to find a way to have both installed and working at the same time, but that's a topic for another time.

What's left unanswered is why the model was in fact improving - even under JIT - when I removed those layers, but perhaps that's out-of-scope here.

Edit 2: Install PyTorch first and then JAX; from my tests PyTorch works fine with the JAX CUDAs but not the other way around.

Artur-Galstyan avatar Mar 13 '25 11:03 Artur-Galstyan

This sounds pretty weird :) But I'm glad you got to the bottom of it!

patrick-kidger avatar Mar 13 '25 23:03 patrick-kidger