SimonKoop
SimonKoop
The following code ```python import jax from jax import numpy as jnp !pip install equinox import equinox as eqx print(f"{eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=}") print(f"{jax.hessian(jax.nn.relu)(jnp.ones(()))=}") ``` results in ``` eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=Array(1., dtype=float32) jax.hessian(jax.nn.relu)(jnp.ones(()))=Array(0., dtype=float32) ```...
### Description I sometimes, but not consistently, get the following jaxlib.xla_extension.XlaRuntimeError when training a neural network with sine activations in JAX on an NVIDIA A100 GPU: ``` jaxlib.xla_extension.XlaRuntimeError: INTERNAL: All...
Since version 0.11.6, the following results in a strange AttributeError: ```python import jax import equinox as eqx class Parent(eqx.Module): abs_cls_var: eqx.AbstractClassVar[str] def __init__(self, **kwargs): pass def __init_subclass__(cls): """__init_subclass__ tries to...
The call for `PReLU` doesn't take a `key` argument, which leads to it being incompatible with Sequential. The easiest fix is to add an (ignored) `key` argument to the `__call__`...