JUDI.jl
JUDI.jl copied to clipboard
Enable scalar/broadcast operation for LazyPropagation
- Enable scalar/broadcast operation for
LazyPropagation
; add associated test (which won't pass with the current master) -
LazyPropagation
now has an attributeval
, which storesF * q
if previously computed - fix the
reshape
issue for multi source vector -- which can be in size ofnsrc
and also in size ofnsrc * nt * nrec
Codecov Report
Base: 81.88% // Head: 81.59% // Decreases project coverage by -0.29%
:warning:
Coverage data is based on head (
171170f
) compared to base (8f65ed4
). Patch coverage: 41.17% of modified lines in pull request are covered.
Additional details and impacted files
@@ Coverage Diff @@
## master #167 +/- ##
==========================================
- Coverage 81.88% 81.59% -0.30%
==========================================
Files 28 28
Lines 2186 2200 +14
==========================================
+ Hits 1790 1795 +5
- Misses 396 405 +9
Impacted Files | Coverage Δ | |
---|---|---|
src/TimeModeling/LinearOperators/lazy.jl | 83.72% <0.00%> (-0.99%) |
:arrow_down: |
src/rrules.jl | 62.02% <40.00%> (-5.14%) |
:arrow_down: |
src/TimeModeling/Types/abstract.jl | 77.41% <100.00%> (+0.18%) |
:arrow_up: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
Your change lead to ambiguities... please run these basic tests locally
On a side note: does it make sense to move the scalar operations (all of +-*/) into LazyPropagation.post?
nto LazyPropagation.post?
No because then it's not a linear operation anymore
Hmm appreciate your @mloubout comment on this one: I am now on JUDI master and
julia> gs_inv = gradient(x -> norm(F(x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.50 s
Operator `gradient` ran in 0.34 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)
julia> gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.55 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
ERROR: MethodError: no method matching *(::Float32, ::JUDI.LazyPropagation)
Closest candidates are:
*(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
*(::T, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:385
*(::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:414
...
Stacktrace:
[1] (::ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}})()
@ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/arraymath.jl:111
[2] unthunk
@ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
[3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, ChainRules.var"#1489#1493"{JUDI.LazyPropagation, Float32}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:237
[4] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:105 [inlined]
[5] map
@ ./tuple.jl:223 [inlined]
[6] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
[7] ZBack
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
[8] Pullback
@ ./REPL[26]:1 [inlined]
[9] (::typeof(∂(#10)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[10] (::Zygote.var"#60#61"{typeof(∂(#10))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
[11] gradient(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
[12] top-level scope
@ REPL[26]:1
[13] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
julia> import Base.*;
julia> *(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));
julia> gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.56 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.43 s
Operator `gradient` ran in 0.35 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)
gs_inv
performs a nonlinear forward modeling and an RTM. gs_inv1
fails because scalar multiplication is not defined yet. After the definition of multiplication, gs_inv2
did 2 evaluations on the LazyPropgation, which confuses me ... any idea why? Thanks
Full script below
using JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit
Flux.Random.seed!(2022)
### Model
tti = false
viscoacoustic = false
nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)
# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)
ra = false
stype = "Point"
Pq = Ps
opt = Options(return_array=ra, sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)
# Operators
F = Pr*A_inv*adjoint(Pq)
F0 = Pr*A_inv0*adjoint(Pq)
gs_inv = gradient(x -> norm(F(x)*q), m0)
gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)
import Base.*;
*(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));
gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)
That's quite curious indeed i'll see if can figure out what's going on
Well that's is baaaaaaaad, this is why people don't wanna use Julia for serious stuff.
When you do gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)
Zygote doesn't understand correctly that you want "only" the derivative w.r.t to m0
, in part because it doesn't understand thinks. So it end up computing what you want, i.e d F(1*m0)*q / d m0
but because diff rules are defined for both left and right input for mul
(and again since zygote always computes and evaluate everything) it also computes d F(1*m0)*q / d 1
which calls dot
which calls eval_prop
.
So there is not trivial way out of it except maybe having LazyPropagation
store the result at its first evaluation so its only computed once (the above compute the same gradient twice)
Could you enlighten me how (by code or something) you reach the conclusion here https://github.com/slimgroup/JUDI.jl/pull/167#issuecomment-1373145567 ? I am experiencing issue below and would like to check what went wrong ...
julia> gs_inv = gradient(() -> norm(F(1f0*m)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.48 s
Operator `gradient` ran in 0.34 s
Grads(...)
julia> gs_inv = gradient(() -> norm(F(m*1f0)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
Grads(...)
Debug every eval_prop to see which where it's called and what the inputs are. In that other case it was evaluated in dot then you can infer why and check that's undeed the gradient it computes by requesting it as a param
Not sure where you are in the debug, but I can tell you that's it's not super trivial and the fix will require some proper design to extend it cleanly to this type of case. But i'll leave it to you to at least find what the issue is as an exercise.
Thanks! Yes I agree this is not simple. I will pick it up some time later this week