probability icon indicating copy to clipboard operation
probability copied to clipboard

tfp.stats.histogram cannot be compiled by XLA?

Open yellowdolphin opened this issue 2 years ago • 1 comments

Summary of problem

Applying tfp.stats.histogram on the data in a tf keras model breaks XLA compilation. The example code (see below) works on CPU/GPU but with TPU strategy raises:

InvalidArgumentError: 9 root error(s) found.
  (0) INVALID_ARGUMENT: {{function_node __inference_train_function_3630}} Input 1 to node `sequential/lambda/histogram/count_integers/map/while/bincount/Bincount` with op Bincount must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

	 [[{{node sequential/lambda/histogram/count_integers/map/while/bincount/Bincount}}]]

	 [[sequential/lambda/histogram/count_integers/map/while]]
	 [[TPUReplicate/_compile/_10395939635067805685/_4]]

Reproducible example

https://colab.research.google.com/drive/1g9yHihhmcAcwEeE80wWPwyI8W6D6BGfx?usp=sharing

yellowdolphin avatar Oct 06 '23 20:10 yellowdolphin

Hi @yellowdolphin , did you try to dig in a bit? This would be very helpful! As the error suggests, it looks to me that count_integers and inside that bincount creates some troubles, can you compile these separately?

jonas-eschle avatar Apr 15 '24 16:04 jonas-eschle