Simplex transform upcasts float32 input to float64
This is leading to failing tests in https://github.com/pymc-devs/pymc/pull/5170
from aeppl.transforms import Simplex
import aesara.tensor as at
x = at.fmatrix('x')
assert x.dtype == "float32"
forw_x = Simplex().forward(x)
assert forw_x.dtype == "float32" # Raises
It happens because of this division by shape which is by default int64:
https://github.com/aesara-devs/aeppl/blob/808236f0b20cd126316e226e2ffe8299fa3a5533/aeppl/transforms.py#L268
Is this a request to make all transforms return the same precision outputs as their inputs (excluding the inputs' shapes from consideration, of course)?
If so, that would seem to impose some constraints on the shapes (i.e. that they fit into int32); otherwise, we'll need to upcast and then downcast, and—in that case—we're better off downcasting the results in PyMC, because, if we do it in AePPL, the correct/accurate results might not be available to anyone.
Yeah I am not sure what's the best. In this particular case where the int32 comes from a shape variable it doesn't sound to restrictive to compute with int32 precision (or recast the result to value.dtype) but in other transforms this could be much more restrictive (say when involving a distribution parameter)
On the other hand the change of precision makes graph replacements with transforms a little bit annoying because we have to always safeguard with a cast.
The point of this issue was indeed to discuss what is the best strategy
If we can avoid unnecessary upcasting (e.g. #80), then we'll do that, but forcibly downcasting results is a very problematic and confounding endeavor.
As a rule, I'm inclined to say that we cannot make broad input/output precision equality guarantees in Aesara projects, especially since NumPy openly makes no attempt to do so themselves, and we're not going to somehow solve all the surrounding consistency problems and decisions better than them.
The situation in this issue is that one of the arguments to Simplex.forward is the shape of x, which consists of int64s, so, taken altogether, the arguments are a float32 and an int64; based on those, how would one determine the appropriate output type? Per NumPy's logic, a simple arithmetic combination of the two types is a float64, which definitely sounds safe in terms of loss-of-precision, so we're doing the consistent thing already.
Regardless, any additional logic—e.g. that would attempt to determine when something is safe to downcast and whatnot—would need to be justified, and I don't see the case for PyMC's current assumption(s), aside, perhaps, for an easily naive notion of memory efficiency or niche architecture support that could only be properly addressed in Aesara (e.g. shapes default to int32 when int64 isn't available for some reason).
numpy is not a great guide here though, because they use Python integers for shapes.
This expression returns a float32 array (until x.shape ~ 2**16)
x = np.ones(5, dtype=np.float32)
x = x / np.ones(5).shape[0]
x.dtype # float32
Anyway not something straightforward
numpy is not a great guide here though, because they use Python integers for shapes.
This expression returns a float32 array (until x.shape ~ 2**16)
Unfortunately, we can't know the exact shape values when constructing the graphs, so we can't do what NumPy is doing there without making some broad assumptions (e.g. an Aesara-level setting that fixes shapes to be int32) in most cases.
That example also doesn't fit the context, because—as a function—its signature is effectively (float32, int32), and not the (float32, int64) situation we're discussing.
In other words, NumPy is a perfectly fine guide as long as you account for the context and the fact that Aesara is "strictly" typed and regular NumPy/Python use isn't (in the same sense).
Hi, I looked at this by chance because I made the pymc transformation that may has inspired this simplex transformation.
From my perspective I also not like to see int64 be cast to int32, however, I would not see an issue with casting value.shape[-1] to the type of log_value, which is likely float32 or float64. Do you think the resulting loss of precision would be acceptable?