brax icon indicating copy to clipboard operation
brax copied to clipboard

Summing param values

Open verityw opened this issue 3 years ago • 1 comments

Not sure if I should be asking here or on one of Jax's other package pages, but is there an easy way of summing param values? Eg if I have two Params-type variables a and b of the same structure, how could I do: a = a * alpha + (1 - alpha) * b to update a's values by taking elementwise linear combinations of a and b's values? I am not planning on updating a via gradient descent at all, so I think it's fine if the method of doing this doesn't preserve differentiability.

verityw avatar Jan 20 '22 22:01 verityw

What are the types of a and b? For any jax primitives, if a is a scalar, this should work exactly as you wrote it:

a = a * alpha + (1 - alpha) * b

If a and b are JAX pytrees, same thing, just use tree map:

a = jax.tree_map(lambda a,b,alpha=alpha: a * alpha + (1 - alpha) * b, a_tree, b_tree)

Did I misunderstand some part of your question? Please follow up if so.

erikfrey avatar Feb 06 '22 23:02 erikfrey