ML icon indicating copy to clipboard operation
ML copied to clipboard

Not tracking tensors returned by `tf.reshape()` for data sources other than MNIST

Open khatchad opened this issue 1 year ago • 5 comments

Consider the following code:

# tf2_test_reshape.py
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/reshape

import tensorflow as tf


def f(a):
    pass


t1 = tf.ones([2, 3])
t2 = tf.reshape(t1, [6])
f(t2)

t2 should be a (reshaped) tensor, and the argument to f() should also be tracked as a tensor. Instead, I'm seeing this tensor analysis result:

[INFO] Tensor analysis: answer:
[Node: synthetic < PythonLoader, Ltensorflow/functions/reshape, do()LRoot; > Context: CallStringContext: [ script tf2_test_reshape.py.do()LRoot;@105 ], v2][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]
[Node: <Code body of function Lscript tf2_test_reshape.py> Context: CallStringContext: [ com.ibm.wala.FakeRootClass.fakeRootMethod()V@2 ], v245][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]
[Ret-V:Node: synthetic < PythonLoader, Ltensorflow/functions/ones, do()LRoot; > Context: CallStringContext: [ script tf2_test_reshape.py.do()LRoot;@100 ]][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]
[Node: synthetic < PythonLoader, Ltensorflow/functions/ones, do()LRoot; > Context: CallStringContext: [ script tf2_test_reshape.py.do()LRoot;@100 ], v5][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]

In the IR, v245 refers to the return value of tf.ones(). That's the only tensor in this file.

Regression

  • There is a summary for tf.reshape(), but it calls copy_data() instead of read_data().
  • There is special handling of tf.reshape() in the code, stemming from the MethodReference field com.ibm.wala.cast.python.ml.client.PythonTensorAnalysisEngine.reshape.

khatchad avatar May 08 '24 17:05 khatchad

Add this test works:

https://github.com/ponder-lab/ML/commit/2db36e27afd9399c70a788a6800b675c19505379

I believe the problem is that the data sources are hard-coded:

https://github.com/wala/ML/blob/ddba21e7881f2a9cc825f1857aea5a5ea89f1bc3/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java#L606-L610

khatchad avatar May 08 '24 18:05 khatchad

But, it could also have something to do with the way other APIs are being constructed. For example, the points-to set for tf.ones() is empty.

khatchad avatar May 08 '24 18:05 khatchad

Looking at the summary of tf.reshape(), I see that there's a data copy:

https://github.com/wala/ML/blob/ddba21e7881f2a9cc825f1857aea5a5ea89f1bc3/com.ibm.wala.cast.python.ml/data/tensorflow.xml#L379-L388

That may mean if the data source has something wrong with it, any copied data would also have the problem. Thus, the problem may not be with the tf.reshape() operation itself but rather with how data sources other than MNIST are constructed.

khatchad avatar May 09 '24 13:05 khatchad

That being said, copy_data() in the above summary doesn't use its argument.

khatchad avatar May 09 '24 13:05 khatchad

Thus, my best guess is that the problem involves a combination of the (new) XML summaries and the hard-coded initialization of the dataflow.

khatchad avatar May 09 '24 14:05 khatchad

...

Regression

  • There is a summary for tf.reshape(), but it calls copy_data() instead of read_data().

That sounds correct since reshape() wouldn't be creating new data.

khatchad avatar Jun 17 '25 19:06 khatchad

That being said, copy_data() in the above summary doesn't use its argument.

In fact, the original code seems to just return a new MNIST dataset.

khatchad avatar Jun 17 '25 19:06 khatchad

Looking at the summary of tf.reshape(), I see that there's a data copy:

ML/com.ibm.wala.cast.python.ml/data/tensorflow.xml

Lines 379 to 388 in ddba21e

That may mean if the data source has something wrong with it, any copied data would also have the problem.

But we can see from the above summary that the data isn't actually being copied; instead, a new MNIST dataset is created and returned and the argument is ignored.

Thus, the problem may not be with the tf.reshape() operation itself but rather with how data sources other than MNIST are constructed.

If you pass a reference to an MNIST dataset, it would make sense that this summary works because, as the argument is ignored, another dataset of the same kind is returned.

khatchad avatar Jun 19 '25 14:06 khatchad

...

Regression

  • There is a summary for tf.reshape(), but it calls copy_data() instead of read_data().

That sounds correct since reshape() wouldn't be creating new data.

But copy_data doesn't actually copy data; it returns a new dataset.

khatchad avatar Jun 19 '25 14:06 khatchad

Thus, my best guess is that the problem involves a combination of the (new) XML summaries and the hard-coded initialization of the dataflow.

The new summary also has some problems; see https://github.com/wala/ML/issues/265.

khatchad avatar Jun 19 '25 14:06 khatchad

Looks like we have an error when testing https://github.com/ponder-lab/ML/blob/870a2f509a5acc4a81f34979854d32108eb54323/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java#L3346-L3348, which has input file https://github.com/ponder-lab/ML/blob/870a2f509a5acc4a81f34979854d32108eb54323/com.ibm.wala.cast.python.test/data/tf2_test_reshape.py:

{[Node: <Code body of function Lscript tf2_test_reshape.py> Context: CallStringContext: [ com.ibm.wala.FakeRootClass.fakeRootMethod()V@2 ], v255]=Cannot reshape pixel[n][28 * 28] to pixel[6]}

khatchad avatar Jun 19 '25 17:06 khatchad

This error looks to be a shape mismatch, but it's also affecting the tensor type inference. If there's a tensor shape mismatch, tensor types aren't propagated as they normally would be. Is this intended/desirable?

khatchad avatar Jun 19 '25 18:06 khatchad

Looks like we have an error when testing https://github.com/ponder-lab/ML/blob/870a2f509a5acc4a81f34979854d32108eb54323/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java#L3346-L3348, which has input file https://github.com/ponder-lab/ML/blob/870a2f509a5acc4a81f34979854d32108eb54323/com.ibm.wala.cast.python.test/data/tf2_test_reshape.py:

{[Node: <Code body of function Lscript tf2_test_reshape.py> Context: CallStringContext: [ com.ibm.wala.FakeRootClass.fakeRootMethod()V@2 ], v255]=Cannot reshape pixel[n][28 * 28] to pixel[6]}

This error isn't really correct for this file in the sense that the starting tensor doesn't have these initial dimensions. These dimensions are from the MNIST dataset, but t1 doesn't refer to that.

If I run the file, the tf.reshape() operation goes through fine.

khatchad avatar Jun 19 '25 18:06 khatchad

Add this test works:

ponder-lab@2db36e2

I believe the problem is that the data sources are hard-coded:

ML/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java

Lines 606 to 610 in ddba21e

TensorType mnistData = TensorType.mnistInput(); Map<PointsToSetVariable, TensorType> init = HashMapFactory.make(); for (PointsToSetVariable v : sources) { init.put(v, mnistData); }

This works because this file is using MNIST.

khatchad avatar Jun 19 '25 19:06 khatchad

Indeed, setting t2 to dimensions, i.e., [-1, 28, 28, 1], compatible with MNIST, i.e., [28, 28], works.

khatchad avatar Jun 19 '25 19:06 khatchad

Thus, this issue is a byproduct of two different problems:

  1. Regardless of their actual initial dimensions, all tensors are initialized with the MNIST dataset dimensions. https://github.com/wala/ML/issues/267
  2. Tensor tracking stops if there's a shape mismatch during (at least) reshaping options. #266

khatchad avatar Jun 19 '25 19:06 khatchad

This error looks to be a shape mismatch, but it's also affecting the tensor type inference. If there's a tensor shape mismatch, tensor types aren't propagated as they normally would be. Is this intended/desirable?

It might be. If we can't track reshapes, target tensors won't have accurate shapes. However, right now, every tensor besides those from MNIST have inaccurate shapes per https://github.com/wala/ML/issues/267.

khatchad avatar Jun 19 '25 19:06 khatchad