nn.WeightNorm needs a special scale_init to match PyTorch weight_norm
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 11 Pro using WSL 2
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:
Name: flax
Version: 0.8.5
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author:
Author-email: Flax team <[email protected]>
License:
Location: /home/braun/.local/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: audiotree, clu
---
Name: jax
Version: 0.4.31
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/braun/.local/lib/python3.10/site-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, jaxloudnorm, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.31
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/braun/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, jax, optax, orbax-checkpoint
- Python version: 3.10
- GPU/TPU model and memory: RTX 2080 (8GB)
- CUDA version (if applicable): 12.3
Problem you have encountered:
flax.linen.WeightNorm needs an special scale_init in order to match PyTorch. I have written an example in both PyTorch and Flax that produces the same outputs.
About Conv
Before talking about WeightNorm, I first have to show that the convolutions before the weight norm produce the same outputs. That's the purpose of run_custom_conv() in both scripts. The torch documentation for Conv2d gives a formula for initializing the kernel and the bias. In my Flax script, I have a make_initializer which uses in_channels, like a fan-in operation described by the torch docs. I looked at the source code of variance_scaling, and it turns out that you can use kernel_init = nn.initializers.variance_scaling(1/3, "fan_in", "uniform") in JAX instead of make_initializer(...). Needing to use 1/3 is a little unintuitive, but no big deal.
Other users have pointed out that you can't use variance_scaling for the bias_init (https://github.com/google/flax/issues/2749). One solution is to refactor one's code to use make_initializer. If you need a fan-out operation, like how torch does ConvTranspose, it's also easy to refactor make_initializer.
About WeightNorm
I have a guess that Flax WeightNorm needs scale_init = nn.initializers.constant(1/jnp.sqrt(3)) in order to match PyTorch. I arrived at this number through a bit of trial and error, and I also think it's not 0.5. I would like to know if someone can explain why.
Here's the PyTorch:
from einops import rearrange
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn import functional as F
def WNConv2d(*args, act=True, **kwargs):
conv = weight_norm(nn.Conv2d(*args, **kwargs))
if not act:
return conv
return nn.Sequential(conv, nn.LeakyReLU(0.1))
class MPD(nn.Module):
def __init__(self, period):
super().__init__()
self.period = period
self.convs = nn.ModuleList(
[
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
]
)
self.conv_post = WNConv2d(
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
)
def pad_to_period(self, x):
t = x.shape[-1]
x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
return x
def forward(self, x):
fmap = []
x = self.pad_to_period(x)
x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
def summary_stats(name, x, ax):
x = x.detach().cpu().numpy()
ax.hist(x.reshape(-1), bins=100, alpha=0.5, label=name)
ax.set_title(f'PyTorch Histogram of {name}')
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')
ax.legend(loc='upper right')
print(f'Stats for {name}:')
print(f'shape:', list(x.shape))
print(f'mean: {np.mean(x):,.5f} min: {np.min(x):,.5f} max: {np.max(x):,.5f} std: {np.std(x):,.5f}')
def run_MPD():
B, C, T = 1, 1, 44100
x = torch.rand((B, C, T)).cuda()*2-1
period = 2
model = MPD(period).cuda()
fmaps = model(x)
# Create a tall figure with one subplot for each feature map
fig, axs = plt.subplots(len(fmaps), 1, figsize=(10, 18))
fig.tight_layout(pad=5.0) # Adjust the spacing between subplots
# Plot each histogram on a different subplot
for i, (fmap, ax) in enumerate(zip(fmaps, axs)):
summary_stats(f"fmap {i}", fmap, ax)
print()
plt.show()
from torchinfo import summary
summary(model,
col_names=['input_size', 'output_size', 'num_params'],
input_size=x.shape,
depth=5,
verbose=1,
)
def run_custom_conv():
B, C, H, W = 1, 1, 25, 25
x = torch.rand((B, C, H, W)).cuda()*2-1
model = nn.Conv2d(C, out_channels=32, kernel_size=(3, 3), padding=0).cuda()
fmap = model(x)
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
summary_stats(f"fmap 0", fmap.reshape(-1), ax)
print()
plt.show()
from torchinfo import summary
summary(model,
col_names=['input_size', 'output_size', 'num_params'],
input_size=x.shape,
depth=5,
verbose=1,
)
if __name__ == '__main__':
print('running custom conv:')
run_custom_conv()
print('running MPD:')
run_MPD()
and its output:
running custom conv:
Stats for fmap 0:
shape: [16928]
mean: -0.02200 min: -1.46258 max: 1.46658 std: 0.38855
===================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param #
===================================================================================================================
Conv2d [1, 1, 25, 25] [1, 32, 23, 23] 320
===================================================================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.17
===================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.14
Params size (MB): 0.00
Estimated Total Size (MB): 0.14
===================================================================================================================
running MPD:
C:\Python311\Lib\site-packages\torch\nn\utils\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
Stats for fmap 0:
shape: [1, 32, 7351, 2]
mean: 0.18682 min: -0.19188 max: 1.80779 std: 0.27846
Stats for fmap 1:
shape: [1, 128, 2451, 2]
mean: 0.07238 min: -0.10415 max: 0.91518 std: 0.11993
Stats for fmap 2:
shape: [1, 512, 817, 2]
mean: 0.02764 min: -0.03801 max: 0.46201 std: 0.04894
Stats for fmap 3:
shape: [1, 1024, 273, 2]
mean: 0.01188 min: -0.01440 max: 0.14850 std: 0.02060
Stats for fmap 4:
shape: [1, 1024, 273, 2]
mean: 0.00579 min: -0.00673 max: 0.07628 std: 0.00976
Stats for fmap 5:
shape: [1, 1, 273, 2]
mean: -0.00274 min: -0.01184 max: 0.00474 std: 0.00262
===================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param #
===================================================================================================================
MPD [1, 1, 44100] [1, 32, 7351, 2] --
├─ModuleList: 1-1 -- -- --
│ └─Sequential: 2-1 [1, 1, 22051, 2] [1, 32, 7351, 2] --
│ │ └─Conv2d: 3-1 [1, 1, 22051, 2] [1, 32, 7351, 2] 224
│ │ └─LeakyReLU: 3-2 [1, 32, 7351, 2] [1, 32, 7351, 2] --
│ └─Sequential: 2-2 [1, 32, 7351, 2] [1, 128, 2451, 2] --
│ │ └─Conv2d: 3-3 [1, 32, 7351, 2] [1, 128, 2451, 2] 20,736
│ │ └─LeakyReLU: 3-4 [1, 128, 2451, 2] [1, 128, 2451, 2] --
│ └─Sequential: 2-3 [1, 128, 2451, 2] [1, 512, 817, 2] --
│ │ └─Conv2d: 3-5 [1, 128, 2451, 2] [1, 512, 817, 2] 328,704
│ │ └─LeakyReLU: 3-6 [1, 512, 817, 2] [1, 512, 817, 2] --
│ └─Sequential: 2-4 [1, 512, 817, 2] [1, 1024, 273, 2] --
│ │ └─Conv2d: 3-7 [1, 512, 817, 2] [1, 1024, 273, 2] 2,623,488
│ │ └─LeakyReLU: 3-8 [1, 1024, 273, 2] [1, 1024, 273, 2] --
│ └─Sequential: 2-5 [1, 1024, 273, 2] [1, 1024, 273, 2] --
│ │ └─Conv2d: 3-9 [1, 1024, 273, 2] [1, 1024, 273, 2] 5,244,928
│ │ └─LeakyReLU: 3-10 [1, 1024, 273, 2] [1, 1024, 273, 2] --
├─Conv2d: 1-2 [1, 1024, 273, 2] [1, 1, 273, 2] 3,074
===================================================================================================================
Total params: 8,221,154
Trainable params: 8,221,154
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 8.23
===================================================================================================================
Input size (MB): 0.18
Forward/backward pass size (MB): 24.43
Params size (MB): 32.88
Estimated Total Size (MB): 57.49
===================================================================================================================
and its two graphs:
Here's the Flax:
from einops import rearrange
from flax import linen as nn
import jax
from jax import numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
def make_initializer(in_channels, out_channels, kernel_size, groups):
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
k = groups / (in_channels * jnp.prod(jnp.array(kernel_size)))
scale = jnp.sqrt(k)
def init_fn(key, shape, dtype):
return jax.random.uniform(key, shape, minval=-scale, maxval=scale, dtype=dtype)
return init_fn
class CustomConv(nn.Conv):
@nn.compact
def __call__(self, x):
# note: we just ignore whatever self.kernel_init is
kernel_init = make_initializer(
x.shape[-1], self.features, self.kernel_size, self.feature_group_count
)
if self.use_bias:
# note: we just ignore whatever self.bias_init is
bias_init = make_initializer(
x.shape[-1], self.features, self.kernel_size, self.feature_group_count
)
else:
bias_init = None
# todo: try using these instead
# kernel_init = nn.initializers.variance_scaling(1/3, "fan_in", "uniform") # same as kernel_init above
# bias_init = nn.initializers.constant(1)
return nn.Conv(
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
input_dilation=self.input_dilation,
kernel_dilation=self.kernel_dilation,
feature_group_count=self.feature_group_count,
use_bias=self.use_bias,
mask=self.mask,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=kernel_init,
bias_init=bias_init
)(x)
class LeakyReLU(nn.Module):
negative_slope: float = .01
@nn.compact
def __call__(self, x):
return nn.leaky_relu(x, negative_slope=self.negative_slope)
def WNConv2d(*args, **kwargs):
scale_init = nn.initializers.constant(1/jnp.sqrt(3))
# scale_init = nn.initializers.constant(1) # todo: try using this instead
conv = nn.WeightNorm(CustomConv(*args, **kwargs), scale_init=scale_init)
return conv
class MPD(nn.Module):
period: int
def pad_to_period(self, x):
t = x.shape[-1]
x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, self.period - t % self.period)), mode='reflect')
return x
@nn.compact
def __call__(self, x):
convs = [
WNConv2d(features=32, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(features=128, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(features=512, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(features=1024, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(features=1024, kernel_size=(5, 1), strides=(1, 1), padding=((2, 2), (0, 0))),
WNConv2d(features=1, kernel_size=(3, 1), strides=(1, 1), padding=((1, 1), (0, 0))),
]
fmap = []
x = self.pad_to_period(x)
x = rearrange(x, "b c (l p) -> b l p c", p=self.period)
for i, layer in enumerate(convs):
x = layer(x)
if i != (len(convs) - 1):
x = LeakyReLU(negative_slope=0.1)(x)
fmap.append(x)
return fmap
def summary_stats(name, x, ax):
x = np.array(x)
ax.hist(x.reshape(-1), bins=100, alpha=0.5, label=name)
ax.set_title(f'JAX Histogram of {name}')
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')
ax.legend(loc='upper right')
print(f'Stats for {name}:')
print(f'shape:', list(x.shape))
print(f'mean: {np.mean(x):,.5f} min: {np.min(x):,.5f} max: {np.max(x):,.5f} std: {np.std(x):,.5f}')
def run_MPD():
key = jax.random.PRNGKey(0)
B, C, T = 1, 1, 44100
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(B, C, T), minval=-1.0, maxval=1.0)
period = 2
model = MPD(period)
fmaps, variables = model.init_with_output({"params": key}, x)
# Create a tall figure with one subplot for each feature map
fig, axs = plt.subplots(len(fmaps), 1, figsize=(10, 18))
fig.tight_layout(pad=5.0) # Adjust the spacing between subplots
# Plot each histogram on a different subplot
for i, (fmap, ax) in enumerate(zip(fmaps, axs)):
summary_stats(f"fmap {i}", fmap, ax)
print()
plt.show()
print(model.tabulate({"params": key}, x, console_kwargs={"width": 400}))
def run_custom_conv():
key = jax.random.PRNGKey(0)
B, C, H, W = 1, 1, 25, 25
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(B, H, W, C), minval=-1.0, maxval=1.0)
model = CustomConv(features=32, kernel_size=(3, 3), padding='VALID')
fmap, variables = model.init_with_output({"params": key}, x)
fig, axs = plt.subplots(1, 1, figsize=(8, 6))
fig.tight_layout(pad=5.0) # Adjust the spacing between subplots
# Plot each histogram on a different subplot
summary_stats(f"fmap 0", fmap.reshape(-1), axs)
print()
plt.show()
print(model.tabulate({"params": key}, x, console_kwargs={"width": 400}))
if __name__ == '__main__':
print('running custom conv:')
run_custom_conv()
print('running MPD:')
run_MPD()
Here's the Flax output:
running custom conv:
Stats for fmap 0:
shape: [16928]
mean: -0.03363 min: -1.40942 max: 1.28537 std: 0.37116
CustomConv Summary
┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ │ CustomConv │ float32[1,25,25,1] │ float32[1,23,23,32] │ │
├────────┼────────────┼────────────────────┼─────────────────────┼───────────────────────────┤
│ Conv_0 │ Conv │ float32[1,25,25,1] │ float32[1,23,23,32] │ bias: float32[32] │
│ │ │ │ │ kernel: float32[3,3,1,32] │
│ │ │ │ │ │
│ │ │ │ │ 320 (1.3 KB) │
├────────┼────────────┼────────────────────┼─────────────────────┼───────────────────────────┤
│ │ │ │ Total │ 320 (1.3 KB) │
└────────┴────────────┴────────────────────┴─────────────────────┴───────────────────────────┘
Total Parameters: 320 (1.3 KB)
running MPD:
Stats for fmap 0:
shape: [1, 7351, 2, 32]
mean: 0.15735 min: -0.14764 max: 1.46771 std: 0.26314
Stats for fmap 1:
shape: [1, 2451, 2, 128]
mean: 0.05720 min: -0.08959 max: 0.69397 std: 0.10307
Stats for fmap 2:
shape: [1, 817, 2, 512]
mean: 0.02382 min: -0.03053 max: 0.30742 std: 0.04272
Stats for fmap 3:
shape: [1, 273, 2, 1024]
mean: 0.01169 min: -0.01467 max: 0.13217 std: 0.01918
Stats for fmap 4:
shape: [1, 273, 2, 1024]
mean: 0.00552 min: -0.00686 max: 0.06548 std: 0.00945
Stats for fmap 5:
shape: [1, 273, 2, 1]
mean: 0.02159 min: 0.01417 max: 0.02803 std: 0.00248
MPD Summary
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ │ MPD │ float32[1,1,44100] │ - float32[1,7351,2,32] │ │
│ │ │ │ - float32[1,2451,2,128] │ │
│ │ │ │ - float32[1,817,2,512] │ │
│ │ │ │ - float32[1,273,2,1024] │ │
│ │ │ │ - float32[1,273,2,1024] │ │
│ │ │ │ - float32[1,273,2,1] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_0 │ WeightNorm │ float32[1,22051,2,1] │ float32[1,7351,2,32] │ CustomConv_0/Conv_0/kernel/scale: float32[32] │
│ │ │ │ │ │
│ │ │ │ │ 32 (128 B) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_0 │ CustomConv │ float32[1,22051,2,1] │ float32[1,7351,2,32] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_0/Conv_0 │ Conv │ float32[1,22051,2,1] │ float32[1,7351,2,32] │ bias: float32[32] │
│ │ │ │ │ kernel: float32[5,1,1,32] │
│ │ │ │ │ │
│ │ │ │ │ 192 (768 B) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_0 │ LeakyReLU │ float32[1,7351,2,32] │ float32[1,7351,2,32] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_1 │ WeightNorm │ float32[1,7351,2,32] │ float32[1,2451,2,128] │ CustomConv_1/Conv_0/kernel/scale: float32[128] │
│ │ │ │ │ │
│ │ │ │ │ 128 (512 B) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_1 │ CustomConv │ float32[1,7351,2,32] │ float32[1,2451,2,128] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_1/Conv_0 │ Conv │ float32[1,7351,2,32] │ float32[1,2451,2,128] │ bias: float32[128] │
│ │ │ │ │ kernel: float32[5,1,32,128] │
│ │ │ │ │ │
│ │ │ │ │ 20,608 (82.4 KB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_1 │ LeakyReLU │ float32[1,2451,2,128] │ float32[1,2451,2,128] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_2 │ WeightNorm │ float32[1,2451,2,128] │ float32[1,817,2,512] │ CustomConv_2/Conv_0/kernel/scale: float32[512] │
│ │ │ │ │ │
│ │ │ │ │ 512 (2.0 KB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_2 │ CustomConv │ float32[1,2451,2,128] │ float32[1,817,2,512] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_2/Conv_0 │ Conv │ float32[1,2451,2,128] │ float32[1,817,2,512] │ bias: float32[512] │
│ │ │ │ │ kernel: float32[5,1,128,512] │
│ │ │ │ │ │
│ │ │ │ │ 328,192 (1.3 MB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_2 │ LeakyReLU │ float32[1,817,2,512] │ float32[1,817,2,512] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_3 │ WeightNorm │ float32[1,817,2,512] │ float32[1,273,2,1024] │ CustomConv_3/Conv_0/kernel/scale: float32[1024] │
│ │ │ │ │ │
│ │ │ │ │ 1,024 (4.1 KB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_3 │ CustomConv │ float32[1,817,2,512] │ float32[1,273,2,1024] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_3/Conv_0 │ Conv │ float32[1,817,2,512] │ float32[1,273,2,1024] │ bias: float32[1024] │
│ │ │ │ │ kernel: float32[5,1,512,1024] │
│ │ │ │ │ │
│ │ │ │ │ 2,622,464 (10.5 MB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_3 │ LeakyReLU │ float32[1,273,2,1024] │ float32[1,273,2,1024] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_4 │ WeightNorm │ float32[1,273,2,1024] │ float32[1,273,2,1024] │ CustomConv_4/Conv_0/kernel/scale: float32[1024] │
│ │ │ │ │ │
│ │ │ │ │ 1,024 (4.1 KB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_4 │ CustomConv │ float32[1,273,2,1024] │ float32[1,273,2,1024] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_4/Conv_0 │ Conv │ float32[1,273,2,1024] │ float32[1,273,2,1024] │ bias: float32[1024] │
│ │ │ │ │ kernel: float32[5,1,1024,1024] │
│ │ │ │ │ │
│ │ │ │ │ 5,243,904 (21.0 MB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ LeakyReLU_4 │ LeakyReLU │ float32[1,273,2,1024] │ float32[1,273,2,1024] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ WeightNorm_5 │ WeightNorm │ float32[1,273,2,1024] │ float32[1,273,2,1] │ CustomConv_5/Conv_0/kernel/scale: float32[1] │
│ │ │ │ │ │
│ │ │ │ │ 1 (4 B) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_5 │ CustomConv │ float32[1,273,2,1024] │ float32[1,273,2,1] │ │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ CustomConv_5/Conv_0 │ Conv │ float32[1,273,2,1024] │ float32[1,273,2,1] │ bias: float32[1] │
│ │ │ │ │ kernel: float32[3,1,1024,1] │
│ │ │ │ │ │
│ │ │ │ │ 3,073 (12.3 KB) │
├─────────────────────┼────────────┼───────────────────────┼─────────────────────────┼─────────────────────────────────────────────────┤
│ │ │ │ Total │ 8,221,154 (32.9 MB) │
└─────────────────────┴────────────┴───────────────────────┴─────────────────────────┴─────────────────────────────────────────────────┘
Total Parameters: 8,221,154 (32.9 MB)
and its two graphs:
There's an explanation for 1/sqrt(3). It's because the variance of a uniform distribution between -1 and 1 is 1/3, so the standard deviation is 1/sqrt(3). I hope that's a clue for finding why PyTorch seems to do WeightNorm one way and Flax does it another.