jax icon indicating copy to clipboard operation
jax copied to clipboard

Prefer user specified dtype in jax.nn.initializers.orthogonal

Open ethanluoyc opened this issue 2 years ago • 15 comments

Closes #18267

ethanluoyc avatar Oct 24 '23 23:10 ethanluoyc

@jakevdp Good to go.

ethanluoyc avatar Oct 25 '23 23:10 ethanluoyc

I think the logic in the uniform example looks pretty good now. We'd want to apply the same pattern in other cases, and then fix the tests according to my previous comments. Thanks!

jakevdp avatar Nov 15 '23 22:11 jakevdp

Hi - let me know if you're interested in continuing to work on this!

jakevdp avatar Nov 28 '23 19:11 jakevdp

Hi sorry for the lack of progress. I was lost in some research stuff at school. I should be able to get back to this later this week.

On Tue, 28 Nov 2023 at 19:54, Jake Vanderplas @.***> wrote:

Hi - let me know if you're interested in continuing to work on this!

— Reply to this email directly, view it on GitHub https://github.com/google/jax/pull/18266#issuecomment-1830622627, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABOCZOEX5UYIPH573CMLGQDYGY6OTAVCNFSM6AAAAAA6OPK6YGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMZQGYZDENRSG4 . You are receiving this because you authored the thread.Message ID: @.***>

ethanluoyc avatar Nov 28 '23 20:11 ethanluoyc

@jakevdp I updated the remaining initializers.

ethanluoyc avatar Dec 04 '23 22:12 ethanluoyc

We need to rethink this. I think the desired semantics are:

  • if the user specifies a dtype, then we should cast inputs to that dtype.
  • if the user does not specify a dtype, then we should not cast inputs to that dtype, and if the uncast inputs require promotion, it should raise a promotion error in strict mode.

I think the current approach gets closer to that, but it's still problematic when the dtype is specified at initialization.

jakevdp avatar Dec 07 '23 18:12 jakevdp

We need to rethink this. I think the desired semantics are:

  • if the user specifies a dtype, then we should cast inputs to that dtype.
  • if the user does not specify a dtype, then we should not cast inputs to that dtype, and if the uncast inputs require promotion, it should raise a promotion error in strict mode.

I think the current approach gets closer to that, but it's still problematic when the dtype is specified at initialization.

The initializers current allow specifying dtypes at two levels. I guess when you say the user specifies a dtype, I think you mean that the user supplied the dtype at either level (outer or inner)?

If this is the case, I guess we can change both the outer and inner dtype to None. If either of those is overridden then we perform a cast otherwise we leave the parameter as is.

A rough sketch of what this looks like would be

def orthogonal(scale: RealNumeric = 1.0,
               column_axis: int = -1,
               dtype = None) -> Initializer:
  default_dtype = dtype
  def init(key: KeyArray,
           shape: core.Shape,
           dtype: DTypeLikeInexact = None) -> Array:
    if default_dtype is None and dtype is None:
      # Both are not specified, use a default and cast the parameter
      dtype = jnp.float_
      scale_ = lax.convert_element_type(scale, dtype)
    elif dtype is None and default_dtype is not None:
      # Fallback to outer dtype and don't cast
      dtype = default_dtype
      scale_ = scale
    elif dtype is not None and default_dtype is None:
      # Use inner dtype and don't cast
      scale_ = scale
    else:
      # Both outer and innert dtype is supplied, prefer inner and don't cast
      scale_ = scale
    
    # Proceed as usual without any additional casting
    if len(shape) < 2:
      raise ValueError("orthogonal initializer requires at least a 2D shape")
    n_rows, n_cols = math.prod(shape) // shape[column_axis], shape[column_axis]
    matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
    A = random.normal(key, matrix_shape, dtype)
    Q, R = jnp.linalg.qr(A)
    diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
    Q *= diag_sign # needed for a uniform distribution
    if n_rows < n_cols: Q = Q.T
    Q = jnp.reshape(Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis],))
    Q = jnp.moveaxis(Q, -1, column_axis)
    return scale_ * Q
  return init

The logic is getting a bit complicated but maybe we can figure out how to simplify it.

ethanluoyc avatar Dec 07 '23 18:12 ethanluoyc

That seems like a lot of boilerplate... seems like we want something like this?

def orthogonal(scale: RealNumeric = 1.0,
               column_axis: int = -1,
               dtype = None) -> Initializer:
  default_dtype = dtypes.canonicalize_dtype(float)
  outer_dtype = dtype
  def init(key: KeyArray,
           shape: core.Shape,
           dtype: DTypeLikeInexact = None) -> Array:
    inner_dtype = dtype
    dtype = dtype or outer_dtype or default_dtype
    scale_ = jnp.array(scale, dtype=None if inner_dtype is None and outer_dtype is None else dtype)
    
    # Proceed as usual without any additional casting
    # ...

jakevdp avatar Dec 07 '23 19:12 jakevdp

@jakevdp Happy new year! Sorry I was late to this as I was traveling. I updated the initializers according to your suggestion.

ethanluoyc avatar Jan 02 '24 17:01 ethanluoyc

It looks like there are some CI failures – please take a look!

jakevdp avatar Jan 02 '24 18:01 jakevdp

@jakevdp Sorry for getting back on this really late. I fixed the lint error, but I wasn't so sure what the error on 3.12 build is yet.

ethanluoyc avatar Feb 16 '24 11:02 ethanluoyc

You'll need to rebase on the main branch to pick up the last several months of updates.

jakevdp avatar Feb 16 '24 16:02 jakevdp

@jakevdp I fixed the broken test cases in stax, but I am not so comfortable with having an asymmetry between the jnp.float_ for some initializers and None for others. WDYT?

ethanluoyc avatar Feb 17 '24 20:02 ethanluoyc

Hi - sorry for having lost context on this in the last few months, but I'm not sure I understand what asymmetry you're referring to.

jakevdp avatar Feb 20 '24 18:02 jakevdp

After the PR the outer type would default to None for the initializers touched by this PR, but there are a couple of initializers which defaults to jnp.float_. Do we want to update those as well?

On Tue, 20 Feb 2024 at 18:11, Jake Vanderplas @.***> wrote:

Hi - sorry for having lost context on this in the last few months, but I'm not sure I understand what asymmetry you're referring to.

— Reply to this email directly, view it on GitHub https://github.com/google/jax/pull/18266#issuecomment-1954794247, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABOCZOEAFXTMHYVCNUUEKK3YUTRN5AVCNFSM6AAAAAA6OPK6YGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNJUG44TIMRUG4 . You are receiving this because you authored the thread.Message ID: @.***>

ethanluoyc avatar Feb 20 '24 18:02 ethanluoyc