Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Fix hessian sret type

Open wsmoses opened this issue 3 months ago • 5 comments

requires https://github.com/EnzymeAD/Enzyme/pull/2592

wsmoses avatar Dec 01 '25 01:12 wsmoses

Your PR requires formatting changes to meet the project's style guidelines. Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/compiler.jl b/src/compiler.jl
index bec4673b..6fad4576 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -1059,15 +1059,15 @@ end
                     EnumAttribute("willreturn"),
                     EnumAttribute("nosync"),
                     EnumAttribute("nofree"),
-           	    StringAttribute("enzyme_preserve_primal", "*"),
+            StringAttribute("enzyme_preserve_primal", "*"),
 		      ]
     else
         LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"),
                     EnumAttribute("willreturn"),
                     EnumAttribute("nosync"),
-		    EnumAttribute("nofree"),
-           	    StringAttribute("enzyme_preserve_primal", "*"),
-		    ]
+            EnumAttribute("nofree"),
+            StringAttribute("enzyme_preserve_primal", "*"),
+        ]
     end
     handleCustom(state, custom, k_name, llvmfn, name, attrs)
     return
@@ -6751,11 +6751,11 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
             for f in functions(mod)
                 for i in 1:length(parameters(f))
                     for a in collect(parameter_attributes(f, i))
-                       if kind(a) == "enzyme_sret"
-                           API.EnzymeDumpValueRef(f)
-                       end
-                       @assert kind(a) != "enzyme_sret"
-                       @assert kind(a) != "enzyme_sret_v"
+                        if kind(a) == "enzyme_sret"
+                            API.EnzymeDumpValueRef(f)
+                        end
+                        @assert kind(a) != "enzyme_sret"
+                        @assert kind(a) != "enzyme_sret_v"
                     end
                 end
             end
@@ -6765,7 +6765,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
             if DumpPrePostOpt[]
                 API.EnzymeDumpModuleRef(mod.ref)
             end
-            post_optimize!(mod, JIT.get_tm(); callconv=false)
+            post_optimize!(mod, JIT.get_tm(); callconv = false)
             if DumpPostOpt[]
                 API.EnzymeDumpModuleRef(mod.ref)
             end
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 6cfb5959..da3b1895 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -380,7 +380,7 @@ const DumpPostCallConv = Ref(false)
 function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine)
     addr13NoAlias(mod)
     
-    removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
+    removeDeadArgs!(mod, tm, #=post_gc_fixup=# false)
 
     memcpy_sret_split!(mod)
     # if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted.
@@ -436,17 +436,17 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
     if callconv
         fixup_callconv!(mod, tm)
     end
-    
+
     for f in functions(mod)
-	if isempty(blocks(f))
-		continue
-	end
-	if has_fn_attr(f, StringAttribute("enzyme_preserve_primal"))
-	     delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal"))
-	end
+        if isempty(blocks(f))
+            continue
+        end
+        if has_fn_attr(f, StringAttribute("enzyme_preserve_primal"))
+            delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal"))
+        end
     end
 
-    removeDeadArgs!(mod, tm, #=post_gc_fixup=#true)
+    removeDeadArgs!(mod, tm, #=post_gc_fixup=# true)
 
     @dispose pb = NewPMPassBuilder() begin
         registerEnzymeAndPassPipeline!(pb)
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 00381caf..d557abf9 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -2527,14 +2527,14 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
             )
                 for u in LLVM.uses(fn)
                     u = LLVM.user(u)
-		    if !isa(u, LLVM.CallInst)
-                    	msg = sprint() do io
-			   println(io, "Unknown user of fn: ", string(u))
-			   println(io, "fn: ", string(fn))
-			   println(io, "mod: ", string(parent(fn)))
-			end
-			throw(AssertionError(msg))
-		    end
+                        if !isa(u, LLVM.CallInst)
+                            msg = sprint() do io
+                                println(io, "Unknown user of fn: ", string(u))
+                                println(io, "fn: ", string(fn))
+                                println(io, "mod: ", string(parent(fn)))
+                            end
+                            throw(AssertionError(msg))
+                        end
                     B = IRBuilder()
                     nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
                     position!(B, nextInst)

github-actions[bot] avatar Dec 01 '25 01:12 github-actions[bot]

Codecov Report

:white_check_mark: All modified and coverable lines are covered by tests. :white_check_mark: Project coverage is 48.00%. Comparing base (d64497a) to head (dd3f4eb). :warning: Report is 5 commits behind head on main.

:exclamation: There is a different number of reports uploaded between BASE (d64497a) and HEAD (dd3f4eb). Click for more details.

HEAD has 34 uploads less than BASE
Flag BASE (d64497a) HEAD (dd3f4eb)
38 4
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2824       +/-   ##
===========================================
- Coverage   67.53%   48.00%   -19.53%     
===========================================
  Files          58       13       -45     
  Lines       21051     1256    -19795     
===========================================
- Hits        14217      603    -13614     
+ Misses       6834      653     -6181     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Dec 01 '25 01:12 codecov[bot]

Benchmark Results

main dd3f4eb2d82acd... main / dd3f4eb2d82acd...
basics/make_zero/namedtuple 0.049 ± 0.0096 μs 0.0492 ± 0.011 μs 0.996 ± 0.29
basics/make_zero/struct 0.252 ± 0.0078 μs 0.251 ± 0.0067 μs 1.01 ± 0.041
basics/overhead 3.18 ± 0.003 ns 3.46 ± 0.003 ns 0.917 ± 0.0012
basics/remake_zero!/namedtuple 0.222 ± 0.007 μs 0.219 ± 0.011 μs 1.01 ± 0.059
basics/remake_zero!/struct 0.223 ± 0.011 μs 0.225 ± 0.0091 μs 0.992 ± 0.063
fold_broadcast/multidim_sum_bcast/1D 10.9 ± 0.24 μs 10.9 ± 0.24 μs 1 ± 0.031
fold_broadcast/multidim_sum_bcast/2D 12.2 ± 0.32 μs 12.2 ± 0.35 μs 1 ± 0.039
time_to_load 0.983 ± 0.0081 s 0.978 ± 0.0058 s 1 ± 0.01

Benchmark Plots

A plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/20237959906/artifacts/4873816709.

github-actions[bot] avatar Dec 01 '25 01:12 github-actions[bot]

@copilot isolate a MWE of whatever test is timing out

wsmoses avatar Dec 13 '25 22:12 wsmoses

@wsmoses I've opened a new pull request, #2848, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot avatar Dec 13 '25 22:12 Copilot