equinox icon indicating copy to clipboard operation
equinox copied to clipboard

merging multiple eqx.Module classes

Open ToshiyukiBandai opened this issue 2 months ago • 2 comments

Hi all,

I am building a multi-physics solver using JAX and equinox, where I want to merge multiple dataclasses inherited from eqx.Module. If I am using standard Python classes, the following merge function works:

def merge(ob1, ob2):
    ob1.__dict__.update(ob2.__dict__)
    return ob1

class State1():
    def __init__(self, x):
        self.x = x

class State2():
    def __init__(self, y):
        self.y = y

state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2

If I defined the state classes using eqx.Module, the merge function seems to work:

class State1(eqx.Module):
    x: Float_1D

class State2(eqx.Module):
    y: Float_1D

state1 = State1(jnp.asarray([2.0]))
state2 = State2(jnp.asarray([3.0]))
state3 = merge(state1, state2)
print(state3.x) # x from state1
print(state3.y) # y from state2

Is this way of merging two dataclasses a good practice? Let me know if there are better ways to do that.

ToshiyukiBandai avatar Apr 30 '24 19:04 ToshiyukiBandai