symbolic-pymc
symbolic-pymc copied to clipboard
Introduce a somewhat usable "metatize" for TF helper functions
This addresses #56 in another way; namely, it uses an intermediate/temporary TF graph that mirrors a given meta graph with the meta tensor terms replaced by Placeholders. The temporary TF graph is given to the TF function we want to metatize, the result is turned into a meta graph (i.e. "metatized") and the Placeholder stand-ins are replaced by the original meta tensors.
The reason this seems like a worthwhile approach: Placeholders have some flexibility for unknown shape and dtype information, so, when meta tensors use logic variables for those values, we have a workable mapping between meta tensors and valid TF tensors.
Naturally, this approach has its limits, and the reason is that some TF helper functions simply do not accept unknown shape and dtype input (i.e. "variant" dtype). However, the better we are about inferring/specifying dtype and shape information (when it's possible to do so) for meta objects, the better this approach will work.
Example
We start by making the TF graph we ultimately want in meta form:
import tensorflow as tf
from tensorflow.python.eager.context import graph_mode
from symbolic_pymc.tensorflow.meta import mt
from symbolic_pymc.tensorflow.printing import tf_dprint
with graph_mode():
# Create an identity matrix with the number of rows derived from another
# matrix's shape in TF.
A_tf = tf.compat.v1.placeholder(tf.float64, name='A',
shape=tf.TensorShape([None, None]))
A_shape_tf = tf.shape(A_tf)
A_rows_tf = A_shape_tf[0]
# The TF function for this is `tf.eye`.
I_A_tf = tf.eye(A_rows_tf)
In an ideal world, there would be an OpDef behind the function tf.eye, but, since there isn't, we have to build an equivalent meta graph by hand. The meta graph should mirror the TF graph for I_A_tf, so we can always inspect I_A_tf to see what tf.eye constructed from its inputs (i.e. A_rows_tf):
>>> tf_dprint(I_A_tf)
Tensor(MatrixDiag):0, shape=[None, None] "eye/MatrixDiag:0"
| Op(MatrixDiag) "eye/MatrixDiag"
| | Tensor(Fill):0, shape=[None] "eye/ones:0"
| | | Op(Fill) "eye/ones"
| | | | Tensor(ConcatV2):0, shape=[1] "eye/concat:0"
| | | | | Op(ConcatV2) "eye/concat"
| | | | | | Tensor(Const):0, shape=[0] "eye/shape:0"
| | | | | | Tensor(Pack):0, shape=[1] "eye/concat/values_1:0"
| | | | | | | Op(Pack) "eye/concat/values_1"
| | | | | | | | Tensor(Minimum):0, shape=[] "eye/Minimum:0"
| | | | | | | | | Op(Minimum) "eye/Minimum"
| | | | | | | | | | Tensor(StridedSlice):0, shape=[] "strided_slice:0"
| | | | | | | | | | | Op(StridedSlice) "strided_slice"
| | | | | | | | | | | | Tensor(Shape):0, shape=[2] "Shape:0"
| | | | | | | | | | | | | Op(Shape) "Shape"
| | | | | | | | | | | | | | Tensor(Placeholder):0, shape=[None, None] "A:0"
| | | | | | | | | | | | Tensor(Const):0, shape=[1] "strided_slice/stack:0"
| | | | | | | | | | | | Tensor(Const):0, shape=[1] "strided_slice/stack_1:0"
| | | | | | | | | | | | Tensor(Const):0, shape=[1] "strided_slice/stack_2:0"
| | | | | | | | | | Tensor(StridedSlice):0, shape=[] "strided_slice:0"
| | | | | | | | | | | ...
| | | | | | Tensor(Const):0, shape=[] "eye/concat/axis:0"
| | | | Tensor(Const):0, shape=[] "eye/ones/Const:0"
Basically, reconstructing graphs like these by hand involves reproducing the steps in the function tf.eye.
With the TF function "metatizing" in this PR, the process is much simpler:
with graph_mode():
A_mt = mt(A_tf)
A_shape_mt = mt.shape(A_mt)
# There's still work to do to make things easier...
A_rows_mt = mt.StridedSlice(A_shape_mt, [0], [1], [1], shrink_axis_mask=1)
I_A_mt = mt.metatize_tf_function(tf.eye, A_rows_mt)
Now, if we convert the meta graph I_A_mt into a TF graph and print the results, we see essentially the same results as the original tf.eye, which verifies the correspondence between the two graphs:
>>> tf_dprint(I_A_mt.reify())
Tensor(MatrixDiag):0, shape=[None, None] "eye_5/MatrixDiag_1:0"
| Op(MatrixDiag) "eye_5/MatrixDiag_1"
| | Tensor(Fill):0, shape=[None] "eye_5/ones_1:0"
| | | Op(Fill) "eye_5/ones_1"
| | | | Tensor(ConcatV2):0, shape=[1] "eye_5/concat_1:0"
| | | | | Op(ConcatV2) "eye_5/concat_1"
| | | | | | Tensor(Const):0, shape=[0] "eye_5/shape:0"
| | | | | | Tensor(Pack):0, shape=[1] "eye_5/concat/values_1_1:0"
| | | | | | | Op(Pack) "eye_5/concat/values_1_1"
| | | | | | | | Tensor(Minimum):0, shape=[] "eye_5/Minimum_1:0"
| | | | | | | | | Op(Minimum) "eye_5/Minimum_1"
| | | | | | | | | | Tensor(StridedSlice):0, shape=[] "StridedSlice_4:0"
| | | | | | | | | | | Op(StridedSlice) "StridedSlice_4"
| | | | | | | | | | | | Tensor(Shape):0, shape=[2] "Shape_5:0"
| | | | | | | | | | | | | Op(Shape) "Shape_5"
| | | | | | | | | | | | | | Tensor(Placeholder):0, shape=[None, None] "A_1:0"
| | | | | | | | | | | | Tensor(Const):0, shape=[1] "StridedSlice_4/begin:0"
| | | | | | | | | | | | Tensor(Const):0, shape=[1] "StridedSlice_4/end:0"
| | | | | | | | | | | | Tensor(Const):0, shape=[1] "StridedSlice_4/strides:0"
| | | | | | | | | | Tensor(StridedSlice):0, shape=[] "StridedSlice_4:0"
| | | | | | | | | | | ...
| | | | | | Tensor(Const):0, shape=[] "eye_5/concat/axis:0"
| | | | Tensor(Const):0, shape=[] "eye_5/ones/Const:0"