edward2
edward2 copied to clipboard
Update the design of weight sampling in the BNN layers.
TFP has been updated such that tfp.distributions.*
objects can be initialized in a tape-safe manner. I.e., it's now possible to create a distribution within one tape, and use it in another tape. This test now would no longer fail even if we didn't create new weight RVs on each call.
However, we still need to be able to sample new values for the weights (also noted in https://github.com/google/edward2/commit/eb4f33c4eed9c0375623d0fd3add9c2d83b559ad), and be able to override this with tracers to enforce, for example, the use of the mean.
One option is to check if the layer weights are ed.RandomVariables, and if so, call .sample()
within the layer (and regularizers) as needed. Tracers could be updated to override .sample()
to return the mean if desired.