ChainRules.jl
ChainRules.jl copied to clipboard
Pullback fails to inline with intervals
The following MWE does not inline when using intervals from IntervalArithmetic.jl
:
function fff(x, y)
z, z_pullback = rrule(*, x, y)
z̄ = one(x)
_, r1, r2 = z_pullback(z̄)
x̄ = unthunk(r1)
ȳ = unthunk(r2)
return (x̄, ȳ)
end
using IntervalArithmetic
julia> @code_native f(1..1, 2..2)
.section __TEXT,__text,regular,pure_instructions
; ┌ @ REPL[23]:1 within `f'
pushq %rbx
subq $160, %rsp
movq %rdi, %rbx
; │ @ REPL[23]:2 within `f'
movabsq $rrule, %rax
leaq 112(%rsp), %rdi
callq *%rax
movabsq $5516288192, %rax ## imm = 0x148CBE0C0
; │ @ REPL[23]:5 within `f'
; │┌ @ fastmath_able.jl:188 within `times_pullback'
vmovaps (%rax), %xmm0
vmovups %xmm0, 80(%rsp)
vmovups 128(%rsp), %xmm1
vmovups 144(%rsp), %xmm2
vmovups %xmm2, 96(%rsp)
vmovups %xmm0, 48(%rsp)
vmovups %xmm1, 64(%rsp)
; │└
; │ @ REPL[23]:7 within `f'
; │┌ @ thunks.jl:99 within `unthunk'
; ││┌ @ thunks.jl:98 within `Thunk'
movabsq $"#461", %rax
leaq 16(%rsp), %rdi
leaq 80(%rsp), %rsi
callq *%rax
; │└└
; │ @ REPL[23]:8 within `f'
; │┌ @ thunks.jl:99 within `unthunk'
; ││┌ @ thunks.jl:98 within `Thunk'
movabsq $"#462", %rax
movq %rsp, %rdi
leaq 48(%rsp), %rsi
callq *%rax
; │└└
; │ @ REPL[23]:10 within `f'
vmovups (%rsp), %xmm0
vmovups %xmm0, 32(%rsp)
vmovups 16(%rsp), %ymm0
vmovups %ymm0, (%rbx)
movq %rbx, %rax
addq $160, %rsp
popq %rbx
vzeroupper
retq
nopw %cs:(%rax,%rax)
nopl (%rax)
; └
Cf. the beautiful code when using floats:
julia> @code_native f(1.0, 2.0)
.section __TEXT,__text,regular,pure_instructions
; ┌ @ REPL[23]:1 within `f'
movq %rdi, %rax
; │ @ REPL[23]:10 within `f'
vmovsd %xmm1, (%rdi)
vmovsd %xmm0, 8(%rdi)
retq
nopl (%rax)
; └
I wonder if we need to open an unstream issue on julia itself about this.
I tried
@inline function rrule(::typeof(*), x::Number, y::Number)
@inline function times_pullback(ΔΩ)
return (NO_FIELDS, ΔΩ * y', x' * ΔΩ)
end
return x * y, times_pullback
end
which seems to work correctly with intervals.
Time to Inline All The Things?
We can at least add it to @scalar_rule
.
I am hesitant to add it to everything, nor to add it to the best practices; because visual noise. but we might want to at least mentioned it as a thing that can be considered in best practices?