flax icon indicating copy to clipboard operation
flax copied to clipboard

initialization with parameters depends on input

Open Roger-luo opened this issue 1 year ago • 1 comments

Hi I'm trying to port the following PyTorch model to Flax where which parameter is used depends on a runtime input, IIUC, in flax the initialization is done lazily thus I don't find an obvious way to initialize all the parameters with a call like init(key, x, 1), I'm wondering if there is a way to force initialize all the parameters?

PyTorch code
class Anstaz(nn.Module):
    def __init__(self, input_size: int, output_size: int, *args, **kwargs) -> None:
        super().__init__()
        assert input_size >= output_size
        self.input_size = input_size
        self.output_size = output_size
        self._args = args
        self._kwargs = kwargs
        self.__init_network__(*args, **kwargs)

    def __init_network__(self, *args, **kwargs):
        raise NotImplementedError

    @property
    def name(self) -> str:
        raise NotImplementedError


class EachSite(Anstaz):
    def __init__(
        self,
        ansatz: type[Anstaz],
        input_size: int,
        output_size: int,
        n_rg_steps: int,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(input_size, output_size, *args, **kwargs)
        self.networks = nn.ModuleList()
        for _ in range(n_rg_steps):
            self.networks.append(ansatz(input_size, output_size, *args, **kwargs))

    def __init_network__(self, *args, **kwargs):
        pass

    def forward(self, x: torch.Tensor, idx: int) -> torch.Tensor:
        return self.networks[idx](x)

    @property
    def name(self) -> str:
        return "EachSite(" + self.networks[0].name + ")"


class Linear(Anstaz):
    def __init_network__(self):
        Q = torch.linalg.qr(torch.randn(self.input_size, self.output_size))[0]
        self.projector = nn.Parameter(Q)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        P = self.projector
        T = P.H @ x @ P
        return (T + T.mH) / 2

    @property
    def name(self) -> str:
        return "Linear"
Attempt flax implementation
import flax.linen as nn
from jax import numpy as jnp, Array
from typing import Callable, Tuple, Any
from flax.linen import initializers

PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
default_proj_init = initializers.lecun_normal()

class LinearQR(nn.Module):
    output_size: int
    param_dtype: Dtype = jnp.float32
    proj_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_proj_init
 
    @nn.compact
    def __call__(self, op: Array) -> Array:
        proj = self.param(
            'proj',
            self.proj_init,
            (jnp.shape(op)[-1], self.output_size),
            self.param_dtype,
        )
        proj = jnp.linalg.qr(proj)[0]
        return proj.T @ op @ proj


class EachSite(nn.Module):
    site_init: Callable[[], nn.Module]
    n_iterations: int

    def setup(self) -> None:
        self.networks = [self.site_init() for _ in range(self.n_iterations)]

    def __call__(self, op: Array, idx: int) -> Array:
        return self.networks[idx](op)

Roger-luo avatar Aug 11 '23 17:08 Roger-luo

Would this work:

import flax.linen as nn
from jax import numpy as jnp, Array
from typing import Callable, Tuple, Any
from flax.linen import initializers

PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
default_proj_init = initializers.lecun_normal()

class LinearQR(nn.Module):
    output_size: int
    param_dtype: Dtype = jnp.float32
    proj_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_proj_init
 
    @nn.compact
    def __call__(self, op: Array) -> Array:
        proj = self.param(
            'proj',
            self.proj_init,
            (jnp.shape(op)[-1], self.output_size),
            self.param_dtype,
        )
        proj = jnp.linalg.qr(proj)[0]
        return proj.T @ op @ proj


class EachSite(nn.Module):
    site_init: Callable[[], nn.Module]
    n_iterations: int
    output_size: int
    param_dtype: Dtype = jnp.float32
    proj_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_proj_init

    def setup(self) -> None:
        self.networks = [self.site_init(self.output_size, self.param_dtype, self.proj_init) for _ in range(self.n_iterations)]

    def __call__(self, op: Array, idx: int) -> Array:
        if self.is_mutable_collection('params'): # initialize params on init
          for network in self.networks:
            network(op)
        return self.networks[idx](op)

a = EachSite(LinearQR, n_iterations=4, output_size=2)
v = a.init(jax.random.PRNGKey(0), jnp.ones((3, 3)), 0)
for i in range(4):
    print(a.apply(v, jnp.ones((3, 3)), i))
jax.tree_map(lambda x: x.shape, v)
[[ 1.6175689 -1.3340342]
 [-1.334034   1.1001985]]
[[ 0.02817053 -0.05630156]
 [-0.05630156  0.11252418]]
[[2.7759523  0.6146091 ]
 [0.6146091  0.13607739]]
[[0.04857546 0.08921713]
 [0.08921713 0.16386251]]
{'params': {'networks_0': {'proj': (3, 2)},
  'networks_1': {'proj': (3, 2)},
  'networks_2': {'proj': (3, 2)},
  'networks_3': {'proj': (3, 2)}}}

Also check out our MultiHeadSwitchExample code example snippet in our docs.

chiamp avatar Sep 15 '23 23:09 chiamp