flax
flax copied to clipboard
initialization with parameters depends on input
Hi I'm trying to port the following PyTorch model to Flax where which parameter is used depends on a runtime input, IIUC, in flax the initialization is done lazily thus I don't find an obvious way to initialize all the parameters with a call like init(key, x, 1)
, I'm wondering if there is a way to force initialize all the parameters?
PyTorch code
class Anstaz(nn.Module):
def __init__(self, input_size: int, output_size: int, *args, **kwargs) -> None:
super().__init__()
assert input_size >= output_size
self.input_size = input_size
self.output_size = output_size
self._args = args
self._kwargs = kwargs
self.__init_network__(*args, **kwargs)
def __init_network__(self, *args, **kwargs):
raise NotImplementedError
@property
def name(self) -> str:
raise NotImplementedError
class EachSite(Anstaz):
def __init__(
self,
ansatz: type[Anstaz],
input_size: int,
output_size: int,
n_rg_steps: int,
*args,
**kwargs,
) -> None:
super().__init__(input_size, output_size, *args, **kwargs)
self.networks = nn.ModuleList()
for _ in range(n_rg_steps):
self.networks.append(ansatz(input_size, output_size, *args, **kwargs))
def __init_network__(self, *args, **kwargs):
pass
def forward(self, x: torch.Tensor, idx: int) -> torch.Tensor:
return self.networks[idx](x)
@property
def name(self) -> str:
return "EachSite(" + self.networks[0].name + ")"
class Linear(Anstaz):
def __init_network__(self):
Q = torch.linalg.qr(torch.randn(self.input_size, self.output_size))[0]
self.projector = nn.Parameter(Q)
def forward(self, x: torch.Tensor) -> torch.Tensor:
P = self.projector
T = P.H @ x @ P
return (T + T.mH) / 2
@property
def name(self) -> str:
return "Linear"
Attempt flax implementation
import flax.linen as nn
from jax import numpy as jnp, Array
from typing import Callable, Tuple, Any
from flax.linen import initializers
PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
default_proj_init = initializers.lecun_normal()
class LinearQR(nn.Module):
output_size: int
param_dtype: Dtype = jnp.float32
proj_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_proj_init
@nn.compact
def __call__(self, op: Array) -> Array:
proj = self.param(
'proj',
self.proj_init,
(jnp.shape(op)[-1], self.output_size),
self.param_dtype,
)
proj = jnp.linalg.qr(proj)[0]
return proj.T @ op @ proj
class EachSite(nn.Module):
site_init: Callable[[], nn.Module]
n_iterations: int
def setup(self) -> None:
self.networks = [self.site_init() for _ in range(self.n_iterations)]
def __call__(self, op: Array, idx: int) -> Array:
return self.networks[idx](op)
Would this work:
import flax.linen as nn
from jax import numpy as jnp, Array
from typing import Callable, Tuple, Any
from flax.linen import initializers
PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
default_proj_init = initializers.lecun_normal()
class LinearQR(nn.Module):
output_size: int
param_dtype: Dtype = jnp.float32
proj_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_proj_init
@nn.compact
def __call__(self, op: Array) -> Array:
proj = self.param(
'proj',
self.proj_init,
(jnp.shape(op)[-1], self.output_size),
self.param_dtype,
)
proj = jnp.linalg.qr(proj)[0]
return proj.T @ op @ proj
class EachSite(nn.Module):
site_init: Callable[[], nn.Module]
n_iterations: int
output_size: int
param_dtype: Dtype = jnp.float32
proj_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_proj_init
def setup(self) -> None:
self.networks = [self.site_init(self.output_size, self.param_dtype, self.proj_init) for _ in range(self.n_iterations)]
def __call__(self, op: Array, idx: int) -> Array:
if self.is_mutable_collection('params'): # initialize params on init
for network in self.networks:
network(op)
return self.networks[idx](op)
a = EachSite(LinearQR, n_iterations=4, output_size=2)
v = a.init(jax.random.PRNGKey(0), jnp.ones((3, 3)), 0)
for i in range(4):
print(a.apply(v, jnp.ones((3, 3)), i))
jax.tree_map(lambda x: x.shape, v)
[[ 1.6175689 -1.3340342]
[-1.334034 1.1001985]]
[[ 0.02817053 -0.05630156]
[-0.05630156 0.11252418]]
[[2.7759523 0.6146091 ]
[0.6146091 0.13607739]]
[[0.04857546 0.08921713]
[0.08921713 0.16386251]]
{'params': {'networks_0': {'proj': (3, 2)},
'networks_1': {'proj': (3, 2)},
'networks_2': {'proj': (3, 2)},
'networks_3': {'proj': (3, 2)}}}
Also check out our MultiHeadSwitchExample
code example snippet in our docs.