equinox
equinox copied to clipboard
merging multiple eqx.Module classes
Hi all,
I am building a multi-physics solver using JAX and equinox, where I want to merge multiple dataclasses inherited from eqx.Module. If I am using standard Python classes, the following merge
function works:
def merge(ob1, ob2):
ob1.__dict__.update(ob2.__dict__)
return ob1
class State1():
def __init__(self, x):
self.x = x
class State2():
def __init__(self, y):
self.y = y
state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2
If I defined the state classes using eqx.Module, the merge
function seems to work:
class State1(eqx.Module):
x: Float_1D
class State2(eqx.Module):
y: Float_1D
state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2
Is this way of merging two dataclasses a good practice? Let me know if there are better ways to do that.