ClimaOcean.jl icon indicating copy to clipboard operation
ClimaOcean.jl copied to clipboard

Distributed MPI runs unable to save variables using `jldsave`

Open taimoorsohail opened this issue 5 months ago • 17 comments

Hi,

I am trying to save some variables of interest in my simulation using jldsave. Note that I can't use JLD2Writer or outputwriter because I need to save variables from the ocean, atmosphere and sea ice in the same file, for checkpointing, and JLD2Writer is not supported for all components of a coupled OceanSeaIceModel yet (?).

I have the below MWE:

using MPI
using CUDA

MPI.Init()
atexit(MPI.Finalize)  

using Oceananigans
using Oceananigans.Units
using Oceananigans.DistributedComputations
using Printf
using Dates
using Oceananigans.Architectures: on_architecture

output_path = expanduser("/g/data/v46/txs156/ocean-ensembles/outputs/")

arch = Distributed(GPU(); partition = Partition(y = DistributedComputations.Equal()), synchronized_communication=false)

function immersed_latlon_grid(underlying_grid::LatitudeLongitudeGrid;
                                           radius = 5,  # controls width of Gaussian (in degrees)
                                           height = 2000, # hill height in meters
                                           λc = 0, φc = 45, # hill center (lon, lat)
                                           active_cells_map = false)

    Lz = underlying_grid.Lz

    # Convert degrees to radians if grid expects radians, but leaving in degrees here
    bottom_height(λ, φ) = begin
        # Gaussian hill centered at (λc, φc)
        r² = (λ - λc)^2 + (φ - φc)^2
        z = -Lz + height * exp(-r² / (2 * radius^2))
        return z
    end

    grid = ImmersedBoundaryGrid(underlying_grid,
                                GridFittedBottom(bottom_height);
                                active_cells_map)

    return grid
end

Nx, Ny, Nz = 100, 100, 50
Lx, Ly = 100, 100

@info "Defining vertical z faces"

depth = -6000.0 # Depth of the ocean in meters
z_faces = ExponentialDiscretization(Nz, depth, 0) 

@info "Creating grid"

underlying_grid = LatitudeLongitudeGrid(arch,
                                        size = (Nx, Ny, Nz),
                                        z = z_faces,
                                        halo = (6, 6, 3),
                                        longitude = (0, 360),
                                        latitude = (-70, 70))




@info "Defining grid"

grid = immersed_latlon_grid(underlying_grid; active_cells_map=true)

@info "Creating free surface"
free_surface = SplitExplicitFreeSurface(grid; substeps = 70)

@info "Creating model"

ocean_model = HydrostaticFreeSurfaceModel(; grid, free_surface, timestepper = :SplitRungeKutta3)

@info "Creating simulation"

simulation = Simulation(ocean_model; Δt=10, verbose=false, stop_time=2hours)

function save_restart(sim)
    @info @sprintf("Saving checkpoint file")
    localrank = Integer(sim.model.architecture.local_rank)
    @info "Local rank: " * string(localrank)
    @info "Saving filename" * output_path * "mwe_jldsave_distributed" * string(sim.model.clock.iteration) * "_rank$(localrank).jld2"

    jldsave(output_path * "mwe_jldsave_distributed" * string(sim.model.clock.iteration) * "_rank$(localrank).jld2";
    u = on_architecture(CPU(), (sim.model.velocities.u)))
end

# Nice progress messaging is helpful:

## Print a progress message
progress_message(sim) = @printf("Iteration: %04d, time: %s, Δt: %s, max(|w|) = %.1e ms⁻¹, wall time: %s\n",
                                iteration(sim), prettytime(sim), prettytime(sim.Δt),
                                maximum(abs, sim.model.velocities.w), prettytime(sim.run_wall_time))

add_callback!(simulation, progress_message, IterationInterval(40))
add_callback!(simulation, save_restart, IterationInterval(40))

@info "Running simulation"
run!(simulation)

But running the simulation leads to the error at jldsave(output_path * "mwe_jldsave_distributed" * string(sim.model.clock.iteration) * "_rank$(localrank).jld2"; :

ERROR: LoadError: type Nothing has no field rank
Stacktrace:
  [1] getproperty(x::Nothing, f::Symbol)
    @ Base ./Base.jl:37
  [2] (::Oceananigans.BoundaryConditions.MultiRegionFillHalo{Oceananigans.BoundaryConditions.SouthAndNorth})(c::OffsetArrays.OffsetArray{Float64, 3, Array{Float64, 3}}, southbc::BoundaryCondition{Oceananigans.BoundaryConditions.MultiRegionCommunication, Nothing}, northbc::BoundaryCondition{Oceananigans.BoundaryConditions.MultiRegionCommunication, Nothing}, loc::Tuple{Center, Center, Nothing}, grid::LatitudeLongitudeGrid{Float64, Periodic, Oceananigans.Grids.FullyConnected, Bounded, Oceananigans.Grids.StaticVerticalDiscretization{OffsetArrays.OffsetVector{Float64, Vector{Float64}}, OffsetArrays.OffsetVector{Float64, Vector{Float64}}, OffsetArrays.OffsetVector{Float64, Vector{Float64}}, OffsetArrays.OffsetVector{Float64, Vector{Float64}}}, Float64, Float64, OffsetArrays.OffsetVector{Float64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, OffsetArrays.OffsetVector{Float64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, Float64, Float64, OffsetArrays.OffsetVector{Float64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, OffsetArrays.OffsetVector{Float64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, OffsetArrays.OffsetVector{Float64, Vector{Float64}}, OffsetArrays.OffsetVector{Float64, Vector{Float64}}, OffsetArrays.OffsetVector{Float64, Vector{Float64}}, OffsetArrays.OffsetVector{Float64, Vector{Float64}}, Float64, Float64, CPU, Int64}, buffers::Tuple{})
    @ Oceananigans.MultiRegion /g/data/v46/txs156/Oceananigans.jl/src/MultiRegion/multi_region_boundary_conditions.jl:113
  [3] fill_halo_event!
    @ /g/data/v46/txs156/Oceananigans.jl/src/BoundaryConditions/fill_halo_regions.jl:40 [inlined]
  [4] #fill_halo_regions!#26
    @ /g/data/v46/txs156/Oceananigans.jl/src/BoundaryConditions/fill_halo_regions.jl:32 [inlined]
  [5] fill_halo_regions!
    @ /g/data/v46/txs156/Oceananigans.jl/src/BoundaryConditions/fill_halo_regions.jl:25 [inlined]
  [6] #fill_halo_regions!#63
    @ /g/data/v46/txs156/Oceananigans.jl/src/Fields/field.jl:857 [inlined]
    @ Oceananigans.Fields /g/data/v46/txs156/Oceananigans.jl/src/Fields/field.jl:843
    @ Oceananigans.ImmersedBoundaries /g/data/v46/txs156/Oceananigans.jl/src/ImmersedBoundaries/grid_fitted_bottom.jl:85
    @ Oceananigans.ImmersedBoundaries /g/data/v46/txs156/Oceananigans.jl/src/ImmersedBoundaries/immersed_boundary_grid.jl:129
    @ Oceananigans.Fields /g/data/v46/txs156/Oceananigans.jl/src/Fields/field.jl:468
    @ Main /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:82
 [12] UndefVarError: `jldsave` not definedUndefVarError: `jldsave` not defined
Stacktrace:
 [1] 
Stacktrace:
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/callback.jl:15
   @ Main /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:82
   @ Main /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:82
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:238
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/callback.jl:15
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/callback.jl:15
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:136
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:238
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:238
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:105
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:136
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:136
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:92
 [17] top-level scope
    @ /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:97
in expression starting at /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:97
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:105
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:105
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:92
 [7] top-level scope
   @ /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:97
in expression starting at /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:97
   @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:92
 [7] top-level scope
   @ /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:97
in expression starting at /g/data/v46/txs156/ocean-ensembles/mwes/JLDSave_distributed.jl:97
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[17346,1],0]
  Exit code:    1
--------------------------------------------------------------------------

Note I had to delete most of the stack trace because of character limits on Github...

But if you run the MWE with MPI (more than 1 GPU) you will see the issue I have!

cc @navidcy @simone-silvestri @xkykai

taimoorsohail avatar Oct 06 '25 06:10 taimoorsohail

hm... this

 [12] UndefVarError: `jldsave` not definedUndefVarError: `jldsave` not defined

suggests that simply the using JLD2 is missing?

navidcy avatar Oct 06 '25 07:10 navidcy

@simone-silvestri I'm a bit puzzled, how come the stacktrace includes methods from MultiRegion?

navidcy avatar Oct 06 '25 07:10 navidcy

@taimoorsohail tangential to this PR, but can you write up a "feature request" issue with the thing you would like to be supported by JLD2Writer? (If that issue doesn't exist yet.)

glwagner avatar Oct 06 '25 16:10 glwagner

@taimoorsohail tangential to this PR, but can you write up a "feature request" issue with the thing you would like to be supported by JLD2Writer? (If that issue doesn't exist yet.)

In brief: it's checkpointing for coupled ocean-sea ice models + ability to checkpoint for Distributed grids.

navidcy avatar Oct 06 '25 18:10 navidcy

I think JLD2Writer should work though and can be used for manual checkpointing without having to call jldsave? I'm wondering what the specific problem is (maybe the problem is an assumption that all fields in JLD2Writer have the same grid? I think that is easy to fix)

glwagner avatar Oct 06 '25 18:10 glwagner

Hmm, the same error here is being hit by the cubed sphere GPU tests on this PR: https://github.com/CliMA/Oceananigans.jl/pull/4785

Something is borked because here it makes no sense to hit multi region halo filling, right?

glwagner avatar Oct 06 '25 18:10 glwagner

Something is borked because here it makes no sense to hit multi region halo filling, right?

I feel similarly puzzled w MultiRegion!

navidcy avatar Oct 06 '25 18:10 navidcy

MultiRegion is the default boundary condition for a FullyConnected topology. If it appears here it means something is wrong.

simone-silvestri avatar Oct 06 '25 20:10 simone-silvestri

@taimoorsohail tangential to this PR, but can you write up a "feature request" issue with the thing you would like to be supported by JLD2Writer? (If that issue doesn't exist yet.)

What Navid said - there is an active PR on checkpointing coupled models here: https://github.com/CliMA/ClimaOcean.jl/pull/401

But in the meantime, I am saving JLD2 files and reading them as a workaround. I can try using JLD2Writer for the ocean and sea ice to see if that works :)

taimoorsohail avatar Oct 07 '25 01:10 taimoorsohail

Just an update - JLD2Writer only works for Fields, but for the checkpointer to work, I need to save the ocean model clock as well:

ocean_checkpointer_tracers = merge(
    ocean.model.velocities,
    ocean.model.tracers,
    ocean.model.free_surface.barotropic_velocities,
    (η = simulation.model.ocean.model.free_surface.η,
     clock = ocean.model.clock)
)
sea_ice_checkpointer_tracers = merge(  
                                (ice_thickness = sea_ice.model.ice_thickness,
                                ice_concentration = sea_ice.model.ice_concentration,
                                top_surface_temperature = sea_ice.model.ice_thermodynamics.top_surface_temperature),
                                sea_ice.model.dynamics.auxiliaries.fields, 
                                sea_ice.model.velocities)

iteration_number = string(simulation.model.clock.iteration)
@time ocean.output_writers[:checkpointer] = JLD2Writer(ocean.model, ocean_checkpointer_tracers;
                                            dir = output_path,
                                            schedule = TimeInterval(5days),
                                            filename = "ocean_checkpointer_vars_iteration" * iteration_number,
                                            overwrite_existing = true)

@time sea_ice.output_writers[:checkpointer] = JLD2Writer(sea_ice.model, sea_ice_checkpointer_tracers;
                                            dir = output_path,
                                            schedule = TimeInterval(5days),
                                            filename = "sea_ice_checkpointer_vars_iteration" * iteration_number,
                                            overwrite_existing = true)

This leads to an error when running the simulation:

ERROR: MethodError: objects of type Clock{Float64, Float64, Int64, Int64} are not callable
Stacktrace:
  [1] fetch_output(output::Clock{…}, model::HydrostaticFreeSurfaceModel{…})
    @ Oceananigans.OutputWriters /g/data/v46/txs156/Oceananigans.jl/src/OutputWriters/fetch_output.jl:12
  [2] fetch_and_convert_output(output::Clock{…}, model::HydrostaticFreeSurfaceModel{…}, writer::JLD2Writer{…})
    @ Oceananigans.OutputWriters /g/data/v46/txs156/Oceananigans.jl/src/OutputWriters/fetch_output.jl:39
  [3] (::Oceananigans.OutputWriters.var"#36#37"{JLD2Writer{…}, HydrostaticFreeSurfaceModel{…}})(::Tuple{Symbol, Clock{…}})
    @ Oceananigans.OutputWriters ./none:0
  [4] iterate
    @ ./generator.jl:47 [inlined]
  [5] merge(a::@NamedTuple{}, itr::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Oceananigans.OutputWriters.var"#36#37"{JLD2Writer{…}, HydrostaticFreeSurfaceModel{…}}})
    @ Base ./namedtuple.jl:370
  [6] NamedTuple
    @ ./namedtuple.jl:151 [inlined]
  [7] macro expansion
    @ ./timing.jl:395 [inlined]
  [8] write_output!(writer::JLD2Writer{…}, model::HydrostaticFreeSurfaceModel{…})
    @ Oceananigans.OutputWriters /g/data/v46/txs156/Oceananigans.jl/src/OutputWriters/jld2_writer.jl:255
  [9] write_output!(writer::JLD2Writer{…}, sim::Simulation{…})
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/simulation.jl:257
 [10] initialize!(sim::Simulation{…})
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:243
 [11] initialize!
    @ /g/data/v46/txs156/ClimaOcean.jl/src/OceanSeaIceModels/ocean_sea_ice_model.jl:79 [inlined]
 [12] initialize!(sim::Simulation{…})
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:209
 [13] time_step!(sim::Simulation{…})
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:136
 [14] run!(sim::Simulation{…}; pickup::Bool)
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:105
 [15] run!(sim::Simulation{…})
    @ Oceananigans.Simulations /g/data/v46/txs156/Oceananigans.jl/src/Simulations/run.jl:92
 [16] top-level scope
    @ REPL[63]:1
Some type information was truncated. Use `show(err)` to see complete types.

taimoorsohail avatar Oct 07 '25 01:10 taimoorsohail

So I think the error is that a relevant method for jldsave doesn't exist, so the code goes to the fallback which may be in MultiRegion.

taimoorsohail avatar Oct 08 '25 04:10 taimoorsohail

we need to merge https://github.com/CliMA/Oceananigans.jl/pull/4749 and https://github.com/CliMA/ClimaSeaIce.jl/pull/90 before we can use a checkpointer for the sea ice model, though

just noticed this issue involves OutputWriter, not Checkpointer

simone-silvestri avatar Oct 08 '25 11:10 simone-silvestri

JLD2Writer indexes fields by their time and iteration. Is there additional information needed for clock besides these?

glwagner avatar Oct 08 '25 19:10 glwagner

That's a good point @glwagner - I need time, iteration and last_Δt - the latter of which is in clock.

taimoorsohail avatar Oct 09 '25 04:10 taimoorsohail

Update: I got the manual checkpointing to work by saving three files:

################################### START CHECKPOINTING ######################################

@info "Saving restart"

function save_restart(sim)
    localrank = MPI.Comm_rank(MPI.COMM_WORLD)
    jldsave(output_path * "ocean_checkpointer_clock_iteration" * string(sim.model.clock.iteration) * "_rank$(localrank).jld2";
    clock = sim.model.ocean.model.clock)
end

ocean_checkpointer_tracers = merge(
    ocean.model.velocities,
    ocean.model.tracers,
    ocean.model.free_surface.barotropic_velocities,
    (; η = simulation.model.ocean.model.free_surface.η)
)
sea_ice_checkpointer_tracers = merge(  
                                (ice_thickness = sea_ice.model.ice_thickness,
                                ice_concentration = sea_ice.model.ice_concentration,
                                top_surface_temperature = sea_ice.model.ice_thermodynamics.top_surface_temperature),
                                sea_ice.model.dynamics.auxiliaries.fields, 
                                sea_ice.model.velocities)

@time ocean.output_writers[:checkpointer] = JLD2Writer(ocean.model, ocean_checkpointer_tracers;
                                            dir = output_path,
                                            schedule = IterationInterval(40),
                                            filename = "ocean_checkpointer_vars_iteration" * iteration_number,
                                            overwrite_existing = true)

@time sea_ice.output_writers[:checkpointer] = JLD2Writer(sea_ice.model, sea_ice_checkpointer_tracers;
                                            dir = output_path,
                                            schedule = IterationInterval(40),
                                            filename = "sea_ice_checkpointer_vars_iteration" * iteration_number,
                                            overwrite_existing = true)

add_callback!(simulation, save_restart, checkpoint_intervals)

################################### END CHECKPOINTING ######################################

Closing this issue now as the files are saving with distributed GPUs.

jldsave still isn't supported but a combination of jldsave for the clock and JLD2Writer for the Fields seems to work.

Happy to reopen if we think adding jldsave support for distributed GPUs is worth doing.

taimoorsohail avatar Oct 09 '25 04:10 taimoorsohail

@taimoorsohail, the MultiRegion error was never resolved, right? Potentially https://github.com/CliMA/Oceananigans.jl/pull/4860 deals with it.

navidcy avatar Oct 15 '25 21:10 navidcy

Good catch @navidcy - I can test whether it fixes the issue of jldsave-ing fields...

taimoorsohail avatar Oct 15 '25 23:10 taimoorsohail