Oceananigans.jl
Oceananigans.jl copied to clipboard
Correct `AveragedTimeInterval` to use actuations
This is a cleanup version of the closed PR https://github.com/CliMA/Oceananigans.jl/pull/3717
@liuchihl, thanks for cleaning up these changes by separating them from the background flux PR—it's much clearer now.
Consolidating @glwagner and @navidcy's earlier comments, it seems there are three things that need to be done before this can be merged:
- Review the the existing OutputWriter tests and verify that they still pass with the new implementation in this PR
- Create a new, more rigorous, test that is capable of flagging the bizarre behavior you found in your issue but (hopefully) now passes thanks to the changes in this branch.
- Add some warnings to let users know that
TimeIntervalandAveragedTimeInterval(and probably other diagnostic schedules) are currently broken and give incorrect results after picking up from a checkpoint whenever the checkpoint interval is not an integer multiple of the scheduled time interval.
Add some warnings to let users know that TimeInterval and AveragedTimeInterval (and probably other diagnostic schedules) are currently broken and give incorrect results after picking up from a checkpoint whenever the checkpoint interval is not an integer multiple of the scheduled time interval.
Can we make it so this warning gets thrown only if one is using a Checkpointer? I think the majority of simulations do not use a Checkpointer so the warning would be irrelevant in most cases. Maybe we should put the warning within the checkpointer constructor.
Checkpointing certainly needs love. I think it's only used for barebones stuff right now, not complicated simulations. To be fully featured we have to somehow have a system for checkpointing all callbacks. It's not just AveragedTimeInterval that would have a problem.
I think the majority of simulations do not use a Checkpointer so the warning would be irrelevant in most cases.
I don't get this. How are people not using a Checkpointer? Is no one else limited by HPC wall times or running long simulations? It seems like one of the most fundamental capabilities of any time-stepped numerical model.
But yes, this warning should only be issued if both OutputWriter and Checkpointers are being used and if the checkpointer interval is not an integer multiple of the OutputWriter interval.
I think the majority of simulations do not use a Checkpointer so the warning would be irrelevant in most cases.
I don't get this. How are people not using a Checkpointer? Is no one else limited by HPC wall times or running long simulations? It seems like one of the most fundamental capabilities of any time-stepped numerical model.
But yes, this warning should only be issued if a Checkpointer is used (and maybe also when a simulation picks up from an existing Checkpoint).
I think for sophisticated research Checkpointing is common, but for simpler classroom and LES applications the checkpointer is used less. After all, probably the most simulations are actually run in our examples on CI -- and there are no examples with a checkpointer! (It would be nice to change that)
I can't speak for others, but for boundary layer parameterization work the LES typically run in less than 24 hours of wall time. We also only utilize very simple diagnostics, like the horizontally-averaged solution at the final time step. So in those rare cases that we need a checkpointer (I have used a handful of times) barebones checkpointing is sufficient.
Of course we are currently working on building a OMIP simulation and that will require much longer runs, so we will definitely need more sophisticated checkpointing very soon.
@simone-silvestri and @tomchor might have more to add. Or @sandreza, what do you use for the neverworld work?
I'm not saying we don't want to develop this, I'm just providing some context about why this hasn't been resolved / developed yet.
In an ideal world the simulations would run fast enough that we wouldn't need checkpointing, after all 😄
I think for sophisticated research Checkpointing is common, but for simpler classroom and LES applications the checkpointer is used less. After all, probably the most simulations are actually run in our examples on CI -- and there are no examples with a checkpointer! (It would be nice to change that)
Agreed!
I can't speak for others, but for boundary layer parameterization work the LES typically run in less than 24 hours of wall time. We also only utilize very simple diagnostics, like the horizontally-averaged solution at the final time step. So in those rare cases that we need a checkpointer (I have used a handful of times) barebones checkpointing is sufficient.
Of course we are currently working on building a OMIP simulation and that will require much longer runs, so we will definitely need more sophisticated checkpointing very soon.
@simone-silvestri and @tomchor might have more to add. Or @sandreza, what do you use for the neverworld work?
For context, 100% of my simulations have used checkpoints. As far as I know, 100% of the simulations from others in my group also use checkpoints. The only exceptions for my case are very early scripts still in the development phase, and still with very coarse grids. As soon as I try to get more serious with it, I need checkpoints. So this a PR I'm very much looking forward to seeing merged ;)
@hdrake @glwagner
The original windowed_time_average successfully passes the NetCDF OutputWriter test. However, when running the existing MWE, test_netcdf_time_averaging , setting Δt to 0.01 and an average window of 3Δt produces a similar discontinuity to what was observed in https://github.com/CliMA/Oceananigans.jl/issues/3670#issuecomment-2259395775, as shown in the figure below. The vertical lines indicate the start and end of the windows.
With this new PR, the same test now yields a smooth solution:
While the discontinuity caused by rounding errors has been resolved, not all cases with different Δt pass the test, e.g.,
for (n, t) in enumerate(single_ds["time"][2:end])
averaging_times = [t - n*Δt for n in 0:stride:window_size-1 if t - n*Δt >= 0]
@test all(isapprox.(single_ds["c1"][:, n+1], c̄1(averaging_times), rtol=1e-5))
end
Here is an example of a case that does not pass the test:
using Oceananigans
using Plots
using NCDatasets
using Test
arch = CPU()
topo = (Periodic, Periodic, Periodic)
domain = (x=(0, 1), y=(0, 1), z=(0, 1))
grid = RectilinearGrid(arch, topology=topo, size=(4, 4, 4); domain...)
λ1(x, y, z) = x + (1 - y)^2 + tanh(z)
λ2(x, y, z) = x + (1 - y)^2 + tanh(4z)
Fc1(x, y, z, t, c1) = - λ1(x, y, z) * c1
Fc2(x, y, z, t, c2) = - λ2(x, y, z) * c2
c1_forcing = Forcing(Fc1, field_dependencies=:c1)
c2_forcing = Forcing(Fc2, field_dependencies=:c2)
model = NonhydrostaticModel(; grid,
timestepper = :RungeKutta3,
tracers = (:c1, :c2),
forcing = (c1=c1_forcing, c2=c2_forcing))
set!(model, c1=1, c2=1)
Δt = .01 #1/64 # Nice floating-point number
simulation = Simulation(model, Δt=Δt, stop_time=150Δt)
∫c1_dxdy = Field(Average(model.tracers.c1, dims=(1, 2)))
∫c2_dxdy = Field(Average(model.tracers.c2, dims=(1, 2)))
nc_outputs = Dict("c1" => ∫c1_dxdy, "c2" => ∫c2_dxdy)
nc_dimensions = Dict("c1" => ("zC",), "c2" => ("zC",))
single_time_average_nc_filepath = "single_decay_windowed_time_average_test.nc"
window_nΔt = 3
window = window_nΔt*Δt
interval_nΔt = 5
interval = interval_nΔt*Δt
stride = 1
single_nc_output = Dict("c1" => ∫c1_dxdy)
single_nc_dimension = Dict("c1" => ("zC",))
simulation.output_writers[:single_output_time_average] =
NetCDFOutputWriter(model, single_nc_output,
array_type = Array{Float64},
verbose = true,
filename = single_time_average_nc_filepath,
schedule = AveragedTimeInterval(interval, window = window, stride = stride),
dimensions = single_nc_dimension,
overwrite_existing = true)
run!(simulation)
##### For each λ, horizontal average should evaluate to
#####
##### c̄(z, t) = ∫₀¹ ∫₀¹ exp{- λ(x, y, z) * t} dx dy
##### = 1 / (Nx*Ny) * Σᵢ₌₁ᴺˣ Σⱼ₌₁ᴺʸ exp{- λ(i, j, k) * t}
#####
##### which we can compute analytically.
# ds = NCDataset(horizontal_average_nc_filepath)
Nx, Ny, Nz = size(grid)
xs, ys, zs = nodes(model.tracers.c1)
c̄1(z, t) = 1 / (Nx * Ny) * sum(exp(-λ1(x, y, z) * t) for x in xs for y in ys)
c̄2(z, t) = 1 / (Nx * Ny) * sum(exp(-λ2(x, y, z) * t) for x in xs for y in ys)
rtol = 1e-5 # need custom rtol for isapprox because roundoff errors accumulate (?)
# Compute time averages...
c̄1(ts) = 1/length(ts) * sum(c̄1.(zs, t) for t in ts)
c̄2(ts) = 1/length(ts) * sum(c̄2.(zs, t) for t in ts)
#####
##### Test strided windowed time average against analytic solution
##### for *single* NetCDF output
#####
single_ds = NCDataset(single_time_average_nc_filepath)
attribute_names = ("schedule", "interval", "output time interval",
"time_averaging_window", "time averaging window",
"time_averaging_stride", "time averaging stride")
for name in attribute_names
@test haskey(single_ds.attrib, name) && !isnothing(single_ds.attrib[name])
end
window_size = Int(window/Δt)
@info " Testing time-averaging of a single NetCDF output [$(typeof(arch))]..."
for (n, t) in enumerate(single_ds["time"][2:end])
averaging_times = [t - n*Δt for n in 0:stride:window_size-1 if t - n*Δt >= 0]
# @info n,t,averaging_times, c̄1(averaging_times).-single_ds["c1"][:, n+1]
@test all(isapprox.(single_ds["c1"][:, n+1], c̄1(averaging_times), rtol=rtol))
end
I believe there might be some minor issues in our new PR that still need to be addressed.
@liuchihl, did you find any cases where the windowed_time_average in this PR fails the test with window_nΔt == interval_nΔt? Maybe we have an indexing error when we are waiting for the window to start and we actually average over slightly the wrong period? It also would be good to confirm that the problem is with windowed_time_average, and not with the analytical solution that we're comparing to!
did you find any cases where the windowed_time_average in this PR fails the test with window_nΔt == interval_nΔt?
No, the test passes if window_nΔt equals interval_nΔt. However, if they differ, even if by a multiple integer, the test fails.
Well that's promising at least!
Ah, one caveat, this only holds when the timestep is sufficiently small.
That's great work! Can you summarize what you did? I'm wondering if it makes sense that this is hard or if we should actually consider a more fundamental redesign to make it more robust...
Sure, the PR resolves the rounding issue caused by previous_interval_stop_time through the use of actuations (as inspired by its application here).
Here is the part of the code showing these changes.
Another important change is that
# Save averaging start time and the initial data collection time
wta.window_start_time = model.clock.time
wta.window_start_iteration = model.clock.iteration
wta.previous_collection_time = model.clock.time
wta.schedule.collecting = false
wta.schedule.actuations += 1
occurs only when the window ends, i.e., when end_of_window(wta.schedule, model.clock) == true. In contrast, the previous version triggered this only when the model wasn't collecting.
I'm wondering if it makes sense that this is hard or if we should actually consider a more fundamental redesign to make it more robust...
I agree that a more fundamental redesign could improve robustness in the long term. That said, the current adjustments seem to resolve the issue for now (I'll look into why certain cases aren't passing the test). We can continue to monitor its performance and consider a redesign if further issues arise.
@liuchihl, can you explain a bit more your caveat about the new method only passing the test if the timestep is small enough? With the new method, the windowed-time-averages still look good by eye but quantitatively the errors are larger than the default relative tolerance of 1e-5, right? That seems it would still be an improvement over spurious zeros that sometimes show up with the previous method?
Yes, I realize that increasing Δt leads to larger errors (especially for the first two windows but numerical and analytical solutions would match in subsequent windows), which would surpass the default relative tolerance of 1e-5.
The spurious zeros are definitely incorrect, so solving that is an improvement.
For sure, we are happy to merge any improvement no matter how small. PRs only have to push the project forward, they don't have to be "big" or "important".
Independent of that, it'd be good to have a clear conclusion about whether a redesign really is needed too because we are in a good position to make that judgment now.
Should we unmark this as a draft and try to get it merged? If there is still a problem and you know how to test it, you can add a @test_broken to the test suite. Then after we merge this PR, you can open another PR which marks that as a @test and takes up where this one leaves off.
@glwagner, give us a couple of more days to figure sort this out. I'm not so sure that we need a fundamental redesign, but I think that should be revisited when we decide to take on the more challenging issue of checkpointing partially-accumulated time-averages so that windowed_time_average always works as intended when picking up from checkpoints.
I think that should be revisited when we decide to take on the more challenging issue of checkpointing partially-accumulated time-averages so that windowed_time_average always works as intended when picking up from checkpoints.
good point!
There is a possibility that it is not very hard. It will require reshuffling code (which I can do), but with Julia we can serialize objects to disk and then load them back seamlessly in a single line... which might be all we need. The only limitation of serialization is that we haven't yet figured out how to serialize methods (eg functions) which prevents us from serializing entire models. When functions are not involved things can be pretty simple.
Our earlier tests with a simple sine function indicate that when the checkpoint interval is an integer multiple of the AveragedTimeInterval, the results after the checkpoint seem reasonable. However, I’ve noticed this isn't the case with the following parameter settings, for instance:
Δt = .01 # timestep
T1 = 6Δt # first simulation stop time
T2 = 2T1 # second simulation stop time
window_nΔt = 2 # window interval: 2Δt
interval_nΔt = 2 # time average saving interval: 2Δt
stride = 1
The averaged values are clearly off after the checkpoint (t>6Δt):
This issue does not only occur in the existing MWE (decaying function); it also occurs in our MWE using the exact same parameters mentioned above. The dashed curve and steps indicate TimeInterval and AveragedTimeInterval outputs, respectively, shown in the figure below. It is unclear to me as to why spurious zero appears in this case.
The point of these tests is to show that even when the checkpoint interval is an integer multiple of the AveragedTimeInterval, issues can still arise.
Here is the MWE with the decaying function for reference:
using Oceananigans
using Plots
using NCDatasets
using Test
if isfile("single_decay_windowed_time_average_test.nc")
rm("single_decay_windowed_time_average_test.nc")
end
run(`sh -c "rm test_iteration*.jld2"`)
function test_simulation(stop_time, Δt, window_nΔt, interval_nΔt, stride, overwrite)
arch = CPU()
topo = (Periodic, Periodic, Periodic)
domain = (x=(0, 1), y=(0, 1), z=(0, 1))
grid = RectilinearGrid(arch, topology=topo, size=(4, 4, 4); domain...)
λ1(x, y, z) = x + (1 - y)^2 + tanh(z)
λ2(x, y, z) = x + (1 - y)^2 + tanh(4z)
Fc1(x, y, z, t, c1) = - λ1(x, y, z) * c1
Fc2(x, y, z, t, c2) = - λ2(x, y, z) * c2
c1_forcing = Forcing(Fc1, field_dependencies=:c1)
c2_forcing = Forcing(Fc2, field_dependencies=:c2)
model = NonhydrostaticModel(; grid,
timestepper = :RungeKutta3,
tracers = (:c1, :c2),
forcing = (c1=c1_forcing, c2=c2_forcing))
set!(model, c1=1, c2=1)
simulation = Simulation(model, Δt=Δt, stop_time=stop_time)
∫c1_dxdy = Field(Average(model.tracers.c1, dims=(1, 2)))
∫c2_dxdy = Field(Average(model.tracers.c2, dims=(1, 2)))
nc_outputs = Dict("c1" => ∫c1_dxdy, "c2" => ∫c2_dxdy)
nc_dimensions = Dict("c1" => ("zC",), "c2" => ("zC",))
single_time_average_nc_filepath = "single_decay_windowed_time_average_test.nc"
window = window_nΔt*Δt
interval = interval_nΔt*Δt
single_nc_output = Dict("c1" => ∫c1_dxdy)
single_nc_dimension = Dict("c1" => ("zC",))
simulation.output_writers[:single_output_time_average] =
NetCDFOutputWriter(model, single_nc_output,
array_type = Array{Float64},
verbose = true,
filename = single_time_average_nc_filepath,
schedule = AveragedTimeInterval(interval, window = window, stride = stride),
dimensions = single_nc_dimension,
overwrite_existing = overwrite)
checkpointer = Checkpointer(model,
schedule = TimeInterval(stop_time),
prefix = "test",
cleanup = true)
simulation.output_writers[:checkpointer] = checkpointer
return simulation
end
Δt = .01 #1/64 # Nice floating-point number
T1 = 6Δt # first simulation stop time (s)
T2 = 2T1 # second simulation stop time (s)
window_nΔt = 2
interval_nΔt = 2
stride = 1
# Run a simulation that saves data to a checkpoint
simulation = test_simulation(T1, Δt, window_nΔt, interval_nΔt, stride, true)
run!(simulation)
# Now try again, but picking up from the previous checkpoint
N = iteration(simulation)
checkpoint = "test_iteration$N.jld2"
simulation = test_simulation(T2, Δt, window_nΔt, interval_nΔt, stride, false)
run!(simulation, pickup=checkpoint)
##### For each λ, horizontal average should evaluate to
#####
##### c̄(z, t) = ∫₀¹ ∫₀¹ exp{- λ(x, y, z) * t} dx dy
##### = 1 / (Nx*Ny) * Σᵢ₌₁ᴺˣ Σⱼ₌₁ᴺʸ exp{- λ(i, j, k) * t}
#####
##### which we can compute analytically.
arch = CPU()
topo = (Periodic, Periodic, Periodic)
domain = (x=(0, 1), y=(0, 1), z=(0, 1))
grid = RectilinearGrid(arch, topology=topo, size=(4, 4, 4); domain...)
λ1(x, y, z) = x + (1 - y)^2 + tanh(z)
λ2(x, y, z) = x + (1 - y)^2 + tanh(4z)
Fc1(x, y, z, t, c1) = - λ1(x, y, z) * c1
Fc2(x, y, z, t, c2) = - λ2(x, y, z) * c2
c1_forcing = Forcing(Fc1, field_dependencies=:c1)
c2_forcing = Forcing(Fc2, field_dependencies=:c2)
model = NonhydrostaticModel(; grid,
timestepper = :RungeKutta3,
tracers = (:c1, :c2),
forcing = (c1=c1_forcing, c2=c2_forcing))
Nx, Ny, Nz = size(grid)
xs, ys, zs = nodes(model.tracers.c1)
c̄1(z, t) = 1 / (Nx * Ny) * sum(exp(-λ1(x, y, z) * t) for x in xs for y in ys)
c̄2(z, t) = 1 / (Nx * Ny) * sum(exp(-λ2(x, y, z) * t) for x in xs for y in ys)
rtol = 1e-5 # need custom rtol for isapprox because roundoff errors accumulate (?)
# Compute time averages...
c̄1(ts) = 1/length(ts) * sum(c̄1.(zs, t) for t in ts)
c̄2(ts) = 1/length(ts) * sum(c̄2.(zs, t) for t in ts)
#####
##### Test strided windowed time average against analytic solution
##### for *single* NetCDF output
#####
single_time_average_nc_filepath = "single_decay_windowed_time_average_test.nc"
single_ds = NCDataset(single_time_average_nc_filepath)
attribute_names = ("schedule", "interval", "output time interval",
"time_averaging_window", "time averaging window",
"time_averaging_stride", "time averaging stride")
for name in attribute_names
@test haskey(single_ds.attrib, name) && !isnothing(single_ds.attrib[name])
end
window_size = window_nΔt
window = window_size*Δt
time = single_ds["time"][:]
data_plot = single_ds["c1"][1:4, :]
c̄1_timeaverage = zeros(4,length(time[1:end]))
for (n, t) in enumerate(time[1:end])
averaging_times = [t - n*Δt for n in 0:stride:window_size-1 if t - n*Δt >= 0]
# @info n,t,averaging_times, c̄1(averaging_times)
c̄1_timeaverage[:,n] = c̄1(averaging_times)
# @test all(isapprox.(single_ds["c1"][:, n+1], c̄1(averaging_times), rtol=rtol))
end
# Plot each of the four lines
pl = plot()
plot!(time, data_plot[1, :], label="1", color=:blue, legend=:topright)
plot!(time, data_plot[2, :], label="2", color=:red)
plot!(time, data_plot[3, :], label="3", color=:orange)
plot!(time, data_plot[4, :], label="4", color=:green)
plot!(time[1:end],c̄1_timeaverage[1,:], color=:black, linestyle=:dash, label="1-analytic")
plot!(time[1:end],c̄1_timeaverage[2,:], color=:black, linestyle=:dash, label="2-analytic")
plot!(time[1:end],c̄1_timeaverage[3,:], color=:black, linestyle=:dash, label="3-analytic")
plot!(time[1:end],c̄1_timeaverage[4,:], color=:black, linestyle=:dash, label="4-analytic")
tt = 0:window:T2
for i in 1:length(tt)
plot!([tt[i], tt[i]],[0,1],color=:grey,label="")
end
title!(pl, string("Δt=",Δt,", average window=",window_nΔt,"Δt")) # Add the title to the plot
ylims!(pl,(minimum(c̄1_timeaverage[4,:]),maximum(c̄1_timeaverage[4,:])))
xlims!(pl,(0,T2))
close(single_ds)
display(pl)
@glwagner @hdrake With the latest commit, I believe we've addressed the remaining issues in this PR, including those that occur when 1) the saving interval differs from the window time average interval, and 2) after resuming from a checkpoint.
A quick result showing the numerical solution matches well with the analytical solution even when window ≠ interval and picking up checkpoint at t=1:
To properly handle the checkpoint pickup, we manually adjusted the actuation to match the correct value based on the pre-pickup simulation. Here's an example to illustrate what I mean. While this resolves the issue of checkpoint for now, it's more of a workaround, and a more robust high-level design is still needed. I believe some help is required with the checkpoint pickup design, but aside from that, the window time averaging appears to be functioning correctly now!
@liuchihl that's awesome. Is there anything relating to AveragedTimeInterval that doesn't work after this PR?
Also, based on what you said I'm marking this PR as ready for review. (i.e. no longer a draft)
@tomchor After this PR, I believe there are no issues. Except we currently have to manually adjust the actuation to match the correct value based on the pre-pickup simulation during setup. For example, we have to do something like this:
# Run a simulation that saves data to a checkpoint
simulation = test_simulation(T1, Δt, window_nΔt, interval_nΔt, stride, true)
run!(simulation)
checkpointed_wta = simulation.output_writers[:single_output_time_average].outputs["c1"]
checkpointed_actuations = checkpointed_wta.schedule.actuations
# Now try again, but picking up from the previous checkpoint
N = iteration(simulation)
checkpoint = "test_iteration$N.jld2"
simulation = test_simulation(T2, Δt, window_nΔt, interval_nΔt, stride, false)
simulation.output_writers[:single_output_time_average].outputs["c1"].schedule.actuations = checkpointed_actuations
run!(simulation, pickup=checkpoint)
The point is to ensure that the actuation after the pickup matches the actuation value from before the checkpoint.
simulation.output_writers[:single_output_time_average].outputs["c1"].schedule.actuations = checkpointed_actuations
Sorry for the confusion, test_netcdf_timeaverage.jl is just my minimum working example, which is modified from test_netcdf_output_writer.jl. So I believe that MWE I created is not really needed (it's not the original test).
To properly handle the checkpoint pickup, we manually adjusted the actuation to match the correct value based on the pre-pickup simulation. Here's an example to illustrate what I mean. While this resolves the issue of checkpoint for now, it's more of a workaround
Agree we need to refactor the Checkpointer. Thank you for documenting a way to proceed for the time being.
Except we currently have to manually adjust the actuation to match the correct value based on the pre-pickup simulation during setup.
@liuchihl this is outside the scope of this PR, isn't it? This PR says nothing about checkpointing, the issue is AveragedTimeInterval. The checkpointing stuff is separate and orthogonal to this. It's best to keep PRs limited in scope as much as possible.
So I believe that MWE I created is not really needed (it's not the original test).
For the future, note that an efficient workflow is to use an MWE as a test. They are closely related.
@glwagner yeah I agree, I think you're right!
@liuchihl I see there are some unaddressed comments, but this PR otherwise looks to be ready, no? Do you need help? Let's not let this nice PR go stale!
Thanks for coming back to this @tomchor , I've gotten rid of test_netcdf_timeaverage.jl because it is a similar test to the original one. About the checkpoint, I agree that it is a separate issue (though we have a solution here). So I think this PR should be ready.
To summarize the discussion above and in related https://github.com/CliMA/Oceananigans.jl/issues/3670: our MWE that we used to flag and debug these rounding point errors in the WindowedTimeAverage scheme was actually the exact same as the test_netcdf_spatial_average test function in test/test_netcdf_output_writer.jl, just with the value of the timestep changed from
Δt = 1/64 # Nice floating-point number
to
Δt = 0.01 # Floating point number chosen conservatively to flag rounding errors
The original choice of a "nice" floating-point number is convenient because it is not as susceptible to rounding errors, but not as useful for a test because it is overly optimistic. Instead, the test should use a pessimistic value that is more prone to rounding errors, as this protects against the edge cases that users might run into–as we did in our production runs!
Here is the MWE that @liuchihl has modified from the existing one, but which I have removed in favor of a very small modification to the existing MWE test. https://github.com/CliMA/Oceananigans.jl/pull/3721/commits/a0c725c01952a84b4dcffc142a38ccde2dc6890b#diff-60f3eb03acc634682fc82a022d0a0b32382d1e890be3057750545a3d0ada1c09L99
@tomchor and @glwagner, I think this is ready to be merged!
We've tested this quite a bit in various MWEs and also in our complicated production runs with tidally-averaged outputs, and find the results to be much improved over the previous method.
Feedback and testing from others welcome!