brax
brax copied to clipboard
Summing param values
Not sure if I should be asking here or on one of Jax's other package pages, but is there an easy way of summing param values? Eg if I have two Params-type variables a and b of the same structure, how could I do: a = a * alpha + (1 - alpha) * b to update a's values by taking elementwise linear combinations of a and b's values? I am not planning on updating a via gradient descent at all, so I think it's fine if the method of doing this doesn't preserve differentiability.
What are the types of a and b? For any jax primitives, if a
is a scalar, this should work exactly as you wrote it:
a = a * alpha + (1 - alpha) * b
If a and b are JAX pytrees, same thing, just use tree map:
a = jax.tree_map(lambda a,b,alpha=alpha: a * alpha + (1 - alpha) * b, a_tree, b_tree)
Did I misunderstand some part of your question? Please follow up if so.