Zygote.jl
Zygote.jl copied to clipboard
Incremental accumulation of gradients?
From discourse https://discourse.julialang.org/t/zygote-gradient-accumulation/55654
I have a densenet inspired architecture implemented in pytorch and ported it to julia. Sadly I get out of memory errors now. Here is a MWE, where Julia memory consumption is more then ~5x~ 2x compared to pytorch: On my laptop (GeForce RTX 2060 5.9GB) Julia throws out of memory error, while pytorch does not. Observe that the pytorch tensor is ~5x~ 2x bigger! Also when using the same tensor size in both Zygote and pytorch, pytorch is 2.5x faster.
Zygote
using Zygote
using CUDA
function net(x1)
x2 = x1
x3 = x1 + x2
x4 = x1 + x2 + x3
x5 = x1 + x2 + x3 + x4
x6 = x1 + x2 + x3 + x4 + x5
x7 = x1 + x2 + x3 + x4 + x5 + x6
x8 = x1 + x2 + x3 + x4 + x5 + x6 + x7
x9 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
x11 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10
x12 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11
#x13 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 + x12
return x12
end
function loss(x)
sum(abs2, net(x))
end
x = CUDA.randn(128,128,128,20)
Zygote.gradient(loss, x) # OOM error
x = CUDA.randn(128,128,128,10)
Zygote.gradient(loss, x) #warmup
CUDA.@time for _ in 1:100
Zygote.gradient(loss, x)
end
# 26.188100 seconds (449.09 k CPU allocations: 15.913 MiB, 46.80% gc time) (11.30 k GPU allocations: 867.188 GiB, 59.73% gc time of which 18.51% spent allocating)
pytorch
import torch
def net(x1):
x2 = x1
x3 = x1 + x2
x4 = x1 + x2 + x3
x5 = x1 + x2 + x3 + x4
x6 = x1 + x2 + x3 + x4 + x5
x7 = x1 + x2 + x3 + x4 + x5 + x6
x8 = x1 + x2 + x3 + x4 + x5 + x6 + x7
x9 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
x11 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10
x12 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11
#x13 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 + x12
return x12
def loss(x):
return torch.sum(x**2)
x = torch.randn(128,128,128,40).to("cuda")
x.requires_grad = True
y = loss(net(x))
y.backward()
x.grad
x = torch.randn(128,128,128,10).to("cuda")
x.requires_grad = True
y = loss(net(x))
y.backward()
import time
start = time.time()
#with torch.autograd.profiler.profile(use_cuda=True) as prof:
for _ in range(100):
x.grad.zero_()
y = loss(net(x))
y.backward()
stop = time.time()
stop - start
# 10.764797449111938
Did you mean for your PyTorch code to be y = loss(net(x))
? Regardless, I can see why Zygote might OOM here and PyTorch would not.
Did you mean for your PyTorch code to be
y = loss(net(x))
?
:smile: you are right!
Any updates on the memory differences from then?
Any updates on the memory differences from then?
Yes I already updated the post.
Thanks! 2x is still a lot. What is the runtime difference between the two? I guess we could make ours more efficient. Still good to have the numbers handy
For completeness, here's the output of @code_adjoint
and @code_typed
for loss
when I inline net
into it:
@code_adjoint loss(x)
Zygote.Adjoint(1: (%3, %4 :: Zygote.Context, %1, %2)
%5 = Zygote._pullback(%4, Main.:+, %2, %2)
%6 = Base.getindex(%5, 1)
%7 = Base.getindex(%5, 2)
%8 = Zygote._pullback(%4, Main.:+, %2, %2, %6)
%9 = Base.getindex(%8, 1)
%10 = Base.getindex(%8, 2)
%11 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9)
%12 = Base.getindex(%11, 1)
%13 = Base.getindex(%11, 2)
%14 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9, %12)
%15 = Base.getindex(%14, 1)
%16 = Base.getindex(%14, 2)
%17 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9, %12, %15)
%18 = Base.getindex(%17, 1)
%19 = Base.getindex(%17, 2)
%20 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9, %12, %15, %18)
%21 = Base.getindex(%20, 1)
%22 = Base.getindex(%20, 2)
%23 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9, %12, %15, %18, %21)
%24 = Base.getindex(%23, 1)
%25 = Base.getindex(%23, 2)
%26 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9, %12, %15, %18, %21, %24)
%27 = Base.getindex(%26, 1)
%28 = Base.getindex(%26, 2)
%29 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9, %12, %15, %18, %21, %24, %27)
%30 = Base.getindex(%29, 1)
%31 = Base.getindex(%29, 2)
%32 = Zygote._pullback(%4, Main.:+, %2, %2, %6, %9, %12, %15, %18, %21, %24, %27, %30)
%33 = Base.getindex(%32, 1)
%34 = Base.getindex(%32, 2)
%35 = Zygote._pullback(%4, Main.sum, Main.abs2, %33)
%36 = Base.getindex(%35, 1)
%37 = Base.getindex(%35, 2)
return %36, 1: (%1)
%2 = (@37)(%1)
%3 = Zygote.gradindex(%2, 3)
%4 = (@34)(%3)
%5 = Zygote.gradindex(%4, 2)
%6 = Zygote.gradindex(%4, 3)
%7 = Zygote.gradindex(%4, 4)
%8 = Zygote.gradindex(%4, 5)
%9 = Zygote.gradindex(%4, 6)
%10 = Zygote.gradindex(%4, 7)
%11 = Zygote.gradindex(%4, 8)
%12 = Zygote.gradindex(%4, 9)
%13 = Zygote.gradindex(%4, 10)
%14 = Zygote.gradindex(%4, 11)
%15 = Zygote.gradindex(%4, 12)
%16 = (@31)(%15)
%17 = Zygote.gradindex(%16, 2)
%18 = Zygote.gradindex(%16, 3)
%19 = Zygote.gradindex(%16, 4)
%20 = Zygote.gradindex(%16, 5)
%21 = Zygote.gradindex(%16, 6)
%22 = Zygote.gradindex(%16, 7)
%23 = Zygote.gradindex(%16, 8)
%24 = Zygote.gradindex(%16, 9)
%25 = Zygote.gradindex(%16, 10)
%26 = Zygote.gradindex(%16, 11)
%27 = Zygote.accum(%14, %26)
%28 = (@28)(%27)
%29 = Zygote.gradindex(%28, 2)
%30 = Zygote.gradindex(%28, 3)
%31 = Zygote.gradindex(%28, 4)
%32 = Zygote.gradindex(%28, 5)
%33 = Zygote.gradindex(%28, 6)
%34 = Zygote.gradindex(%28, 7)
%35 = Zygote.gradindex(%28, 8)
%36 = Zygote.gradindex(%28, 9)
%37 = Zygote.gradindex(%28, 10)
%38 = Zygote.accum(%13, %25, %37)
%39 = (@25)(%38)
%40 = Zygote.gradindex(%39, 2)
%41 = Zygote.gradindex(%39, 3)
%42 = Zygote.gradindex(%39, 4)
%43 = Zygote.gradindex(%39, 5)
%44 = Zygote.gradindex(%39, 6)
%45 = Zygote.gradindex(%39, 7)
%46 = Zygote.gradindex(%39, 8)
%47 = Zygote.gradindex(%39, 9)
%48 = Zygote.accum(%12, %24, %36, %47)
%49 = (@22)(%48)
%50 = Zygote.gradindex(%49, 2)
%51 = Zygote.gradindex(%49, 3)
%52 = Zygote.gradindex(%49, 4)
%53 = Zygote.gradindex(%49, 5)
%54 = Zygote.gradindex(%49, 6)
%55 = Zygote.gradindex(%49, 7)
%56 = Zygote.gradindex(%49, 8)
%57 = Zygote.accum(%11, %23, %35, %46, %56)
%58 = (@19)(%57)
%59 = Zygote.gradindex(%58, 2)
%60 = Zygote.gradindex(%58, 3)
%61 = Zygote.gradindex(%58, 4)
%62 = Zygote.gradindex(%58, 5)
%63 = Zygote.gradindex(%58, 6)
%64 = Zygote.gradindex(%58, 7)
%65 = Zygote.accum(%10, %22, %34, %45, %55, %64)
%66 = (@16)(%65)
%67 = Zygote.gradindex(%66, 2)
%68 = Zygote.gradindex(%66, 3)
%69 = Zygote.gradindex(%66, 4)
%70 = Zygote.gradindex(%66, 5)
%71 = Zygote.gradindex(%66, 6)
%72 = Zygote.accum(%9, %21, %33, %44, %54, %63, %71)
%73 = (@13)(%72)
%74 = Zygote.gradindex(%73, 2)
%75 = Zygote.gradindex(%73, 3)
%76 = Zygote.gradindex(%73, 4)
%77 = Zygote.gradindex(%73, 5)
%78 = Zygote.accum(%8, %20, %32, %43, %53, %62, %70, %77)
%79 = (@10)(%78)
%80 = Zygote.gradindex(%79, 2)
%81 = Zygote.gradindex(%79, 3)
%82 = Zygote.gradindex(%79, 4)
%83 = Zygote.accum(%7, %19, %31, %42, %52, %61, %69, %76, %82)
%84 = (@7)(%83)
%85 = Zygote.gradindex(%84, 2)
%86 = Zygote.gradindex(%84, 3)
%87 = Zygote.accum(%5, %6, %17, %18, %29, %30, %40, %41, %50, %51, %59, %60, %67, %68, %74, %75, %80, %81, %85, %86)
%88 = Zygote.tuple(nothing, %87)
return %88)
back = pullback(loss, x)[2]; @code_typed back.back(1)
CodeInfo(
1 ─── %1 = Base.getfield(#self#, :t)::Tuple{Zygote.var"#2643#back#574"{Zygote.var"#572#573"{Array{Float64, 4}}}, typeof(∂(+)), typeof(∂(+)), typeof(∂(+)), typeof(∂(+)), typeof(∂(+)), typeof(∂(+)), typeof(∂(+)), typeof(∂(+)), Zygote.var"#3295#back#826"{Zygote.var"#824#825"}, typeof(∂(+))}
│ %2 = Base.getfield(%1, 11, true)::typeof(∂(+))
│ %3 = Base.getfield(%1, 9, true)::typeof(∂(+))
│ %4 = Base.getfield(%1, 8, true)::typeof(∂(+))
│ %5 = Base.getfield(%1, 7, true)::typeof(∂(+))
│ %6 = Base.getfield(%1, 6, true)::typeof(∂(+))
│ %7 = Base.getfield(%1, 5, true)::typeof(∂(+))
│ %8 = Base.getfield(%1, 4, true)::typeof(∂(+))
│ %9 = Base.getfield(%1, 3, true)::typeof(∂(+))
│ %10 = Base.getfield(%1, 2, true)::typeof(∂(+))
│ %11 = Base.getfield(%1, 1, true)::Zygote.var"#2643#back#574"{Zygote.var"#572#573"{Array{Float64, 4}}}
│ %12 = Core.getfield(%11, Symbol("#2642#_back"))::Zygote.var"#572#573"{Array{Float64, 4}}
│ %13 = invoke %12(_2::Int64)::Tuple{Nothing, Array{Float64, 4}}
│ %14 = Core.getfield(%13, 2)::Array{Float64, 4}
│ %15 = invoke %2(%14::Array{Float64, 4})::Core.PartialStruct(Tuple{Nothing, Any, Vararg{Any, N} where N}, Any[Core.Const(nothing), Any, Vararg{Any, N} where N])
│ %16 = Base.getfield(%15, 2, true)::Any
│ %17 = Base.getfield(%15, 3, true)::Any
│ %18 = Base.getfield(%15, 4, true)::Any
│ %19 = Base.getfield(%15, 5, true)::Any
│ %20 = Base.getfield(%15, 6, true)::Any
│ %21 = Base.getfield(%15, 7, true)::Any
│ %22 = Base.getfield(%15, 8, true)::Any
│ %23 = Base.getfield(%15, 9, true)::Any
│ %24 = Base.getfield(%15, 10, true)::Any
│ %25 = Base.getfield(%15, 11, true)::Any
│ %26 = Base.getfield(%15, 12, true)::Any
│ %27 = (%7)(%26)::Any
│ %28 = (isa)(%27, Nothing)::Bool
└──── goto #3 if not %28
2 ─── goto #4
3 ─── %31 = Zygote.gradindex(%27, 2)::Any
└──── goto #4
4 ┄── %33 = φ (#2 => nothing, #3 => %31)::Any
│ %34 = (isa)(%27, Nothing)::Bool
└──── goto #6 if not %34
5 ─── goto #7
6 ─── %37 = Zygote.gradindex(%27, 3)::Any
└──── goto #7
7 ┄── %39 = φ (#5 => nothing, #6 => %37)::Any
│ %40 = (isa)(%27, Nothing)::Bool
└──── goto #9 if not %40
8 ─── goto #10
9 ─── %43 = Zygote.gradindex(%27, 4)::Any
└──── goto #10
10 ┄─ %45 = φ (#8 => nothing, #9 => %43)::Any
│ %46 = (isa)(%27, Nothing)::Bool
└──── goto #12 if not %46
11 ── goto #13
12 ── %49 = Zygote.gradindex(%27, 5)::Any
└──── goto #13
13 ┄─ %51 = φ (#11 => nothing, #12 => %49)::Any
│ %52 = (isa)(%27, Nothing)::Bool
└──── goto #15 if not %52
14 ── goto #16
15 ── %55 = Zygote.gradindex(%27, 6)::Any
└──── goto #16
16 ┄─ %57 = φ (#14 => nothing, #15 => %55)::Any
│ %58 = (isa)(%27, Nothing)::Bool
└──── goto #18 if not %58
17 ── goto #19
18 ── %61 = Zygote.gradindex(%27, 7)::Any
└──── goto #19
19 ┄─ %63 = φ (#17 => nothing, #18 => %61)::Any
│ %64 = (isa)(%27, Nothing)::Bool
└──── goto #21 if not %64
20 ── goto #22
21 ── %67 = Zygote.gradindex(%27, 8)::Any
└──── goto #22
22 ┄─ %69 = φ (#20 => nothing, #21 => %67)::Any
│ %70 = (isa)(%27, Nothing)::Bool
└──── goto #24 if not %70
23 ── goto #25
24 ── %73 = Zygote.gradindex(%27, 9)::Any
└──── goto #25
25 ┄─ %75 = φ (#23 => nothing, #24 => %73)::Any
│ %76 = (isa)(%27, Nothing)::Bool
└──── goto #27 if not %76
26 ── goto #28
27 ── %79 = Zygote.gradindex(%27, 10)::Any
└──── goto #28
28 ┄─ %81 = φ (#26 => nothing, #27 => %79)::Any
│ %82 = (isa)(%27, Nothing)::Bool
└──── goto #30 if not %82
29 ── goto #31
30 ── %85 = Zygote.gradindex(%27, 11)::Any
└──── goto #31
31 ┄─ %87 = φ (#29 => nothing, #30 => %85)::Any
│ %88 = Zygote.accum(%25, %87)::Any
│ %89 = (%5)(%88)::Any
│ %90 = (isa)(%89, Nothing)::Bool
└──── goto #33 if not %90
32 ── goto #34
33 ── %93 = Zygote.gradindex(%89, 2)::Any
└──── goto #34
34 ┄─ %95 = φ (#32 => nothing, #33 => %93)::Any
│ %96 = (isa)(%89, Nothing)::Bool
└──── goto #36 if not %96
35 ── goto #37
36 ── %99 = Zygote.gradindex(%89, 3)::Any
└──── goto #37
37 ┄─ %101 = φ (#35 => nothing, #36 => %99)::Any
│ %102 = (isa)(%89, Nothing)::Bool
└──── goto #39 if not %102
38 ── goto #40
39 ── %105 = Zygote.gradindex(%89, 4)::Any
└──── goto #40
40 ┄─ %107 = φ (#38 => nothing, #39 => %105)::Any
│ %108 = (isa)(%89, Nothing)::Bool
└──── goto #42 if not %108
41 ── goto #43
42 ── %111 = Zygote.gradindex(%89, 5)::Any
└──── goto #43
43 ┄─ %113 = φ (#41 => nothing, #42 => %111)::Any
│ %114 = (isa)(%89, Nothing)::Bool
└──── goto #45 if not %114
44 ── goto #46
45 ── %117 = Zygote.gradindex(%89, 6)::Any
└──── goto #46
46 ┄─ %119 = φ (#44 => nothing, #45 => %117)::Any
│ %120 = (isa)(%89, Nothing)::Bool
└──── goto #48 if not %120
47 ── goto #49
48 ── %123 = Zygote.gradindex(%89, 7)::Any
└──── goto #49
49 ┄─ %125 = φ (#47 => nothing, #48 => %123)::Any
│ %126 = (isa)(%89, Nothing)::Bool
└──── goto #51 if not %126
50 ── goto #52
51 ── %129 = Zygote.gradindex(%89, 8)::Any
└──── goto #52
52 ┄─ %131 = φ (#50 => nothing, #51 => %129)::Any
│ %132 = (isa)(%89, Nothing)::Bool
└──── goto #54 if not %132
53 ── goto #55
54 ── %135 = Zygote.gradindex(%89, 9)::Any
└──── goto #55
55 ┄─ %137 = φ (#53 => nothing, #54 => %135)::Any
│ %138 = (isa)(%89, Nothing)::Bool
└──── goto #57 if not %138
56 ── goto #58
57 ── %141 = Zygote.gradindex(%89, 10)::Any
└──── goto #58
58 ┄─ %143 = φ (#56 => nothing, #57 => %141)::Any
│ %144 = Zygote.accum(%24, %81, %143)::Any
│ %145 = (%9)(%144)::Any
│ %146 = (isa)(%145, Nothing)::Bool
└──── goto #60 if not %146
59 ── goto #61
60 ── %149 = Zygote.gradindex(%145, 2)::Any
└──── goto #61
61 ┄─ %151 = φ (#59 => nothing, #60 => %149)::Any
│ %152 = (isa)(%145, Nothing)::Bool
└──── goto #63 if not %152
62 ── goto #64
63 ── %155 = Zygote.gradindex(%145, 3)::Any
└──── goto #64
64 ┄─ %157 = φ (#62 => nothing, #63 => %155)::Any
│ %158 = (isa)(%145, Nothing)::Bool
└──── goto #66 if not %158
65 ── goto #67
66 ── %161 = Zygote.gradindex(%145, 4)::Any
└──── goto #67
67 ┄─ %163 = φ (#65 => nothing, #66 => %161)::Any
│ %164 = (isa)(%145, Nothing)::Bool
└──── goto #69 if not %164
68 ── goto #70
69 ── %167 = Zygote.gradindex(%145, 5)::Any
└──── goto #70
70 ┄─ %169 = φ (#68 => nothing, #69 => %167)::Any
│ %170 = (isa)(%145, Nothing)::Bool
└──── goto #72 if not %170
71 ── goto #73
72 ── %173 = Zygote.gradindex(%145, 6)::Any
└──── goto #73
73 ┄─ %175 = φ (#71 => nothing, #72 => %173)::Any
│ %176 = (isa)(%145, Nothing)::Bool
└──── goto #75 if not %176
74 ── goto #76
75 ── %179 = Zygote.gradindex(%145, 7)::Any
└──── goto #76
76 ┄─ %181 = φ (#74 => nothing, #75 => %179)::Any
│ %182 = (isa)(%145, Nothing)::Bool
└──── goto #78 if not %182
77 ── goto #79
78 ── %185 = Zygote.gradindex(%145, 8)::Any
└──── goto #79
79 ┄─ %187 = φ (#77 => nothing, #78 => %185)::Any
│ %188 = (isa)(%145, Nothing)::Bool
└──── goto #81 if not %188
80 ── goto #82
81 ── %191 = Zygote.gradindex(%145, 9)::Any
└──── goto #82
82 ┄─ %193 = φ (#80 => nothing, #81 => %191)::Any
│ %194 = Zygote.accum(%23, %75, %137, %193)::Any
│ %195 = (%6)(%194)::Any
│ %196 = (isa)(%195, Nothing)::Bool
└──── goto #84 if not %196
83 ── goto #85
84 ── %199 = Zygote.gradindex(%195, 2)::Any
└──── goto #85
85 ┄─ %201 = φ (#83 => nothing, #84 => %199)::Any
│ %202 = (isa)(%195, Nothing)::Bool
└──── goto #87 if not %202
86 ── goto #88
87 ── %205 = Zygote.gradindex(%195, 3)::Any
└──── goto #88
88 ┄─ %207 = φ (#86 => nothing, #87 => %205)::Any
│ %208 = (isa)(%195, Nothing)::Bool
└──── goto #90 if not %208
89 ── goto #91
90 ── %211 = Zygote.gradindex(%195, 4)::Any
└──── goto #91
91 ┄─ %213 = φ (#89 => nothing, #90 => %211)::Any
│ %214 = (isa)(%195, Nothing)::Bool
└──── goto #93 if not %214
92 ── goto #94
93 ── %217 = Zygote.gradindex(%195, 5)::Any
└──── goto #94
94 ┄─ %219 = φ (#92 => nothing, #93 => %217)::Any
│ %220 = (isa)(%195, Nothing)::Bool
└──── goto #96 if not %220
95 ── goto #97
96 ── %223 = Zygote.gradindex(%195, 6)::Any
└──── goto #97
97 ┄─ %225 = φ (#95 => nothing, #96 => %223)::Any
│ %226 = (isa)(%195, Nothing)::Bool
└──── goto #99 if not %226
98 ── goto #100
99 ── %229 = Zygote.gradindex(%195, 7)::Any
└──── goto #100
100 ┄ %231 = φ (#98 => nothing, #99 => %229)::Any
│ %232 = (isa)(%195, Nothing)::Bool
└──── goto #102 if not %232
101 ─ goto #103
102 ─ %235 = Zygote.gradindex(%195, 8)::Any
└──── goto #103
103 ┄ %237 = φ (#101 => nothing, #102 => %235)::Any
│ %238 = Zygote.accum(%22, %69, %131, %187, %237)::Any
│ %239 = (%10)(%238)::Any
│ %240 = (isa)(%239, Nothing)::Bool
└──── goto #105 if not %240
104 ─ goto #106
105 ─ %243 = Zygote.gradindex(%239, 2)::Any
└──── goto #106
106 ┄ %245 = φ (#104 => nothing, #105 => %243)::Any
│ %246 = (isa)(%239, Nothing)::Bool
└──── goto #108 if not %246
107 ─ goto #109
108 ─ %249 = Zygote.gradindex(%239, 3)::Any
└──── goto #109
109 ┄ %251 = φ (#107 => nothing, #108 => %249)::Any
│ %252 = (isa)(%239, Nothing)::Bool
└──── goto #111 if not %252
110 ─ goto #112
111 ─ %255 = Zygote.gradindex(%239, 4)::Any
└──── goto #112
112 ┄ %257 = φ (#110 => nothing, #111 => %255)::Any
│ %258 = (isa)(%239, Nothing)::Bool
└──── goto #114 if not %258
113 ─ goto #115
114 ─ %261 = Zygote.gradindex(%239, 5)::Any
└──── goto #115
115 ┄ %263 = φ (#113 => nothing, #114 => %261)::Any
│ %264 = (isa)(%239, Nothing)::Bool
└──── goto #117 if not %264
116 ─ goto #118
117 ─ %267 = Zygote.gradindex(%239, 6)::Any
└──── goto #118
118 ┄ %269 = φ (#116 => nothing, #117 => %267)::Any
│ %270 = (isa)(%239, Nothing)::Bool
└──── goto #120 if not %270
119 ─ goto #121
120 ─ %273 = Zygote.gradindex(%239, 7)::Any
└──── goto #121
121 ┄ %275 = φ (#119 => nothing, #120 => %273)::Any
│ %276 = Zygote.accum(%21, %63, %125, %181, %231, %275)::Any
│ %277 = (%3)(%276)::Any
│ %278 = (isa)(%277, Nothing)::Bool
└──── goto #123 if not %278
122 ─ goto #124
123 ─ %281 = Zygote.gradindex(%277, 2)::Any
└──── goto #124
124 ┄ %283 = φ (#122 => nothing, #123 => %281)::Any
│ %284 = (isa)(%277, Nothing)::Bool
└──── goto #126 if not %284
125 ─ goto #127
126 ─ %287 = Zygote.gradindex(%277, 3)::Any
└──── goto #127
127 ┄ %289 = φ (#125 => nothing, #126 => %287)::Any
│ %290 = (isa)(%277, Nothing)::Bool
└──── goto #129 if not %290
128 ─ goto #130
129 ─ %293 = Zygote.gradindex(%277, 4)::Any
└──── goto #130
130 ┄ %295 = φ (#128 => nothing, #129 => %293)::Any
│ %296 = (isa)(%277, Nothing)::Bool
└──── goto #132 if not %296
131 ─ goto #133
132 ─ %299 = Zygote.gradindex(%277, 5)::Any
└──── goto #133
133 ┄ %301 = φ (#131 => nothing, #132 => %299)::Any
│ %302 = (isa)(%277, Nothing)::Bool
└──── goto #135 if not %302
134 ─ goto #136
135 ─ %305 = Zygote.gradindex(%277, 6)::Any
└──── goto #136
136 ┄ %307 = φ (#134 => nothing, #135 => %305)::Any
│ %308 = Zygote.accum(%20, %57, %119, %175, %225, %269, %307)::Any
│ %309 = (%4)(%308)::Any
│ %310 = (isa)(%309, Nothing)::Bool
└──── goto #138 if not %310
137 ─ goto #139
138 ─ %313 = Zygote.gradindex(%309, 2)::Any
└──── goto #139
139 ┄ %315 = φ (#137 => nothing, #138 => %313)::Any
│ %316 = (isa)(%309, Nothing)::Bool
└──── goto #141 if not %316
140 ─ goto #142
141 ─ %319 = Zygote.gradindex(%309, 3)::Any
└──── goto #142
142 ┄ %321 = φ (#140 => nothing, #141 => %319)::Any
│ %322 = (isa)(%309, Nothing)::Bool
└──── goto #144 if not %322
143 ─ goto #145
144 ─ %325 = Zygote.gradindex(%309, 4)::Any
└──── goto #145
145 ┄ %327 = φ (#143 => nothing, #144 => %325)::Any
│ %328 = (isa)(%309, Nothing)::Bool
└──── goto #147 if not %328
146 ─ goto #148
147 ─ %331 = Zygote.gradindex(%309, 5)::Any
└──── goto #148
148 ┄ %333 = φ (#146 => nothing, #147 => %331)::Any
│ %334 = Zygote.accum(%19, %51, %113, %169, %219, %263, %301, %333)::Any
│ %335 = (%8)(%334)::Any
│ %336 = (isa)(%335, Nothing)::Bool
└──── goto #150 if not %336
149 ─ goto #151
150 ─ %339 = Zygote.gradindex(%335, 2)::Any
└──── goto #151
151 ┄ %341 = φ (#149 => nothing, #150 => %339)::Any
│ %342 = (isa)(%335, Nothing)::Bool
└──── goto #153 if not %342
152 ─ goto #154
153 ─ %345 = Zygote.gradindex(%335, 3)::Any
└──── goto #154
154 ┄ %347 = φ (#152 => nothing, #153 => %345)::Any
│ %348 = (isa)(%335, Nothing)::Bool
└──── goto #156 if not %348
155 ─ goto #157
156 ─ %351 = Zygote.gradindex(%335, 4)::Any
└──── goto #157
157 ┄ %353 = φ (#155 => nothing, #156 => %351)::Any
│ %354 = Zygote.accum(%18, %45, %107, %163, %213, %257, %295, %327, %353)::Any
│ %355 = (isa)(%354, Nothing)::Bool
└──── goto #159 if not %355
158 ─ goto #160
159 ─ %358 = (Zygote.var"#3295#back#826"{Zygote.var"#824#825"}(Zygote.var"#824#825"()))(%354)::Union{Nothing, Tuple{Nothing, Any, Any}}
└──── goto #160
160 ┄ %360 = φ (#158 => nothing, #159 => %358)::Union{Nothing, Tuple{Nothing, Any, Any}}
│ %361 = (isa)(%360, Nothing)::Bool
└──── goto #162 if not %361
161 ─ goto #163
162 ─ %364 = Zygote.gradindex(%360, 2)::Any
└──── goto #163
163 ┄ %366 = φ (#161 => nothing, #162 => %364)::Any
│ %367 = (isa)(%360, Nothing)::Bool
└──── goto #165 if not %367
164 ─ goto #166
165 ─ %370 = Zygote.gradindex(%360, 3)::Any
└──── goto #166
166 ┄ %372 = φ (#164 => nothing, #165 => %370)::Any
│ %373 = Zygote.accum(%16, %17, %33, %39, %95, %101, %151, %157, %201, %207, %245, %251, %283, %289, %315, %321, %341, %347, %366, %372)::Any
│ %374 = Zygote.tuple(nothing, %373)::Core.PartialStruct(Tuple{Nothing, Any}, Any[Core.Const(nothing), Any])
└──── return %374
167 ─ $(Expr(:meta, :inline))
) => Tuple{Nothing, Any}
I am far from the best one to interpret these, but it seems that the pullback doesn't have much of a chance to "clean up after itself" as it calculates the final gradient. The accum
looks especially suspicious given that back
(which normally calls back.back
) simply pulls out one element and discards the rest.
More interesting would be seeing how PyTorch handles the equivalent autograd graph. Does it free as it goes, and if so can Zygote be taught to do the same?
@DhairyaLGandhi I added timings, pytorch is 2.5x faster.
Trying #962 on this (CPU):
julia> @btime gradient(x -> sum(abs2, net(x)), $(rand(50,50,50,50)));
1.133 s (8266 allocations: 3.07 GiB) # master
618.508 ms (8120 allocations: 524.94 MiB) # first version of PR, 8fafdf0, wrong answer :(
951.498 ms (8206 allocations: 954.10 MiB) # without mutation, e95ba74
--
1.889 s (8289 allocations: 3.07 GiB) # with mutation and preventative copies
995.328 ms (8384 allocations: 954.08 MiB) # with simple copy-on-write wrapper
@mcabbott very cool, that you are working on this. With your branch I get:
x = CUDA.randn(128,128,128,10)
Zygote.gradient(loss, x) #warmup
CUDA.@time for _ in 1:100
Zygote.gradient(loss, x)
end
# original
# 26.188100 seconds (449.09 k CPU allocations: 15.913 MiB, 46.80% gc time) (11.30 k GPU allocations: 867.188 GiB, 59.73% gc time of which 18.51% spent allocating)
# mcabbott:getindex2 e95ba74c34f1aaedc43522f13ce0fc3f8efd98bf (latest commit when writing this post)
# 11.991899 seconds (29.67 M CPU allocations: 462.594 MiB, 19.44% gc time) (6.70 k GPU allocations:
# 507.813 GiB, 19.71% gc time of which 0.58% spent allocating)
Which looks too good to be true. ~Also I don't get the out of memory error anymore.~ Also the out of memory situation improved. Before pytorch could handle 4x bigger arrays, now it can handle only 2x bigger arrays. Is this real, or am I using an incorrect commit e95ba74c34f1aaedc43522f13ce0fc3f8efd98bf?
I believe that commit should be safe, give correct answers. Great that it helps much more in your timing than in mine, labelled "without mutation" above.
I think that all the benefit is from this line, accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)
. Interesting that it changed allocations for me to 1/3, but for you only 867 -> 507 GB.
Accumulating all at once like this (not pairwise as before) seems perfectly aligned with this test. Not sure how much it will help on more realistic problems, though. Do you have any examples handy of simple but not quite so simple functions, which might be worth timing?
Not pushed to #962, but I think it ought to be possible to safely accumulate in-place if you make some preventative copies of the gradient. This works wonders for the examples there, but is a disaster here (my last time). Edit -- here's the branch: https://github.com/FluxML/Zygote.jl/compare/master...mcabbott:inplace .
@mcabbott here is an example more close to my actual use case. Here runtime of pytorch is ~1.6x faster and memory usage is ~4x lower.
Zygote
Zygote implementation
using Flux
import CUDA
struct DenseLayer
channelwise
spatial
end
Flux.@functor DenseLayer
function (o::DenseLayer)(xs)
if length(xs) != length(o.channelwise)
@show length(xs)
@show length(o.channelwise)
@show map(size, xs)
error()
end
x = mapreduce(+, o.channelwise, xs) do c, xi
c(xi)
end
return o.spatial(leakyrelu.(x))
end
function DenseLayer(;spacedim, bn_channels, in_channels, out_channels)
channelwise = map(in_channels) do ncin
Flux.Conv(ntuple(_->1, spacedim), ncin => bn_channels)
end
spatial = Flux.Conv(ntuple(_->3, spacedim), bn_channels=>out_channels, leakyrelu, pad=SamePad())
DenseLayer(channelwise, spatial)
end
struct DenseNet
layers::Vector{DenseLayer}
final::DenseLayer
end
Flux.@functor DenseNet
function (o::DenseNet)(x)
xs = (x,)
for l in o.layers
x_new = l(xs)
xs = (xs..., x_new)
end
o.final(xs)
end
function DenseNet(;
spacedim,
in_channels,
out_channels,
nlayers,
growth,
)
layers = DenseLayer[]
cin = (in_channels,)
for i in 1:nlayers
l = DenseLayer(;spacedim, bn_channels=4*growth, in_channels=cin, out_channels=growth)
push!(layers, l)
cin = (cin..., growth,)
end
final = DenseLayer(;spacedim, bn_channels=4*out_channels, in_channels=cin, out_channels=out_channels)
DenseNet(layers, final)
end
using Zygote
nb = 1 # nb = 2 gives out of memory
nx = 64
ny = 128
nz = 128
nc = 3
net = DenseNet(spacedim=3, growth=4, nlayers=4, out_channels=1, in_channels=nc) |> gpu
x = CUDA.randn(Float32, nx, ny, nz, nc, nb)
loss = () -> sum(abs2, net(x))
Zygote.gradient(loss, params(net)) # warmup
for _ in 1:5
CUDA.@time for _ in 1:20
Zygote.gradient(loss, params(net))
end
end
# [email protected]
# 7.015238 seconds (8.62 M CPU allocations: 171.998 MiB, 52.21% gc time) (4.06 k GPU allocations: 89.101 GiB, 54.87% gc time of which 1.31% spent allocating)
# 7.308631 seconds (9.32 M CPU allocations: 182.692 MiB, 50.38% gc time) (4.06 k GPU allocations: 89.101 GiB, 52.91% gc time of which 1.75% spent allocating)
# 7.435183 seconds (9.70 M CPU allocations: 188.559 MiB, 49.42% gc time) (4.06 k GPU allocations: 89.101 GiB, 51.31% gc time of which 1.71% spent allocating)
# 7.419201 seconds (9.71 M CPU allocations: 188.697 MiB, 49.69% gc time) (4.06 k GPU allocations: 89.101 GiB, 51.60% gc time of which 1.76% spent allocating)
# 7.437993 seconds (10.05 M CPU allocations: 193.934 MiB, 49.81% gc time) (4.06 k GPU allocations: 89.101 GiB, 51.28% gc time of which 1.02% spent allocating)
# master 2021-05-27 Zygote#7c66eff0a9200b1095f1c94ac6faec9f16b740fe
# 6.947657 seconds (8.75 M CPU allocations: 178.133 MiB, 52.08% gc time) (4.06 k GPU allocations: 89.101 GiB, 53.18% gc time of which 0.61% spent allocating)
# 7.086846 seconds (8.47 M CPU allocations: 170.352 MiB, 52.77% gc time) (4.06 k GPU allocations: 89.101 GiB, 54.91% gc time of which 1.31% spent allocating)
# 7.072559 seconds (8.53 M CPU allocations: 170.607 MiB, 52.17% gc time) (4.06 k GPU allocations: 89.101 GiB, 54.68% gc time of which 1.91% spent allocating)
# 7.046045 seconds (8.52 M CPU allocations: 170.505 MiB, 52.73% gc time) (4.06 k GPU allocations: 89.101 GiB, 54.55% gc time of which 1.82% spent allocating)
# 7.162578 seconds (9.06 M CPU allocations: 178.848 MiB, 51.45% gc time) (4.06 k GPU allocations: 89.101 GiB, 53.06% gc time of which 1.23% spent allocating)
# inplace accum https://github.com/mcabbott/Zygote.jl.git#5315d12b2080c76d21889d2266ca11a9160e25fe:
# 7.458384 seconds (9.75 M CPU allocations: 193.955 MiB, 50.40% gc time) (4.34 k GPU allocations: 87.461 GiB, 51.51% gc time of which 0.60% spent allocating)
# 7.352417 seconds (9.76 M CPU allocations: 190.535 MiB, 49.46% gc time) (4.34 k GPU allocations: 87.461 GiB, 50.97% gc time of which 0.53% spent allocating)
# 7.742949 seconds (10.17 M CPU allocations: 196.290 MiB, 49.13% gc time) (4.34 k GPU allocations: 87.461 GiB, 51.26% gc time of which 1.56% spent allocating)
# 7.046813 seconds (8.96 M CPU allocations: 177.890 MiB, 51.17% gc time) (4.34 k GPU allocations: 87.461 GiB, 52.89% gc time of which 1.34% spent allocating)
# 7.180376 seconds (8.77 M CPU allocations: 175.013 MiB, 52.58% gc time) (4.34 k GPU allocations: 87.461 GiB, 54.41% gc time of which 1.83% spent allocating)
Pytorch
Pytorch implementation
import torch
import torch.nn as nn
from torch.nn.functional import leaky_relu
def Conv(in_channels, out_channels, kernel_size,**kwargs):
spacedim = len(kernel_size)
Ctor = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d][spacedim]
return Ctor(in_channels, out_channels, kernel_size,**kwargs)
class DenseLayer(nn.Module):
def __init__(self, spacedim, bn_channels, in_channels, out_channels):
super().__init__()
self.in_channels = in_channels
self.bn_channels = bn_channels
self.out_channels = out_channels
self.channelwise = nn.ModuleList()
for ncin in self.in_channels:
l = Conv(in_channels=ncin,
out_channels=self.bn_channels,
kernel_size=tuple(1 for _ in range(spacedim))
)
self.channelwise.append(l)
self.spatial = Conv(in_channels=bn_channels,
out_channels=out_channels,
kernel_size=tuple(3 for _ in range(spacedim)),
padding=1,
)
def forward(self, xs):
assert len(xs) == len(self.channelwise)
x = self.channelwise[0](xs[0])
for i in range(1,len(xs)):
x += self.channelwise[i](xs[i])
x = leaky_relu(x)
x = self.spatial(x)
x = leaky_relu(x)
return x
class DenseNet(nn.Module):
def __init__(self, spacedim, in_channels, out_channels, nlayers, growth):
super().__init__()
self.layers = nn.ModuleList()
cin = [in_channels]
for i in range(nlayers):
l = DenseLayer(spacedim=spacedim, bn_channels=4*growth, in_channels=cin, out_channels=growth)
self.layers.append(l)
cin.append(growth)
self.final = DenseLayer(spacedim=spacedim, bn_channels=4*out_channels, in_channels=cin,
out_channels=out_channels)
def forward(self, x):
xs = [x]
for l in self.layers:
x_new = l(xs)
xs.append(x_new)
return self.final(xs)
nb = 5 # nb = 6 gives out of memory
nx = 64
ny = 128
nz = 128
nc = 3
x = torch.randn(nb, nc, nx, ny, nz).to("cuda")
net = DenseNet(spacedim=3, growth=4, nlayers=4, out_channels=1, in_channels=nc).to("cuda")
def loss(x):
return torch.sum(x**2)
y = loss(net(x))
y.backward()
import time
for _ in range(5):
start = time.time()
for _ in range(20):
net.zero_grad()
y = loss(net(x))
y.backward()
stop = time.time()
print((stop - start) / nb)
# 3.6041640758514406
# 4.214616394042968
# 4.326757097244263
# 4.446330928802491
# 4.760152387619018
OK, this is a real model! I'm curious whether #962 (i.e. Zygote#master, right now) and #981 make any difference? My guess is that they won't, since every weight is used once; maybe a tiny bit from that mapreduce(+, ...
?
Am I reading correctly that there are no nonlinearities here, you never broadcast relu
etc? That step was something which looked like there was room to fuse things & save.
Am I reading correctly that there are no nonlinearities here, you never broadcast relu etc? That step was something which looked like there was room to fuse things & save.
You are right, it is a good idea to add activations to the example. I will update and then rerun with #981 and #962 . With the current variant I already tested, that there is no significant speedup from #981 .
I updated my post to include non linearities and multiple Zygote commits.
This is a useful Zygote issue, but just to address the larger issue: could you try replacing your DenseNet
implementation with SkipConnection
s? This would avoid the xs = (xs..., x_new)
entirely. I'm just curious to see if there is a difference.
Thanks for the suggestion @darsnack . I don't see a way good way to implement this using SkipConnection
. I see complicated ways or ways that lead to the quadratic memory problem of naive densenet implementations. Can you sketch how you would do this? Also is there a reason why xs = (xs..., x_new)
should be avoided or why SkipConnection
could improve performance?
You could use this as a reference. It does get a bit more complicated, though it looks like you won't need a bunch of the extra code in there like pooling or normalization. Basically, right now your DenseLayer
(which I think is referred to as a "bottleneck" in the DenseNet paper) expects a tuple of the outputs from previous bottlenecks. It applies a separate convolution to each element in the tuple, then reduces with +
. This is the same as if the elements of the tuple were concatenated and you applied a single conv. So, to use SkipConnection
, you would follow the recipe in dense_bottleneck
in the code I linked. dense_block
in the linked code will give you the equivalent of DenseNet
in your code.
Also is there a reason why xs = (xs..., x_new) should be avoided or why SkipConnection could improve performance?
I'm not positive if it will improve performance or memory allocations. But SkipConnection
is widely used for constructing models like a dense block, so I would hope it does not suffer from the same problem. I think conceptually the current implementation looks like:
x2 = f(x1)
x3 = f(x1, x2)
x4 = f(x1, x2, x3)
...
But SkipConnection
avoids the building of xs
to make it look like
x2 = h(x1, g1(x1))
x3 = h(x2, g2(x2))
x4 = h(x3, g3(x3))
...
When you unroll the full computation, they are equivalent. But the latter doesn't keep increasing the number of input arguments. Perhaps that is irrelevant for Zygote, but it would be good to know if so.
You could use this as a reference. It does get a bit more complicated, though it looks like you won't need a bunch of the extra code in there like pooling or normalization. Basically, right now your
DenseLayer
(which I think is referred to as a "bottleneck" in the DenseNet paper)
Ah yes, you are right, "bottleneck" may be a better name.
It applies a separate convolution to each element in the tuple, then reduces with
+
. This is the same as if the elements of the tuple were concatenated and you applied a single conv.
I agree both are mathematically equivalent. The trouble with the concatenation formulation is, that it produces a big tensor at each bottleneck. This means the memory needed to record a forward pass is quadratic in the number of layers. In fact, I specifically tried to avoid the quadratic memory problem by using the +
formulation. An alternative way to deal with the memory hunger is gradient checkpointing.
But
SkipConnection
is widely used for constructing models like a dense block, so I would hope it does not suffer from the same problem.
Can you give examples where this is used for densenet like stuff? For sure it is useful for ResNet
like things. Is the Metalhead
densnet you linked meant for use in memory intense situations?
I think conceptually the current implementation looks like:
x2 = f(x1) x3 = f(x1, x2) x4 = f(x1, x2, x3) ...
But
SkipConnection
avoids the building ofxs
to make it look likex2 = h(x1, g1(x1)) x3 = h(x2, g2(x2)) x4 = h(x3, g3(x3)) ...
There is a subtle point here. The xi
in the first case are much smaller, then the xi
in the second approach. Lets make this difference more visually explicit:
x12 = h(x1, g1(x1))
x123 = h(x12, g2(x12))
x1234 = h(x123, g3(x123))
...
The relationship with my approach is x123... = cat(x1,x2,x3,...)
. At the end of the day, you have to pass the same amount of numbers to the next bottleneck, whether they are packaged in one big or many small tensors. But doing all these cats
naively will lead to quadratic memory usage. One way to think about my approach is that it does the cat implicitly and therefore is not kept alive for the backward pass.
I was going to suggest replacing the mapreduce
(and perhaps the spatial
, if you can give up the intermediate leakyrelu
) with a depthwise conv layer, but unfortunately we don't support that layer on GPU yet.
The trouble with the concatenation formulation is, that it produces a big tensor at each bottleneck. This means the memory needed to record a forward pass is quadratic in the number of layers.
Ah, I get what you mean now. I think your implementation is the simplest way to get a model without this problem. So solving the Zygote issue here should be the primary goal.
You could write your approach using SkipConnection
with (x, y) -> (x..., y)
as the combination operator. But I don't know if it would make a difference. The only difference would be the xs = (xs..., x_new)
(reassigning to the same tuple vs creating a new tuple without allocating new tensors). You could even use Parallel
which is a more generic version of SkipConnection
that understands a Vararg
/Tuple
input as being multiple separate inputs. It would make mapping multiple Conv
s across each tuple element easier.
Another way to look at it is that the problem is that cat(x, y; dims =3)
produces a new array instead of an "anti-view
" of x
and y
. So, an alternative way to keep the cat
-style SkipConnection
s and not have a memory explosion is something like CatViews.jl but for N-D arrays instead of just vectors.
Can you give examples where this is used for densenet like stuff? For sure it is useful for ResNet like things. Is the Metalhead densnet you linked meant for use in memory intense situations?
The repo I linked probably has all the examples you would want. There's also some Unet stuff floating around that uses Parallel
. But with cat
all these approaches have the problem you highlighted. So not really a solution.