jax
jax copied to clipboard
Prefer user specified dtype in jax.nn.initializers.orthogonal
Closes #18267
@jakevdp Good to go.
I think the logic in the uniform example looks pretty good now. We'd want to apply the same pattern in other cases, and then fix the tests according to my previous comments. Thanks!
Hi - let me know if you're interested in continuing to work on this!
Hi sorry for the lack of progress. I was lost in some research stuff at school. I should be able to get back to this later this week.
On Tue, 28 Nov 2023 at 19:54, Jake Vanderplas @.***> wrote:
Hi - let me know if you're interested in continuing to work on this!
— Reply to this email directly, view it on GitHub https://github.com/google/jax/pull/18266#issuecomment-1830622627, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABOCZOEX5UYIPH573CMLGQDYGY6OTAVCNFSM6AAAAAA6OPK6YGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMZQGYZDENRSG4 . You are receiving this because you authored the thread.Message ID: @.***>
@jakevdp I updated the remaining initializers.
We need to rethink this. I think the desired semantics are:
- if the user specifies a dtype, then we should cast inputs to that dtype.
- if the user does not specify a dtype, then we should not cast inputs to that dtype, and if the uncast inputs require promotion, it should raise a promotion error in strict mode.
I think the current approach gets closer to that, but it's still problematic when the dtype is specified at initialization.
We need to rethink this. I think the desired semantics are:
- if the user specifies a dtype, then we should cast inputs to that dtype.
- if the user does not specify a dtype, then we should not cast inputs to that dtype, and if the uncast inputs require promotion, it should raise a promotion error in strict mode.
I think the current approach gets closer to that, but it's still problematic when the dtype is specified at initialization.
The initializers current allow specifying dtypes at two levels. I guess when you say the user specifies a dtype, I think you mean that the user supplied the dtype at either level (outer or inner)?
If this is the case, I guess we can change both the outer and inner dtype to None. If either of those is overridden then we perform a cast otherwise we leave the parameter as is.
A rough sketch of what this looks like would be
def orthogonal(scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype = None) -> Initializer:
default_dtype = dtype
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = None) -> Array:
if default_dtype is None and dtype is None:
# Both are not specified, use a default and cast the parameter
dtype = jnp.float_
scale_ = lax.convert_element_type(scale, dtype)
elif dtype is None and default_dtype is not None:
# Fallback to outer dtype and don't cast
dtype = default_dtype
scale_ = scale
elif dtype is not None and default_dtype is None:
# Use inner dtype and don't cast
scale_ = scale
else:
# Both outer and innert dtype is supplied, prefer inner and don't cast
scale_ = scale
# Proceed as usual without any additional casting
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
n_rows, n_cols = math.prod(shape) // shape[column_axis], shape[column_axis]
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
A = random.normal(key, matrix_shape, dtype)
Q, R = jnp.linalg.qr(A)
diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
Q *= diag_sign # needed for a uniform distribution
if n_rows < n_cols: Q = Q.T
Q = jnp.reshape(Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis],))
Q = jnp.moveaxis(Q, -1, column_axis)
return scale_ * Q
return init
The logic is getting a bit complicated but maybe we can figure out how to simplify it.
That seems like a lot of boilerplate... seems like we want something like this?
def orthogonal(scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype = None) -> Initializer:
default_dtype = dtypes.canonicalize_dtype(float)
outer_dtype = dtype
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = None) -> Array:
inner_dtype = dtype
dtype = dtype or outer_dtype or default_dtype
scale_ = jnp.array(scale, dtype=None if inner_dtype is None and outer_dtype is None else dtype)
# Proceed as usual without any additional casting
# ...
@jakevdp Happy new year! Sorry I was late to this as I was traveling. I updated the initializers according to your suggestion.
It looks like there are some CI failures – please take a look!
@jakevdp Sorry for getting back on this really late. I fixed the lint error, but I wasn't so sure what the error on 3.12 build is yet.
You'll need to rebase on the main branch to pick up the last several months of updates.
@jakevdp I fixed the broken test cases in stax, but I am not so comfortable with having an asymmetry between the jnp.float_ for some initializers and None for others. WDYT?
Hi - sorry for having lost context on this in the last few months, but I'm not sure I understand what asymmetry you're referring to.
After the PR the outer type would default to None for the initializers touched by this PR, but there are a couple of initializers which defaults to jnp.float_. Do we want to update those as well?
On Tue, 20 Feb 2024 at 18:11, Jake Vanderplas @.***> wrote:
Hi - sorry for having lost context on this in the last few months, but I'm not sure I understand what asymmetry you're referring to.
— Reply to this email directly, view it on GitHub https://github.com/google/jax/pull/18266#issuecomment-1954794247, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABOCZOEAFXTMHYVCNUUEKK3YUTRN5AVCNFSM6AAAAAA6OPK6YGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNJUG44TIMRUG4 . You are receiving this because you authored the thread.Message ID: @.***>