liesel
liesel copied to clipboard
Make transform a method
This PR introduces a Var.transform
method. Some notes:
Replacing GraphBuilder.transform
We currently have GraphBuilder.transform
. The new method is in fact intended as a replacement for GraphBuilder.transform
. Having Var.transform
as a method on Var
has the advantage that it is easier to find for users. Since the transformation is "doing something to a Var" and it does not actually require any functionality within the GraphBuilder, living as a method on a Var is a natural development for the transform method.
Behavior change
The method behaves similar to GraphBuilder.transform
, with a few notable differences:
- It will not use the default event space bijector from tensorflow by default. This change is made to encourage users to either select their desired bijector manually, which is often sensible, or to request an automatic bijector manually. In both cases, users are more aware of what they are doing.
- The new method accepts bijector instances in addition to bijector classes. In fact, passing an instance is the preferred way of passing a bijector. Passing a bijector class is only supported if you actually defined
*bijector_args
or*bijector_kwargs
to be passed to the bijector. This simplifies the code for the default case. More importantly, this fixes the graph representation after transformation, see below.
Using a bijector instance:
Compare to using the default event space bijector. Note that in this case, the same bijector is being used, but there are spurious edges from the nodes "v0" and "v1", the prior parameters, to the original variable "tau".
Deprecation and updated documentation
- I marked
GraphBuilder.transform
as deprecated and included directions towards the new method in its documentation. - I updated the usage of the transformation method in the tutorials
01a-transform.md
and04-mcycle.md
. As it turned out, usage in these tutorials was outdated anyway.
Notebook for testing
You can play around with the method in this notebook:
Related issues
- This PR closes #173
- This PR closes #93
- This PR improves upon the behavior described in #89, because the problem described in #89 does not occur if a bijector instance is passed.
A log message on the default bijector being used would be nice.
@jobrachem will transfer the core benefits of the Var.transform
method to the GraphBuilder.transform
: 1) accepting a bijector instance, 2) logging the default event space bijector
@jobrachem will continue as follows:
- Implement Var.transform
- Make API between Var.transform and GraphBuilder.transform consistent
- Add a note to the docs of GraphBuilder.transform, referring to Var.transform
I now harmonized the API between Var.transform
and GraphBuilder.transform
. Namely:
-
Var.transform
now has the same default behavior asGraphBuilder.transform
, i.e. it tries to use the default event space bijector from tensorflow probability. -
GraphBuilder.transform
now also accepts a bijector instance.
There is one small difference in how the methods handle it, if the user passes a bijector class: Var.transform
will raise a RuntimeError, if you pass a bijector class, but do not actually use any arguments for the bijector. GraphBuilder.transform
will emit a UserWarning
instead, in order to avoid breaking existing code.
@wiep gentle reminder :)
Hey @wiep, thanks a lot for your suggestions! I implemented them with only slight deviations. Your change to the docstring is fine with me, too 👍
Your comment about removing the code duplication also alerted me to two issues:
-
Var.transform
lacked a test for whether the default event space bijector exists. I added that test. - There was a subtle error in my implementation of using the default event space bijector. In the case of using the default bijector, we want to account for the possibility that the default bijector can change depending on the distribution's arguments. So the default bijector has to be obtained and initialized anew for every
.update()
call. This was not the case in my previous implementation. Not it is, so this problem is fixed.