MutableArithmetics.jl icon indicating copy to clipboard operation
MutableArithmetics.jl copied to clipboard

@rewrite for arbitrary operation/function

Open mattsignorelli opened this issue 1 year ago • 7 comments

Does MutableArithmetics not work for general functions, only for the standard arithmetic operations (+,-,/,*)? I have a mutable type that overrides many of the Base functions (sin, cos, abs, sqrt, csc) as well as some others not in base (sinhc). I'd really like to use this interface and the @rewrite macro to speed up evaluation of expressions, but I've found it doesn't work for these functions.

My code for sin for example is

mutability(::Type{TPS}) = IsMutable()

function promote_operation(::typeof(sin), ::Type{TPS}, ::Type{TPS}) 
  return TPS
end

function operate!(::typeof(sin), a::TPS)
  mad_tpsa_sin!(a.tpsa, a.tpsa)
  return a
end

function operate_to!(output::TPS, ::typeof(sin), a::TPS)
  mad_tpsa_sin!(a.tpsa, output.tpsa)
  return output
end

mattsignorelli avatar Dec 26 '23 23:12 mattsignorelli

Specifically, there is no speedup nor reduction in the memory allocation for evaluation of say

t = @rewrite sin(x*sin(y)) + sin(z)

mattsignorelli avatar Dec 26 '23 23:12 mattsignorelli

Does MutableArithmetics not work for general functions, only for the standard arithmetic operations (+,-,/,*)?

Correct the @rewrite macro works only for a limited subset of the standard arithmetic operations.

odow avatar Dec 27 '23 00:12 odow

Correct the @rewrite macro works only for a limited subset of the standard arithmetic operations.

Would it be a lot of work to generalize this macro for any overloaded function? It looks like other parts of the interface can handle arbitrary functions (seeing abs implemented for BigInt), so there could be a lot of performance benefits for expressions with a lot of sin, cos, sqrt, log, etc

mattsignorelli avatar Dec 27 '23 01:12 mattsignorelli

The macro is defined here (x-ref https://github.com/jump-dev/MutableArithmetics.jl/pull/254): https://github.com/jump-dev/MutableArithmetics.jl/blob/0b3f7d14c3326d5e7721fc41738254e2a93c05bc/src/rewrite.jl#L7-L29 It isn't very complicated.

Here's the actual rewrites: https://github.com/jump-dev/MutableArithmetics.jl/blob/0b3f7d14c3326d5e7721fc41738254e2a93c05bc/src/rewrite_generic.jl#L58-L218

I don't know if we want to add a generalized rewrite. The main purpose of MutableArithmetics is for JuMP. @blegat is the one who would need to decide.

odow avatar Dec 27 '23 01:12 odow

Thanks. Worst case I suppose I could import and modify it. The use case is for GTPSA.jl, a package wrapping a C library for manipulating truncated power series. Each mutable struct is a truncated power series and all the Base math functions are overloaded. Currently all the intermediate values in an expression allocate a new struct. Preliminary testing with using a preallocated temporary buffer shows quite a big speedup. So using a macro or some other method that could do this would be very advantageous

mattsignorelli avatar Dec 27 '23 02:12 mattsignorelli

I'm not opposed to adding support for these unary functions. PR welcome

blegat avatar Dec 28 '23 09:12 blegat

We'd just need to be very careful to ensure that the new rewrite doesn't break JuMP's nonlinear code.

odow avatar Dec 28 '23 22:12 odow