liesel icon indicating copy to clipboard operation
liesel copied to clipboard

Implement easier access to and manipulation of var inputs

Open jobrachem opened this issue 6 months ago • 8 comments

This PR contains only a few lines of changed code, but I think it can drastically improve quality of life while building and manipulating Liesel models.

Problem statement

I have found the work with variable inputs to be often quite cumbersome. I'll show what I mean in examples.

import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd
import jax.numpy as jnp

Example 1: Accessing parameters of a distribution

Let's say I want to access the loc and scale of a variable's distribution, starting from the variable.

def create_var1():
    loc = lsl.Var(0.0, name="loc")
    scale = lsl.Var(1.0, name="scale")
    y = lsl.obs(jnp.zeros(10), lsl.Dist(tfd.Normal, loc=loc, scale=scale), name="y")
    return y

y1 = create_var1()

The Var.dist_node's kwinputs attribute does not give me that access. Instead, it returns a dict of VarValue nodes. That is node helpful.

>>> y1.dist_node.kwinputs
mappingproxy({'loc': VarValue(name="loc_var_value"),
              'scale': VarValue(name="scale_var_value")})

To actually get to the loc var, I have to call:

>>> y1.dist_node.kwinputs["loc"].var
Var(name="loc")

I think this is unnecessarily cumbersome. I have to actually know a lot about Liesel's internals and/or dig in the source code to find what I need. Every time I need to perform this action, I have to look up what to do and extensively test my code in order to be sure that I get it right. It is quite annoying.

Example 2: Accessing inputs of a calculator

A very similar pattern holds when you use calculators:

def create_var2():
    x = lsl.Var(0.0, name="x")
    y = lsl.Var(lsl.Calc(lambda x: x + 1, x), name="y")
    return y

y2 = create_var2()
>>> y2.value_node.inputs
(VarValue(name="x_var_value"),)
>>> y2.value_node.inputs[0].var
Var(name="x")

Proposed Solution

I implemented Node.__getitem__ and Var.__getitem__ as a remedy.

The above tasks can now be solved like this:

>>> y1.dist_node["loc"] # access by arg name, here equivalent to y1.dist_node[0]
Var(name="loc")
>>> y2[0] # access by index, equivalent to y2.value_node[0]
Var(name="x")

Some details

  1. The basic implementation is done in Node.__getitem__. a) If it receives an integer, it will essentially look up the searched item in Node.all_input_nodes(). This will find all inputs, including positional and keyword inputs. b) If it receives a string, it will essentially look up the searched item in Node.kwinputs. This will of course find only inputs that are actually keyword inputs.
  2. Var.__getitem__ will defer to its value node. To access inputs to the distribution, users can turn to Var.dist_node.__getitem__.

Setitem

The implementation also provides the possibility to replace inputs via Node.__setitem__. Example:

# before
>>> y2[0]
Var(name="x")
>>> y2.value
1.0

# change input
>>> y2[0] = lsl.Var(3.0, name="new_input")

# after
>>> y2[0]
Var(name="new_input")
>>> y2.value
4.0

The same works for Dist. The implementation is a thin quality-of-life wrapper around the existing Node.set_inputs() method.

Documentation

This is a draft PR for a first discussion. Even if it remains unchanged, documentation has to be added if it ends up being merged.

jobrachem avatar Aug 17 '24 21:08 jobrachem