Enzyme.jl
Enzyme.jl copied to clipboard
Illegal type analysis error with `setindex!` + `ifelse` + mixed float types
I found this when trying to diff through https://github.com/FluxML/NNlib.jl/blob/acf87f5316e7579ac1e7eb16a278f43a9ca435dc/src/softmax.jl#L115.
MWE:
using Enzyme
function f(out, x)
out[1] = ifelse(isequal(x[1], Inf), -Inf, x[1])
out[1]
end
x = ones(Float32, 1)
@show f(copy(x), copy(x))
dx = zero(x)
Enzyme.autodiff(
Reverse,
f,
Active,
Duplicated(copy(x), copy(dx)),
Duplicated(x, dx)
)
@show x dx
Looking at the generated IR, it appears Julia is union splitting the result of the ifelse
and only truncating if it turns out to be Float64. Which seems a little unnecessary since LLVM ends up promoting it to a double unconditionally anyhow. Strange codegen aside, this specific example should be relatively easy to fix on the NNlib side, but I suspect there are many more examples lurking out there.
Error:
ERROR: LoadError: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define float @preprocess_julia_f_1435_inner.1({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) local_unnamed_addr #6 !dbg !68 {
entry:
%2 = alloca float, align 4
%.0.sroa_cast4 = bitcast float* %2 to i8*
call void @llvm.lifetime.start.p0i8(i64 4, i8* %.0.sroa_cast4)
%3 = call {}*** @julia.get_pgcstack() #7
%4 = bitcast {} addrspace(10)* %1 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !69
%5 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %4 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !69
%6 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %5, i64 0, i32 1, !dbg !69
%7 = load i64, i64 addrspace(11)* %6, align 8, !dbg !69, !tbaa !12, !range !17, !alias.scope !18, !noalias !21
%.not = icmp eq i64 %7, 0, !dbg !69
br i1 %.not, label %oob.i, label %idxend2.i, !dbg !69
L16.i: ; preds = %idxend2.i
%8 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !72
%9 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %8 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !72
%10 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %9, i64 0, i32 1, !dbg !72
%11 = load i64, i64 addrspace(11)* %10, align 8, !dbg !72, !tbaa !12, !range !17, !alias.scope !18, !noalias !21
%.not10 = icmp eq i64 %11, 0, !dbg !72
br i1 %.not10, label %oob3.i, label %idxend4.i, !dbg !72
L21.i: ; preds = %idxend2.i
%12 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !72
%13 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %12 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !72
%14 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %13, i64 0, i32 1, !dbg !72
%15 = load i64, i64 addrspace(11)* %14, align 8, !dbg !72, !tbaa !12, !range !17, !alias.scope !18, !noalias !21
%.not13 = icmp eq i64 %15, 0, !dbg !72
br i1 %.not13, label %oob7.i, label %idxend8.i, !dbg !72
oob.i: ; preds = %entry
%16 = alloca i64, align 8, !dbg !69
store i64 1, i64* %16, align 8, !dbg !69, !noalias !73
%17 = addrspacecast {} addrspace(10)* %1 to {} addrspace(12)*, !dbg !69
call void @ijl_bounds_error_ints({} addrspace(12)* noundef %17, i64* noundef nonnull align 8 %16, i64 noundef 1) #8, !dbg !69
unreachable, !dbg !69
idxend2.i: ; preds = %entry
%18 = bitcast {} addrspace(10)* %1 to float addrspace(13)* addrspace(10)*, !dbg !69
%19 = addrspacecast float addrspace(13)* addrspace(10)* %18 to float addrspace(13)* addrspace(11)*, !dbg !69
%20 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %19, align 16, !dbg !69, !tbaa !32, !alias.scope !76, !noalias !21, !nonnull !6
%21 = load float, float addrspace(13)* %20, align 4, !dbg !69, !tbaa !35, !alias.scope !38, !noalias !39
%22 = bitcast float %21 to i32, !dbg !77
%23 = icmp slt i32 %22, 0, !dbg !80
%24 = fcmp une float %21, 0x7FF0000000000000, !dbg !81
%25 = or i1 %24, %23, !dbg !83
store float %21, float* %2, align 4, !dbg !83, !noalias !73
%.0.sroa_cast3 = addrspacecast float* %2 to double addrspace(11)*, !dbg !83
%26 = select i1 %25, double addrspace(11)* %.0.sroa_cast3, double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*), !dbg !83
br i1 %25, label %L16.i, label %L21.i, !dbg !84
oob3.i: ; preds = %L16.i
%27 = alloca i64, align 8, !dbg !72
store i64 1, i64* %27, align 8, !dbg !72, !noalias !73
%28 = addrspacecast {} addrspace(10)* %0 to {} addrspace(12)*, !dbg !72
call void @ijl_bounds_error_ints({} addrspace(12)* noundef %28, i64* noundef nonnull align 8 %27, i64 noundef 1) #8, !dbg !72
unreachable, !dbg !72
idxend4.i: ; preds = %L16.i
%29 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !72
%30 = addrspacecast float addrspace(13)* addrspace(10)* %29 to float addrspace(13)* addrspace(11)*, !dbg !72
%31 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %30, align 16, !dbg !72, !tbaa !32, !alias.scope !76, !noalias !21, !nonnull !6
%32 = bitcast double addrspace(11)* %26 to float addrspace(11)*, !dbg !72
%33 = load float, float addrspace(11)* %32, align 4, !dbg !72, !tbaa !58
store float %33, float addrspace(13)* %31, align 4, !dbg !72, !tbaa !35, !alias.scope !38, !noalias !85
br label %julia_f_1435_inner.exit, !dbg !84
oob7.i: ; preds = %L21.i
%34 = alloca i64, align 8, !dbg !72
store i64 1, i64* %34, align 8, !dbg !72, !noalias !73
%35 = addrspacecast {} addrspace(10)* %0 to {} addrspace(12)*, !dbg !72
call void @ijl_bounds_error_ints({} addrspace(12)* noundef %35, i64* noundef nonnull align 8 %34, i64 noundef 1) #8, !dbg !72
unreachable, !dbg !72
idxend8.i: ; preds = %L21.i
%36 = load double, double addrspace(11)* %26, align 4, !dbg !86, !tbaa !58
%37 = fptrunc double %36 to float, !dbg !86
%38 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !72
%39 = addrspacecast float addrspace(13)* addrspace(10)* %38 to float addrspace(13)* addrspace(11)*, !dbg !72
%40 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %39, align 16, !dbg !72, !tbaa !32, !alias.scope !76, !noalias !21, !nonnull !6
store float %37, float addrspace(13)* %40, align 4, !dbg !72, !tbaa !35, !alias.scope !38, !noalias !85
br label %julia_f_1435_inner.exit, !dbg !84
julia_f_1435_inner.exit: ; preds = %idxend8.i, %idxend4.i
%41 = phi float [ %37, %idxend8.i ], [ %33, %idxend4.i ]
%.0.sroa_cast5 = bitcast float* %2 to i8*, !dbg !89
call void @llvm.lifetime.end.p0i8(i64 4, i8* %.0.sroa_cast5), !dbg !89
ret float %41, !dbg !90
}
Type analysis state:
<analysis>
%24 = fcmp une float %21, 0x7FF0000000000000, !dbg !51: {[-1]:Integer}, intvals: {}
%25 = or i1 %24, %23, !dbg !55: {[-1]:Integer}, intvals: {}
%6 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %5, i64 0, i32 1, !dbg !7: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
%14 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %13, i64 0, i32 1, !dbg !26: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
%5 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %4 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%7 = load i64, i64 addrspace(11)* %6, align 8, !dbg !7, !tbaa !12, !range !17, !alias.scope !18, !noalias !21: {[-1]:Integer}, intvals: {}
%32 = bitcast double addrspace(11)* %26 to float addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Float@float}, intvals: {}
%33 = load float, float addrspace(11)* %32, align 4, !dbg !26, !tbaa !58: {[-1]:Float@float}, intvals: {}
float 0x7FF0000000000000: {[-1]:Float@float}, intvals: {}
%29 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%30 = addrspacecast float addrspace(13)* addrspace(10)* %29 to float addrspace(13)* addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%31 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %30, align 16, !dbg !26, !tbaa !32, !alias.scope !34, !noalias !21, !nonnull !6: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {}
%26 = select i1 %25, double addrspace(11)* %.0.sroa_cast3, double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*), !dbg !55: {[-1]:Pointer, [-1,0]:Float@float}, intvals: {}
%39 = addrspacecast float addrspace(13)* addrspace(10)* %38 to float addrspace(13)* addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%40 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %39, align 16, !dbg !26, !tbaa !32, !alias.scope !34, !noalias !21, !nonnull !6: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {}
%4 = bitcast {} addrspace(10)* %1 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%36 = load double, double addrspace(11)* %26, align 4, !dbg !60, !tbaa !58: {[-1]:Float@double}, intvals: {}
%37 = fptrunc double %36 to float, !dbg !60: {[-1]:Float@float}, intvals: {}
%38 = bitcast {} addrspace(10)* %0 to float addrspace(13)* addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
double 0xFFF0000000000000: {[-1]:Float@double}, intvals: {}
i64 0: {[-1]:Anything}, intvals: {0,}
%15 = load i64, i64 addrspace(11)* %14, align 8, !dbg !26, !tbaa !12, !range !17, !alias.scope !18, !noalias !21: {[-1]:Integer}, intvals: {}
%.not13 = icmp eq i64 %15, 0, !dbg !26: {[-1]:Integer}, intvals: {}
%9 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %8 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%11 = load i64, i64 addrspace(11)* %10, align 8, !dbg !26, !tbaa !12, !range !17, !alias.scope !18, !noalias !21: {[-1]:Integer}, intvals: {}
%.not10 = icmp eq i64 %11, 0, !dbg !26: {[-1]:Integer}, intvals: {}
%12 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%13 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %12 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%2 = alloca float, align 4: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {}
%.0.sroa_cast5 = bitcast float* %2 to i8*, !dbg !66: {[-1]:Pointer, [-1,0]:Float@float}, intvals: {}
{} addrspace(10)* %0: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
{} addrspace(10)* %1: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%21 = load float, float addrspace(13)* %20, align 4, !dbg !7, !tbaa !35, !alias.scope !38, !noalias !39: {[-1]:Float@float}, intvals: {}
%22 = bitcast float %21 to i32, !dbg !40: {[-1]:Float@float}, intvals: {}
%23 = icmp slt i32 %22, 0, !dbg !48: {[-1]:Integer}, intvals: {}
%41 = phi float [ %37, %idxend8.i ], [ %33, %idxend4.i ]: {[-1]:Float@float}, intvals: {}
%.0.sroa_cast3 = addrspacecast float* %2 to double addrspace(11)*, !dbg !55: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {}
%.0.sroa_cast4 = bitcast float* %2 to i8*: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {}
%10 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %9, i64 0, i32 1, !dbg !26: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
%.not = icmp eq i64 %7, 0, !dbg !7: {[-1]:Integer}, intvals: {}
%8 = bitcast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !26: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
i32 0: {[-1]:Anything}, intvals: {0,}
%18 = bitcast {} addrspace(10)* %1 to float addrspace(13)* addrspace(10)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%19 = addrspacecast float addrspace(13)* addrspace(10)* %18 to float addrspace(13)* addrspace(11)*, !dbg !7: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
%20 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %19, align 16, !dbg !7, !tbaa !32, !alias.scope !34, !noalias !21, !nonnull !6: {[-1]:Pointer, [-1,-1]:Float@float}, intvals: {}
@_j_const1 = private unnamed_addr constant double 0xFFF0000000000000: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*): {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
</analysis>
Illegal updateAnalysis prev:{[-1]:Pointer, [-1,0]:Float@float} new: {[-1]:Pointer, [-1,0]:Float@double}
val: %26 = select i1 %25, double addrspace(11)* %.0.sroa_cast3, double addrspace(11)* addrspacecast (double* @_j_const1 to double addrspace(11)*), !dbg !55 origin= %36 = load double, double addrspace(11)* %26, align 4, !dbg !60, !tbaa !58
Caused by:
Stacktrace:
[1] ifelse
@ ./essentials.jl:575
[2] f
@ <mwe>.jl:4
[3] f
@ <mwe>.jl:0
Stacktrace:
[1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:5330
[2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/YBQJk/src/api.jl:124
[3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:6927
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8177
[5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8690
[6] _thunk
@ ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8687 [inlined]
[7] cached_compilation
@ ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8725 [inlined]
[8] #s287#191
@ ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8783 [inlined]
[9] var"#s287#191"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ::Any, ::Any, ::Any, ::Any, tt::Any, ::Any, ::Any, ::Any, ::Any, ::Any)
@ Enzyme.Compiler ./none:0
[10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[11] thunk(::Val{0x00000000000082c2}, ::Type{Const{typeof(f)}}, ::Type{Active}, tt::Type{Tuple{Duplicated{Vector{Float32}}, Duplicated{Vector{Float32}}}}, ::Val{Enzyme.API.DEM_ReverseModeCombined}, ::Val{1}, ::Val{(false, false, false)}, ::Val{false})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/YBQJk/src/compiler.jl:8742
[12] autodiff(::EnzymeCore.ReverseMode{false}, ::Const{typeof(f)}, ::Type{Active}, ::Duplicated{Vector{Float32}}, ::Vararg{Duplicated{Vector{Float32}}})
@ Enzyme ~/.julia/packages/Enzyme/YBQJk/src/Enzyme.jl:199
[13] autodiff(::EnzymeCore.ReverseMode{false}, ::typeof(f), ::Type, ::Duplicated{Vector{Float32}}, ::Vararg{Duplicated{Vector{Float32}}})
@ Enzyme ~/.julia/packages/Enzyme/YBQJk/src/Enzyme.jl:214
[14] top-level scope
@ <mwe>.jl:12
in expression starting at <mwe>.jl:12
Yeah for the immediate future at least, you should go fix the union in NNLib here (which would also be a generic performance benefit without AD even).
I'm (on 0.11.2) also hitting
ERROR: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc i64 @preprocess_julia_partition__9106({} addrspace(10)* noundef nonnull readonly align 16 dereferenceable(40) %0, i64 signext %1, i64 signext %2, i64 signext %3, { {} addrspace(10)*, i8, i8 } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %4, i8 zeroext %5, { {} addrspace(10)*, i8, i8 } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %6, i64 signext %7) unnamed_addr #80 !dbg !5321 {
top:
However
Caused by:
Stacktrace:
[1] setindex!
@ ./array.jl:969
[2] partition!
@ ./sort.jl:1004
Is this related or should I try to reduce to a MWE?
That is unrelated, a minimal example is helpful @toollu
Closing as a ifelse on a union of different types is considered unsupported atm