JUDI.jl icon indicating copy to clipboard operation
JUDI.jl copied to clipboard

Enable scalar/broadcast operation for LazyPropagation

Open ziyiyin97 opened this issue 1 year ago • 11 comments

  1. Enable scalar/broadcast operation for LazyPropagation; add associated test (which won't pass with the current master)
  2. LazyPropagation now has an attribute val, which stores F * q if previously computed
  3. fix the reshape issue for multi source vector -- which can be in size of nsrc and also in size of nsrc * nt * nrec

ziyiyin97 avatar Jan 04 '23 23:01 ziyiyin97

Codecov Report

Base: 81.88% // Head: 81.59% // Decreases project coverage by -0.29% :warning:

Coverage data is based on head (171170f) compared to base (8f65ed4). Patch coverage: 41.17% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #167      +/-   ##
==========================================
- Coverage   81.88%   81.59%   -0.30%     
==========================================
  Files          28       28              
  Lines        2186     2200      +14     
==========================================
+ Hits         1790     1795       +5     
- Misses        396      405       +9     
Impacted Files Coverage Δ
src/TimeModeling/LinearOperators/lazy.jl 83.72% <0.00%> (-0.99%) :arrow_down:
src/rrules.jl 62.02% <40.00%> (-5.14%) :arrow_down:
src/TimeModeling/Types/abstract.jl 77.41% <100.00%> (+0.18%) :arrow_up:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov[bot] avatar Jan 04 '23 23:01 codecov[bot]

Your change lead to ambiguities... please run these basic tests locally

mloubout avatar Jan 05 '23 00:01 mloubout

On a side note: does it make sense to move the scalar operations (all of +-*/) into LazyPropagation.post?

ziyiyin97 avatar Jan 05 '23 15:01 ziyiyin97

nto LazyPropagation.post?

No because then it's not a linear operation anymore

mloubout avatar Jan 05 '23 15:01 mloubout

Hmm appreciate your @mloubout comment on this one: I am now on JUDI master and

julia> gs_inv = gradient(x -> norm(F(x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.50 s
Operator `gradient` ran in 0.34 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)

julia> gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.55 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
ERROR: MethodError: no method matching *(::Float32, ::JUDI.LazyPropagation)
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  *(::T, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:385
  *(::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:414
  ...
Stacktrace:
  [1] (::ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}})()
    @ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/arraymath.jl:111
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, ChainRules.var"#1489#1493"{JUDI.LazyPropagation, Float32}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:237
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:105 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
  [7] ZBack
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
  [8] Pullback
    @ ./REPL[26]:1 [inlined]
  [9] (::typeof(∂(#10)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof(∂(#10))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [11] gradient(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [12] top-level scope
    @ REPL[26]:1
 [13] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

julia> import Base.*;

julia> *(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));

julia> gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.56 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.43 s
Operator `gradient` ran in 0.35 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)

gs_inv performs a nonlinear forward modeling and an RTM. gs_inv1 fails because scalar multiplication is not defined yet. After the definition of multiplication, gs_inv2 did 2 evaluations on the LazyPropgation, which confuses me ... any idea why? Thanks

Full script below

using JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit

Flux.Random.seed!(2022)

### Model
tti = false
viscoacoustic = false

nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)

# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)

ra = false
stype = "Point"
Pq = Ps

opt = Options(return_array=ra, sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)

# Operators
F = Pr*A_inv*adjoint(Pq)
F0 = Pr*A_inv0*adjoint(Pq)

gs_inv = gradient(x -> norm(F(x)*q), m0)

gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)

import Base.*;
*(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));
gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)

ziyiyin97 avatar Jan 06 '23 00:01 ziyiyin97

That's quite curious indeed i'll see if can figure out what's going on

mloubout avatar Jan 06 '23 04:01 mloubout

Well that's is baaaaaaaad, this is why people don't wanna use Julia for serious stuff.

When you do gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0) Zygote doesn't understand correctly that you want "only" the derivative w.r.t to m0, in part because it doesn't understand thinks. So it end up computing what you want, i.e d F(1*m0)*q / d m0 but because diff rules are defined for both left and right input for mul (and again since zygote always computes and evaluate everything) it also computes d F(1*m0)*q / d 1 which calls dot which calls eval_prop.

So there is not trivial way out of it except maybe having LazyPropagation store the result at its first evaluation so its only computed once (the above compute the same gradient twice)

mloubout avatar Jan 06 '23 04:01 mloubout

Could you enlighten me how (by code or something) you reach the conclusion here https://github.com/slimgroup/JUDI.jl/pull/167#issuecomment-1373145567 ? I am experiencing issue below and would like to check what went wrong ...

julia> gs_inv = gradient(() -> norm(F(1f0*m)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.48 s
Operator `gradient` ran in 0.34 s
Grads(...)

julia> gs_inv = gradient(() -> norm(F(m*1f0)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
Grads(...)

ziyiyin97 avatar Jan 08 '23 21:01 ziyiyin97

Debug every eval_prop to see which where it's called and what the inputs are. In that other case it was evaluated in dot then you can infer why and check that's undeed the gradient it computes by requesting it as a param

mloubout avatar Jan 08 '23 23:01 mloubout

Not sure where you are in the debug, but I can tell you that's it's not super trivial and the fix will require some proper design to extend it cleanly to this type of case. But i'll leave it to you to at least find what the issue is as an exercise.

mloubout avatar Jan 09 '23 20:01 mloubout

Thanks! Yes I agree this is not simple. I will pick it up some time later this week

ziyiyin97 avatar Jan 09 '23 21:01 ziyiyin97