chex
chex copied to clipboard
Consider supporting static attributes in chex.dataclass
trafficstars
from jax import jit
from jax.lax import scan
from tjax import IntegralNumeric, RealNumeric
from tjax.dataclasses import dataclass, field
import chex
def f(carry, _):
return carry + 1.0, None
@jit
def do_scan(c):
final, _ = scan(f, c.x, None, c.y)
return final
@dataclass
class C:
x: RealNumeric
y: IntegralNumeric = field(static=True)
print(do_scan(C(1.0, 10))) # works
@chex.dataclass
class D:
x: RealNumeric
y: IntegralNumeric
print(do_scan(D(x=1.0, y=10))) # fails
I guess this is a duplicate of https://github.com/deepmind/chex/issues/64!