clad
clad copied to clipboard
Incorrect type of clad tape in reverse mode
Minimum reproducible example:
float func(float z, int a) {
for (int i = 1; i < a; i++){
z = z * a;
}
return z;
}
results in the following failed assertion
Assertion `(isGenericMethod || ((*I)->isVariablyModifiedType() || (*I).getNonReferenceType()->isObjCRetainableType() || getContext() .getCanonicalType((*I).getNonReferenceType()) .getTypePtr() == getContext() .getCanonicalType((*Arg)->getType()) .getTypePtr())) && "type mismatch in call argument!"' failed.
The possible culprit here is revealed while looking at the code dump
void func_grad(float z, int a, float *_result) {
unsigned long _t0;
int _d_i = 0;
clad::tape<float> _t1 = {}; <--- the type here should be int instead of float as it stores an int
clad::tape<float> _t2 = {};
_t0 = 0;
for (int i = 1; i < a; i++) {
_t0++;
z = clad::push(_t2, z) * clad::push(_t1, a); <-- _t1 storing an int here
}
float func_return = z;
goto _label0;
_label0:
_result[0UL] += 1;
for (; _t0; _t0--) {
{
float _r_d0 = _result[0UL];
float _r0 = _r_d0 * clad::pop(_t1);
_result[0UL] += _r0;
float _r1 = clad::pop(_t2) * _r_d0;
_result[1UL] += _r1;
_result[0UL] -= _r_d0;
}
}
}
The problem here is the fact that the tape is initialized with a different type that is being pushed to it. This happens because we initialize the tape with a variable wrapped under implicit casts (the cast type is dependent on the expression result and type promotion) hence we miss the actual type of the variable. There are two possible solutions to this problem:
-
Strip all implicit casts before deciding the type of the tape
-
Build another implicit cast while pushing to the tape
Out of which, the first options seems the suitable solution to me.
May be related to #214
@grimmmyshini yeah the culprit is an ImplicitCastExpr
which we can verify when we call dump on the expression on which MakeCladTapeFor
is used:
ImplicitCastExpr 0x5555629ce030 'float' <IntegralToFloating>
`-ImplicitCastExpr 0x5555629ce018 'int' <LValueToRValue>
`-DeclRefExpr 0x5555629cdfe0 'int' lvalue ParmVar 0x5555629cdcd0 'a' 'int'
We can either strip till the <LValueToRValue> ImplicitCastExpr
or add the LValueToRValue
cast on our own.
Which AST node causes the type mismatch? Is it clad::push(_t1, a)
? Is the type of this a
in clad::push(_t1, a)
int
(without implicit cast to float
)?
Which AST node causes the type mismatch? Is it
clad::push(_t1, a)
?
Yeah this node causes type mismatch because the type of _t1
is determined using this AST node:
ImplicitCastExpr 0x5555629ce030 'float' <IntegralToFloating>
`-ImplicitCastExpr 0x5555629ce018 'int' <LValueToRValue>
`-DeclRefExpr 0x5555629cdfe0 'int' lvalue ParmVar 0x5555629cdcd0 'a' 'int'
While a is stored in the tape without the Implicit Cast to float
Is the type of this
a
inclad::push(_t1, a)
int
(without implicit cast tofloat
)?
Yeah