liesel
liesel copied to clipboard
Implement easier access to and manipulation of var inputs
This PR contains only a few lines of changed code, but I think it can drastically improve quality of life while building and manipulating Liesel models.
Problem statement
I have found the work with variable inputs to be often quite cumbersome. I'll show what I mean in examples.
import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd
import jax.numpy as jnp
Example 1: Accessing parameters of a distribution
Let's say I want to access the loc
and scale
of a variable's distribution, starting from the variable.
def create_var1():
loc = lsl.Var(0.0, name="loc")
scale = lsl.Var(1.0, name="scale")
y = lsl.obs(jnp.zeros(10), lsl.Dist(tfd.Normal, loc=loc, scale=scale), name="y")
return y
y1 = create_var1()
The Var.dist_node
's kwinputs
attribute does not give me that access. Instead, it returns a dict of VarValue
nodes. That is node helpful.
>>> y1.dist_node.kwinputs
mappingproxy({'loc': VarValue(name="loc_var_value"),
'scale': VarValue(name="scale_var_value")})
To actually get to the loc
var, I have to call:
>>> y1.dist_node.kwinputs["loc"].var
Var(name="loc")
I think this is unnecessarily cumbersome. I have to actually know a lot about Liesel's internals and/or dig in the source code to find what I need. Every time I need to perform this action, I have to look up what to do and extensively test my code in order to be sure that I get it right. It is quite annoying.
Example 2: Accessing inputs of a calculator
A very similar pattern holds when you use calculators:
def create_var2():
x = lsl.Var(0.0, name="x")
y = lsl.Var(lsl.Calc(lambda x: x + 1, x), name="y")
return y
y2 = create_var2()
>>> y2.value_node.inputs
(VarValue(name="x_var_value"),)
>>> y2.value_node.inputs[0].var
Var(name="x")
Proposed Solution
I implemented Node.__getitem__
and Var.__getitem__
as a remedy.
The above tasks can now be solved like this:
>>> y1.dist_node["loc"] # access by arg name, here equivalent to y1.dist_node[0]
Var(name="loc")
>>> y2[0] # access by index, equivalent to y2.value_node[0]
Var(name="x")
Some details
- The basic implementation is done in
Node.__getitem__
. a) If it receives an integer, it will essentially look up the searched item inNode.all_input_nodes()
. This will find all inputs, including positional and keyword inputs. b) If it receives a string, it will essentially look up the searched item inNode.kwinputs
. This will of course find only inputs that are actually keyword inputs. -
Var.__getitem__
will defer to its value node. To access inputs to the distribution, users can turn toVar.dist_node.__getitem__
.
Setitem
The implementation also provides the possibility to replace inputs via Node.__setitem__
. Example:
# before
>>> y2[0]
Var(name="x")
>>> y2.value
1.0
# change input
>>> y2[0] = lsl.Var(3.0, name="new_input")
# after
>>> y2[0]
Var(name="new_input")
>>> y2.value
4.0
The same works for Dist
. The implementation is a thin quality-of-life wrapper around the existing Node.set_inputs()
method.
Documentation
This is a draft PR for a first discussion. Even if it remains unchanged, documentation has to be added if it ends up being merged.