julia icon indicating copy to clipboard operation
julia copied to clipboard

Improve the type stability of `sqrt(::Complex)`

Open mtfishman opened this issue 1 year ago • 4 comments

More specifically, this improves the type stability of Base.ssqs(x::T, y::T) where T<:Real, which is called by sqrt(::Complex).

Here's a demonstration:

`@code_warntype Base.ssqs(1.0, 2.0)`: this pull request

julia> @code_warntype Base.ssqs(1.0, 2.0)
MethodInstance for Base.ssqs(::Float64, ::Float64)
  from ssqs(x::T, y::T) where T<:Real @ Base complex.jl:516
Static Parameters
  T = Float64
Arguments
  #self#::Core.Const(Base.ssqs)
  x::Float64
  y::Float64
Locals
  yk::Float64
  xk::Float64
  m::Float64
  ρ::Float64
  k::Int64
  @_9::Float64
  @_10::Bool
  @_11::Bool
  @_12::Bool
  @_13::Int64
Body::Tuple{Float64, Int64}
1 ──        Core.NewvarNode(:(yk))
│           Core.NewvarNode(:(xk))
│           Core.NewvarNode(:(m))
│           (k = 0)
│    %5   = Base.:+::Core.Const(+)
│    %6   = (x * x)::Float64
│    %7   = (y * y)::Float64
│           (ρ = (%5)(%6, %7))
│    %9   = Base.:!::Core.Const(!)
│    %10  = ρ::Float64
│    %11  = Base.isfinite(%10)::Bool
│    %12  = (%9)(%11)::Bool
└───        goto #7 if not %12
2 ── %14  = Base.isinf(x)::Bool
└───        goto #4 if not %14
3 ──        (@_10 = %14)
└───        goto #5
4 ──        (@_10 = Base.isinf(y))
5 ┄─ %19  = @_10::Bool
└───        goto #7 if not %19
6 ── %21  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│           (ρ = Base.convert(%21, Base.Inf))
└───        goto #25
7 ┄─ %24  = ρ::Float64
│    %25  = Base.isinf(%24)::Bool
└───        goto #9 if not %25
8 ──        goto #18
9 ── %28  = ρ::Float64
│    %29  = (%28 == 0)::Bool
└───        goto #14 if not %29
10 ─ %31  = (x != 0)::Bool
└───        goto #12 if not %31
11 ─        (@_12 = %31)
└───        goto #13
12 ─        (@_12 = y != 0)
13 ┄ %36  = @_12::Bool
│           (@_11 = %36)
└───        goto #15
14 ─        (@_11 = false)
15 ┄ %40  = @_11::Bool
└───        goto #17 if not %40
16 ─        goto #18
17 ─ %43  = Base.:<::Core.Const(<)
│    %44  = ρ::Float64
│    %45  = Base.:/::Core.Const(/)
│    %46  = Base.nextfloat::Core.Const(nextfloat)
│    %47  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %48  = Base.zero(%47)::Core.Const(0.0)
│    %49  = (%46)(%48)::Core.Const(5.0e-324)
│    %50  = Base.:*::Core.Const(*)
│    %51  = Base.:^::Core.Const(^)
│    %52  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %53  = Base.eps(%52)::Core.Const(2.220446049250313e-16)
│    %54  = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│    %55  = (%54)()::Core.Const(Val{2}())
│    %56  = Base.literal_pow(%51, %53, %55)::Core.Const(4.930380657631324e-32)
│    %57  = (%50)(2, %56)::Core.Const(9.860761315262648e-32)
│    %58  = (%45)(%49, %57)::Core.Const(5.010420900022432e-293)
│    %59  = (%43)(%44, %58)::Bool
└───        goto #25 if not %59
18 ┄ %61  = Base.max::Core.Const(max)
│    %62  = Base.abs(x)::Float64
│    %63  = Base.abs(y)::Float64
│    %64  = (%61)(%62, %63)::Float64
│           (@_9 = %64)
│    %66  = @_9::Float64
│    %67  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %68  = (%66 isa %67)::Core.Const(true)
└───        goto #20 if not %68
19 ─        goto #21
20 ─        Core.Const(:($(Expr(:static_parameter, 1))))
│           Core.Const(:(@_9))
│           Core.Const(:(Base.convert(%71, %72)))
│           Core.Const(:($(Expr(:static_parameter, 1))))
└───        Core.Const(:(@_9 = Core.typeassert(%73, %74)))
21 ┄ %76  = @_9::Float64
│           (m = %76)
│    %78  = m::Float64
│    %79  = (%78 == 0)::Bool
└───        goto #23 if not %79
22 ─ %81  = k::Core.Const(0)
│           (@_13 = %81)
└───        goto #24
23 ─ %84  = Base.exponent::Core.Const(exponent)
│    %85  = m::Float64
└───        (@_13 = (%84)(%85))
24 ┄ %87  = @_13::Int64
│           (k = %87)
│    %89  = Base.ldexp::Core.Const(ldexp)
│    %90  = k::Int64
│    %91  = -%90::Int64
│    %92  = (%89)(x, %91)::Float64
│    %93  = Base.ldexp::Core.Const(ldexp)
│    %94  = k::Int64
│    %95  = -%94::Int64
│    %96  = (%93)(y, %95)::Float64
│           (xk = %92)
│           (yk = %96)
│    %99  = Base.:+::Core.Const(+)
│    %100 = xk::Float64
│    %101 = xk::Float64
│    %102 = (%100 * %101)::Float64
│    %103 = yk::Float64
│    %104 = yk::Float64
│    %105 = (%103 * %104)::Float64
└───        (ρ = (%99)(%102, %105))
25 ┄ %107 = ρ::Float64
│    %108 = k::Int64
│    %109 = Core.tuple(%107, %108)::Tuple{Float64, Int64}
└───        return %109


julia> @btime sqrt(1.0 + 2.0im)
  1.250 ns (0 allocations: 0 bytes)
1.272019649514069 + 0.7861513777574233im

julia> versioninfo()
Julia Version 1.12.0-DEV.753
Commit 4d0149d160* (2024-06-20 16:00 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin23.5.0)
  CPU: 10 × Apple M1 Max
  WORD_SIZE: 64
  LLVM: libLLVM-17.0.6 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
`@code_warntype Base.ssqs(1.0, 2.0)`: nightly

julia> @code_warntype Base.ssqs(1.0, 2.0)
MethodInstance for Base.ssqs(::Float64, ::Float64)
  from ssqs(x::T, y::T) where T<:Real @ Base complex.jl:516
Static Parameters
  T = Float64
Arguments
  #self#::Core.Const(Base.ssqs)
  x::Float64
  y::Float64
Locals
  yk::Float64
  xk::Float64
  m::Float64
  ρ::Float64
  k::Int64
  @_9::Int64
  @_10::Float64
  @_11::Union{Float64, Int64}
  @_12::Bool
  @_13::Bool
  @_14::Bool
  @_15::Union{Float64, Int64}
Body::Tuple{Float64, Int64}
1 ──        Core.NewvarNode(:(yk))
│           Core.NewvarNode(:(xk))
│           Core.NewvarNode(:(m))
│           Core.NewvarNode(:(ρ))
│           Core.NewvarNode(:(k))
│           (@_9 = 0)
│    %7   = @_9::Core.Const(0)
│    %8   = (%7 isa Base.Int)::Core.Const(true)
└───        goto #3 if not %8
2 ──        goto #4
3 ──        Core.Const(:(@_9))
│           Core.Const(:(Base.convert(Base.Int, %11)))
│           Core.Const(:(Base.Int))
└───        Core.Const(:(@_9 = Core.typeassert(%12, %13)))
4 ┄─ %15  = @_9::Core.Const(0)
│           (k = %15)
│    %17  = Base.:+::Core.Const(+)
│    %18  = (x * x)::Float64
│    %19  = (y * y)::Float64
│           (ρ = (%17)(%18, %19))
│    %21  = Base.:!::Core.Const(!)
│    %22  = ρ::Float64
│    %23  = Base.isfinite(%22)::Bool
│    %24  = (%21)(%23)::Bool
└───        goto #10 if not %24
5 ── %26  = Base.isinf(x)::Bool
└───        goto #7 if not %26
6 ──        (@_12 = %26)
└───        goto #8
7 ──        (@_12 = Base.isinf(y))
8 ┄─ %31  = @_12::Bool
└───        goto #10 if not %31
9 ── %33  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│           (ρ = Base.convert(%33, Base.Inf))
└───        goto #31
10 ┄ %36  = ρ::Float64
│    %37  = Base.isinf(%36)::Bool
└───        goto #12 if not %37
11 ─        goto #21
12 ─ %40  = ρ::Float64
│    %41  = (%40 == 0)::Bool
└───        goto #17 if not %41
13 ─ %43  = (x != 0)::Bool
└───        goto #15 if not %43
14 ─        (@_14 = %43)
└───        goto #16
15 ─        (@_14 = y != 0)
16 ┄ %48  = @_14::Bool
│           (@_13 = %48)
└───        goto #18
17 ─        (@_13 = false)
18 ┄ %52  = @_13::Bool
└───        goto #20 if not %52
19 ─        goto #21
20 ─ %55  = Base.:<::Core.Const(<)
│    %56  = ρ::Float64
│    %57  = Base.:/::Core.Const(/)
│    %58  = Base.nextfloat::Core.Const(nextfloat)
│    %59  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %60  = Base.zero(%59)::Core.Const(0.0)
│    %61  = (%58)(%60)::Core.Const(5.0e-324)
│    %62  = Base.:*::Core.Const(*)
│    %63  = Base.:^::Core.Const(^)
│    %64  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %65  = Base.eps(%64)::Core.Const(2.220446049250313e-16)
│    %66  = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│    %67  = (%66)()::Core.Const(Val{2}())
│    %68  = Base.literal_pow(%63, %65, %67)::Core.Const(4.930380657631324e-32)
│    %69  = (%62)(2, %68)::Core.Const(9.860761315262648e-32)
│    %70  = (%57)(%61, %69)::Core.Const(5.010420900022432e-293)
│    %71  = (%55)(%56, %70)::Bool
└───        goto #31 if not %71
21 ┄ %73  = Base.max::Core.Const(max)
│    %74  = Base.abs(x)::Float64
│    %75  = Base.abs(y)::Float64
│    %76  = (%73)(%74, %75)::Float64
│           (@_10 = %76)
│    %78  = @_10::Float64
│    %79  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %80  = (%78 isa %79)::Core.Const(true)
└───        goto #23 if not %80
22 ─        goto #24
23 ─        Core.Const(:($(Expr(:static_parameter, 1))))
│           Core.Const(:(@_10))
│           Core.Const(:(Base.convert(%83, %84)))
│           Core.Const(:($(Expr(:static_parameter, 1))))
└───        Core.Const(:(@_10 = Core.typeassert(%85, %86)))
24 ┄ %88  = @_10::Float64
│           (m = %88)
│    %90  = m::Float64
│    %91  = (%90 == 0)::Bool
└───        goto #26 if not %91
25 ─ %93  = m::Float64
│           (@_15 = %93)
└───        goto #27
26 ─ %96  = Base.exponent::Core.Const(exponent)
│    %97  = m::Float64
└───        (@_15 = (%96)(%97))
27 ┄ %99  = @_15::Union{Float64, Int64}
│           (@_11 = %99)
│    %101 = @_11::Union{Float64, Int64}
│    %102 = (%101 isa Base.Int)::Bool
└───        goto #29 if not %102
28 ─        goto #30
29 ─ %105 = @_11::Float64
│    %106 = Base.convert(Base.Int, %105)::Int64
│    %107 = Base.Int::Core.Const(Int64)
└───        (@_11 = Core.typeassert(%106, %107))
30 ┄ %109 = @_11::Int64
│           (k = %109)
│    %111 = Base.ldexp::Core.Const(ldexp)
│    %112 = k::Int64
│    %113 = -%112::Int64
│    %114 = (%111)(x, %113)::Float64
│    %115 = Base.ldexp::Core.Const(ldexp)
│    %116 = k::Int64
│    %117 = -%116::Int64
│    %118 = (%115)(y, %117)::Float64
│           (xk = %114)
│           (yk = %118)
│    %121 = Base.:+::Core.Const(+)
│    %122 = xk::Float64
│    %123 = xk::Float64
│    %124 = (%122 * %123)::Float64
│    %125 = yk::Float64
│    %126 = yk::Float64
│    %127 = (%125 * %126)::Float64
└───        (ρ = (%121)(%124, %127))
31 ┄ %129 = ρ::Float64
│    %130 = k::Int64
│    %131 = Core.tuple(%129, %130)::Tuple{Float64, Int64}
└───        return %131


julia> @btime sqrt(1.0 + 2.0im)
  1.250 ns (0 allocations: 0 bytes)
1.272019649514069 + 0.7861513777574233im

julia> versioninfo()
Julia Version 1.12.0-DEV.753
Commit 4d0149d160b (2024-06-20 16:00 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 10 × Apple M1 Max
  WORD_SIZE: 64
  LLVM: libLLVM-17.0.6 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)

There may be something I'm missing about why it was written the way it was before, but this new way looks simpler to me, avoids a Union type in type inference, and precludes the need for a few type assertions.

This showed up "in the wild" in https://github.com/JuliaGPU/Metal.jl/pull/374, where the Union type in type inference caused issues for compiling sqrt(::Complex) on Apple GPUs.

mtfishman avatar Jun 20 '24 19:06 mtfishman

Would this break a hypothetical user-defined floating-point type with a BigInt exponent?

nsajko avatar Jun 21 '24 15:06 nsajko

To ensure type stability while keeping correctness for generic code, perhaps it'd make sense to do something like

k = zero(exponent(x))

Or maybe also promote with Int?

nsajko avatar Jun 21 '24 15:06 nsajko

That's a good point, I didn't consider cases where exponent might not be outputting something of type Int. It seems like for that, instead of making it more generic, we should convert to Int like the behavior before. exponent isn't always called in the code, so it isn't clear to me how to initialize k to zero(exponent(x)) without calling exponent or inferring the output type of exponent some other way, which seems pretty heavy-duty for such a low level math function.

Also, a number x such that exponent(x) > typemax(Int) is really big, that's too big to even be represented by BigFloat/BigInt, for example:

julia> big(2)^typemax(Int)
ERROR: OutOfMemoryError()
Stacktrace:
 [1] pow_ui!(x::BigInt, a::BigInt, b::UInt64)
   @ Base.GMP.MPZ ./gmp.jl:180
 [2] pow_ui
   @ ./gmp.jl:181 [inlined]
 [3] ^
   @ ./gmp.jl:626 [inlined]
 [4] bigint_pow(x::BigInt, y::Int64)
   @ Base.GMP ./gmp.jl:647
 [5] ^(x::BigInt, y::Int64)
   @ Base.GMP ./gmp.jl:652
 [6] top-level scope
   @ REPL[32]:1

In your hypothetical example where exponent(x) outputs something like BigInt, it would have failed in the version of Base.ssqs on master with a conversion error if the output couldn't get converted to Int (i.e. if exponent(x) > typemax(Int)), so it should be ok to keep that behavior.

I suppose changing the initialization of k back from k = 0 to k::Int = 0 would make sure that if exponent(m) output something with a type other than Int it would get converted to Int like before, but maybe there is a better code pattern to do that, like:

k = m==0 ? k : convert(typeof(k), exponent(m))

mtfishman avatar Jun 21 '24 15:06 mtfishman

In the latest commit I changed the code to convert the output of exponent to Int in Base.ssqs.

Here's an example of a custom number type that outputs BigInt from exponent:

struct MyBigFloat <: AbstractFloat
  x::BigFloat
end
Base.exponent(x::MyBigFloat) = big(exponent(x.x))
Base.promote_rule(::Type{MyBigFloat}, ::Type{<:Real}) = MyBigFloat
Base.:<(x::MyBigFloat, y::MyBigFloat) = (x.x < y.x)
Base.:-(x::MyBigFloat) = MyBigFloat(-x.x)
Base.:+(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x + y.x)
Base.:-(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x - y.x)
Base.:*(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x * y.x)
Base.:/(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x / y.x)
Base.nextfloat(x::MyBigFloat) = MyBigFloat(nextfloat(x.x))
Base.eps(type::Type{MyBigFloat}) = MyBigFloat(eps(BigFloat))
Base.ldexp(x::MyBigFloat, y::Int64) = MyBigFloat(ldexp(x.x, y))
Base.sqrt(x::MyBigFloat) = MyBigFloat(sqrt(x.x))

and here is the output of @code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0)) based on the latest commit:

`@code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0))`: latest commit of this PR

julia> @code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0))
MethodInstance for Base.ssqs(::MyBigFloat, ::MyBigFloat)
  from ssqs(x::T, y::T) where T<:Real @ Base complex.jl:516
Static Parameters
  T = MyBigFloat
Arguments
  #self#::Core.Const(Base.ssqs)
  x::MyBigFloat
  y::MyBigFloat
Locals
  yk::MyBigFloat
  xk::MyBigFloat
  m::MyBigFloat
  ρ::MyBigFloat
  k::Int64
  @_9::MyBigFloat
  @_10::Bool
  @_11::Bool
  @_12::Bool
  @_13::Int64
Body::Tuple{MyBigFloat, Int64}
1 ──        Core.NewvarNode(:(yk))
│           Core.NewvarNode(:(xk))
│           Core.NewvarNode(:(m))
│           (k = 0)
│    %5   = Base.:+::Core.Const(+)
│    %6   = (x * x)::MyBigFloat
│    %7   = (y * y)::MyBigFloat
│           (ρ = (%5)(%6, %7))
│    %9   = Base.:!::Core.Const(!)
│    %10  = ρ::MyBigFloat
│    %11  = Base.isfinite(%10)::Bool
│    %12  = (%9)(%11)::Bool
└───        goto #7 if not %12
2 ── %14  = Base.isinf(x)::Bool
└───        goto #4 if not %14
3 ──        (@_10 = %14)
└───        goto #5
4 ──        (@_10 = Base.isinf(y))
5 ┄─ %19  = @_10::Bool
└───        goto #7 if not %19
6 ── %21  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│           (ρ = Base.convert(%21, Base.Inf))
└───        goto #25
7 ┄─ %24  = ρ::MyBigFloat
│    %25  = Base.isinf(%24)::Bool
└───        goto #9 if not %25
8 ──        goto #18
9 ── %28  = ρ::MyBigFloat
│    %29  = (%28 == 0)::Bool
└───        goto #14 if not %29
10 ─ %31  = (x != 0)::Bool
└───        goto #12 if not %31
11 ─        (@_12 = %31)
└───        goto #13
12 ─        (@_12 = y != 0)
13 ┄ %36  = @_12::Bool
│           (@_11 = %36)
└───        goto #15
14 ─        (@_11 = false)
15 ┄ %40  = @_11::Bool
└───        goto #17 if not %40
16 ─        goto #18
17 ─ %43  = Base.:<::Core.Const(<)
│    %44  = ρ::MyBigFloat
│    %45  = Base.:/::Core.Const(/)
│    %46  = Base.nextfloat::Core.Const(nextfloat)
│    %47  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│    %48  = Base.zero(%47)::MyBigFloat
│    %49  = (%46)(%48)::MyBigFloat
│    %50  = Base.:*::Core.Const(*)
│    %51  = Base.:^::Core.Const(^)
│    %52  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│    %53  = Base.eps(%52)::MyBigFloat
│    %54  = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│    %55  = (%54)()::Core.Const(Val{2}())
│    %56  = Base.literal_pow(%51, %53, %55)::MyBigFloat
│    %57  = (%50)(2, %56)::MyBigFloat
│    %58  = (%45)(%49, %57)::MyBigFloat
│    %59  = (%43)(%44, %58)::Bool
└───        goto #25 if not %59
18 ┄ %61  = Base.max::Core.Const(max)
│    %62  = Base.abs(x)::MyBigFloat
│    %63  = Base.abs(y)::MyBigFloat
│    %64  = (%61)(%62, %63)::MyBigFloat
│           (@_9 = %64)
│    %66  = @_9::MyBigFloat
│    %67  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│    %68  = (%66 isa %67)::Core.Const(true)
└───        goto #20 if not %68
19 ─        goto #21
20 ─        Core.Const(:($(Expr(:static_parameter, 1))))
│           Core.Const(:(@_9))
│           Core.Const(:(Base.convert(%71, %72)))
│           Core.Const(:($(Expr(:static_parameter, 1))))
└───        Core.Const(:(@_9 = Core.typeassert(%73, %74)))
21 ┄ %76  = @_9::MyBigFloat
│           (m = %76)
│    %78  = m::MyBigFloat
│    %79  = (%78 == 0)::Bool
└───        goto #23 if not %79
22 ─        (@_13 = 0)
└───        goto #24
23 ─ %83  = Base.convert::Core.Const(convert)
│    %84  = Base.Int::Core.Const(Int64)
│    %85  = Base.exponent::Core.Const(exponent)
│    %86  = m::MyBigFloat
│    %87  = (%85)(%86)::BigInt
└───        (@_13 = (%83)(%84, %87))
24 ┄ %89  = @_13::Int64
│           (k = %89)
│    %91  = Base.ldexp::Core.Const(ldexp)
│    %92  = k::Int64
│    %93  = -%92::Int64
│    %94  = (%91)(x, %93)::MyBigFloat
│    %95  = Base.ldexp::Core.Const(ldexp)
│    %96  = k::Int64
│    %97  = -%96::Int64
│    %98  = (%95)(y, %97)::MyBigFloat
│           (xk = %94)
│           (yk = %98)
│    %101 = Base.:+::Core.Const(+)
│    %102 = xk::MyBigFloat
│    %103 = xk::MyBigFloat
│    %104 = (%102 * %103)::MyBigFloat
│    %105 = yk::MyBigFloat
│    %106 = yk::MyBigFloat
│    %107 = (%105 * %106)::MyBigFloat
└───        (ρ = (%101)(%104, %107))
25 ┄ %109 = ρ::MyBigFloat
│    %110 = k::Int64
│    %111 = Core.tuple(%109, %110)::Tuple{MyBigFloat, Int64}
└───        return %111

so you can see there aren't any Union types in type inference with this version, whereas there would be on master and also on the previous commit of this PR, as @nsajko pointed out.

mtfishman avatar Jun 21 '24 16:06 mtfishman