ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Pullback fails to inline with intervals

Open dpsanders opened this issue 4 years ago • 3 comments

The following MWE does not inline when using intervals from IntervalArithmetic.jl:

function fff(x, y)
    z, z_pullback = rrule(*, x, y)
    
    z̄ = one(x)
    _, r1, r2 = z_pullback(z̄)
    
    x̄ = unthunk(r1)
    ȳ = unthunk(r2)

    return (x̄, ȳ)
end

using IntervalArithmetic

julia> @code_native f(1..1, 2..2)
	.section	__TEXT,__text,regular,pure_instructions
; ┌ @ REPL[23]:1 within `f'
	pushq	%rbx
	subq	$160, %rsp
	movq	%rdi, %rbx
; │ @ REPL[23]:2 within `f'
	movabsq	$rrule, %rax
	leaq	112(%rsp), %rdi
	callq	*%rax
	movabsq	$5516288192, %rax       ## imm = 0x148CBE0C0
; │ @ REPL[23]:5 within `f'
; │┌ @ fastmath_able.jl:188 within `times_pullback'
	vmovaps	(%rax), %xmm0
	vmovups	%xmm0, 80(%rsp)
	vmovups	128(%rsp), %xmm1
	vmovups	144(%rsp), %xmm2
	vmovups	%xmm2, 96(%rsp)
	vmovups	%xmm0, 48(%rsp)
	vmovups	%xmm1, 64(%rsp)
; │└
; │ @ REPL[23]:7 within `f'
; │┌ @ thunks.jl:99 within `unthunk'
; ││┌ @ thunks.jl:98 within `Thunk'
	movabsq	$"#461", %rax
	leaq	16(%rsp), %rdi
	leaq	80(%rsp), %rsi
	callq	*%rax
; │└└
; │ @ REPL[23]:8 within `f'
; │┌ @ thunks.jl:99 within `unthunk'
; ││┌ @ thunks.jl:98 within `Thunk'
	movabsq	$"#462", %rax
	movq	%rsp, %rdi
	leaq	48(%rsp), %rsi
	callq	*%rax
; │└└
; │ @ REPL[23]:10 within `f'
	vmovups	(%rsp), %xmm0
	vmovups	%xmm0, 32(%rsp)
	vmovups	16(%rsp), %ymm0
	vmovups	%ymm0, (%rbx)
	movq	%rbx, %rax
	addq	$160, %rsp
	popq	%rbx
	vzeroupper
	retq
	nopw	%cs:(%rax,%rax)
	nopl	(%rax)
; └

Cf. the beautiful code when using floats:

julia> @code_native f(1.0, 2.0)
	.section	__TEXT,__text,regular,pure_instructions
; ┌ @ REPL[23]:1 within `f'
	movq	%rdi, %rax
; │ @ REPL[23]:10 within `f'
	vmovsd	%xmm1, (%rdi)
	vmovsd	%xmm0, 8(%rdi)
	retq
	nopl	(%rax)
; └

dpsanders avatar Aug 20 '20 01:08 dpsanders

I wonder if we need to open an unstream issue on julia itself about this.

oxinabox avatar Aug 20 '20 08:08 oxinabox

I tried

@inline function rrule(::typeof(*), x::Number, y::Number)
    @inline function times_pullback(ΔΩ)
        return (NO_FIELDS,  ΔΩ * y', x' * ΔΩ)
    end
    return x * y, times_pullback
end

which seems to work correctly with intervals.

Time to Inline All The Things?

dpsanders avatar Aug 20 '20 13:08 dpsanders

We can at least add it to @scalar_rule.

I am hesitant to add it to everything, nor to add it to the best practices; because visual noise. but we might want to at least mentioned it as a thing that can be considered in best practices?

oxinabox avatar Aug 22 '20 11:08 oxinabox