Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Fix type escaping in `@import_frule`, `@import_rrule`

Open mofeing opened this issue 1 year ago • 4 comments

While trying to import some ChainRules from a package, I got the following error:

julia> using Enzyme, OMEinsum

julia> Enzyme.@import_rrule(typeof(einsum), OMEinsum.EinCode, Any, Any)
UndefVarError: `OMEinsum` not defined

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/srACB/ext/EnzymeChainRulesCoreExt.jl:181 [inlined]
 [2] top-level scope
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum-from-chainrules.ipynb:1

By expanding the macro, I saw the the type annotations where not correctly escaped. For example, for augmented_primal we get

... augmented_primal(...) where {... var"#378#AN_1" <: Enzyme.Annotation{<:(Enzyme.OMEinsum).EinCode}, var"#379#AN_2" <: Enzyme.Annotation{<:Enzyme.Any}, var"#380#AN_3" <: Enzyme.Annotation{<:Enzyme.Any}}

Check out that Enzyme.OMEinsum and Enzyme.Any are wrong.

I haven't checked it out but this PR should fix it.

mofeing avatar May 14 '24 16:05 mofeing

If possible can you add a test?

On Tue, May 14, 2024 at 9:03 AM Sergio Sánchez Ramírez < @.***> wrote:

While trying to import some ChainRules from a package, I got the following error:

julia> using Enzyme, OMEinsum

julia> @.***_rrule(typeof(einsum), OMEinsum.EinCode, Any, Any) UndefVarError: OMEinsum not defined

Stacktrace: [1] macro expansion @ ~/.julia/packages/Enzyme/srACB/ext/EnzymeChainRulesCoreExt.jl:181 [inlined] [2] top-level scope @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum-from-chainrules.ipynb:1

By expanding the macro, I saw the the type annotations where not correctly escaped. For example, for augmented_primal we get

... augmented_primal(...) where {var"#376#RetAnnotation", var"#377#FA" <: Enzyme.Annotation{<:typeof(einsum)}, var"#378#AN_1" <: Enzyme.Annotation{<:(Enzyme.OMEinsum).EinCode}, var"#379#AN_2" <: Enzyme.Annotation{<:Enzyme.Any}, var"#380#AN_3" <: Enzyme.Annotation{<:Enzyme.Any}}

I haven't checked it out but this PR should fix it.

You can view, comment on, or merge this pull request online at:

https://github.com/EnzymeAD/Enzyme.jl/pull/1446 Commit Summary

File Changes

(1 file https://github.com/EnzymeAD/Enzyme.jl/pull/1446/files)

Patch Links:

  • https://github.com/EnzymeAD/Enzyme.jl/pull/1446.patch
  • https://github.com/EnzymeAD/Enzyme.jl/pull/1446.diff

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/pull/1446, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXCW4UO5RI762JL6HW3ZCIYWRAVCNFSM6AAAAABHWPJXDWVHI2DSMVQWIX3LMV43ASLTON2WKOZSGI4TKOBVHEYTCOA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

wsmoses avatar May 14 '24 16:05 wsmoses

Added! It tests against a MockType in a MockModule.

mofeing avatar May 14 '24 16:05 mofeing

@mofeing this fails CI

wsmoses avatar May 14 '24 19:05 wsmoses

Forgot to add methods for fdiff and rdiff. Should be fixed now.

mofeing avatar May 14 '24 21:05 mofeing

I fixed all errors in the tests except one: The rrule test of MockType.

The pullback in rrule defined for mock_function should return a number

function ChainRulesCore.rrule(::typeof(MockModule.mock_function), x)
    y = MockModule.mock_function(x)
    return y, ȳ -> 2 * ȳ
end

But it seems like Enzyme is returning a MockType. Is this because we are annotating the return activity to be Active?

rdiff(f, x::MockModule.MockType) = autodiff(Reverse, f, Active, Active(x))[1][1]

Enzyme.@import_rrule typeof(MockModule.mock_function) MockModule.MockType
@test rdiff(MockModule.mock_function, MockModule.MockType(1f0)) === 2f0

...

import_rrule: Test Failed at /Users/mofeing/Developer/Enzyme.jl/test/ext/chainrulescore.jl:117
  Expression: rdiff(MockModule.mock_function, MockModule.MockType(1.0f0)) === 2.0f0
   Evaluated: Main.MockModule.MockType(2.0f0) === 2.0f0

mofeing avatar May 15 '24 20:05 mofeing

@mofeing CI still fails "import_rrule: Test Failed at /home/runner/work/Enzyme.jl/Enzyme.jl/test/ext/chainrulescore.jl:117 Expression: rdiff(MockModule.mock_function, MockModule.MockType(1.0f0)) === 2.0f0 Evaluated: Main.MockModule.MockType(2.0f0) === 2.0f0 "

wsmoses avatar May 16 '24 05:05 wsmoses

and no, for active vals, Enzyme returns a value of the same type (whatever it is)

wsmoses avatar May 16 '24 05:05 wsmoses

Okay, fixed now.

mofeing avatar May 16 '24 08:05 mofeing

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 72.43%. Comparing base (cc8ceb6) to head (8130415). Report is 2 commits behind head on main.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1446      +/-   ##
==========================================
+ Coverage   68.04%   72.43%   +4.39%     
==========================================
  Files          30       30              
  Lines       10772    10838      +66     
==========================================
+ Hits         7330     7851     +521     
+ Misses       3442     2987     -455     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar May 16 '24 16:05 codecov-commenter