chex icon indicating copy to clipboard operation
chex copied to clipboard

Consider supporting static attributes in chex.dataclass

Open NeilGirdhar opened this issue 4 years ago • 1 comments
trafficstars

from jax import jit
from jax.lax import scan
from tjax import IntegralNumeric, RealNumeric
from tjax.dataclasses import dataclass, field
import chex

def f(carry, _):
  return carry + 1.0, None

@jit
def do_scan(c):
  final, _ = scan(f, c.x, None, c.y)
  return final

@dataclass
class C:
  x: RealNumeric
  y: IntegralNumeric = field(static=True)

print(do_scan(C(1.0, 10)))  # works

@chex.dataclass
class D:
  x: RealNumeric
  y: IntegralNumeric

print(do_scan(D(x=1.0, y=10)))  # fails

NeilGirdhar avatar Oct 25 '21 14:10 NeilGirdhar

I guess this is a duplicate of https://github.com/deepmind/chex/issues/64!

NeilGirdhar avatar Nov 21 '21 00:11 NeilGirdhar