earth2studio
earth2studio copied to clipboard
feat: Integrate CorrDiffSolar model with MultiDiffusion support
Earth2Studio Pull Request
Hi @Charlelie,
As discussed over email, this PR integrates the retrained CorrDiffSolar model into Earth2Studio.
Key Changes
- Added
CorrDiffSolarMDModel (earth2studio/models/dx/corrdiffMD.py). - Added
MultiDiffusionSampler (earth2studio/utils/multidiffusion.py). - Merged
diagnostic_solar()logic intoearth2studio/run.py.
Usage Example
- An example inference script has been added at
examples/19_multidiff_solar.py. - Note: We need your help to upload the package to the NGC, and then one can use
load_default_package()to download it. (The package is in the docker image as I described in e-mail).
Greptile Overview
Greptile Summary
This PR integrates the CorrDiffSolar model with MultiDiffusion support for high-resolution solar radiation downscaling.
Key additions:
- New
CorrDiffMDbase class andCorrDiffSolarMDsubclass incorrdiffMD.py - New
MultiDiffusionsampler for windowed diffusion processing - New
diagnostic_solarworkflow function inrun.py - Example script demonstrating usage
Critical issues found:
- Multiple import errors and syntax bugs that will cause runtime failures
- Missing
self.deviceattribute access will crash during inference - Undefined
solarcorrdiffic.srxattribute access inrun.py - Incorrect class name import (
corrdiffMDvsCorrDiffMD) in__init__.py - Deprecated
datetime.utcfromtimestampusage - Hardcoded
img_shapeoverride defeats dynamic calculation
Architecture concerns:
- The hardcoded 320x320 window size may limit flexibility for different resolutions
- MultiDiffusion implementation appears sound for handling large images via overlapping windows
Confidence Score: 1/5
- This PR has critical runtime errors that will cause immediate failures
- Multiple critical bugs present including missing attributes (self.device, solarcorrdiffic.srx), incorrect imports, syntax errors in type annotations, and deprecated API usage. These issues will cause AttributeError and ImportError exceptions at runtime
- Primary attention needed on
earth2studio/models/dx/corrdiffMD.py(missing self.device),earth2studio/run.py(undefined srx attribute), andearth2studio/models/dx/__init__.py(incorrect import name)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| earth2studio/models/dx/corrdiffMD.py | 2/5 | New CorrDiffMD and CorrDiffSolarMD model classes with MultiDiffusion support. Critical bugs: duplicate imports, hardcoded img_shape, missing self.device attribute, deprecated datetime function, incorrect type hints |
| earth2studio/utils/multidiffusion.py | 4/5 | New MultiDiffusion sampler implementing windowed diffusion with overlapping regions. Minor style issues with Optional type hints, but logic appears sound |
| earth2studio/run.py | 2/5 | New diagnostic_solar workflow function. Critical bugs: incorrect type annotation for plt_lr parameter, undefined attribute access (solarcorrdiffic.srx) |
| earth2studio/models/dx/init.py | 2/5 | Added imports for corrdiffMD and CorrDiffSolarMD. Issue: importing non-existent corrdiffMD class (lowercase) will cause ImportError |
| examples/19_multidiff_solar.py | 4/5 | Example script for CorrDiffSolarMD inference. Minor style issue with commented environment variable, but overall implementation is clean |
Sequence Diagram
sequenceDiagram
participant User
participant diagnostic_solar
participant PrognosticModel
participant CorrDiffSolarMD
participant MultiDiffusion
participant RegressionModel
participant ResidualModel
User->>diagnostic_solar: Run inference with time, nsteps
diagnostic_solar->>DataSource: Fetch initial conditions
DataSource-->>diagnostic_solar: Return x, coords
loop For each timestep (nsteps)
diagnostic_solar->>PrognosticModel: Forward pass
PrognosticModel-->>diagnostic_solar: pro_out, pro_out_coord
diagnostic_solar->>CorrDiffSolarMD: __call__(pro_out, pro_out_coord)
CorrDiffSolarMD->>CorrDiffSolarMD: Interpolate to output grid
CorrDiffSolarMD->>CorrDiffSolarMD: Compute solar zenith angle
CorrDiffSolarMD->>CorrDiffSolarMD: Preprocess input (normalize)
CorrDiffSolarMD->>CorrDiffSolarMD: get_windows(stride)
loop For each window
CorrDiffSolarMD->>RegressionModel: regression_step(window)
RegressionModel-->>CorrDiffSolarMD: regression output
end
CorrDiffSolarMD->>CorrDiffSolarMD: Average overlapping regions
alt inference_mode == "both"
CorrDiffSolarMD->>MultiDiffusion: __call__(net, img_lr, regression_output, windows)
loop For each diffusion step
loop For each window
MultiDiffusion->>ResidualModel: Forward pass (denoising)
ResidualModel-->>MultiDiffusion: denoised output
MultiDiffusion->>ResidualModel: Forward pass (second order)
ResidualModel-->>MultiDiffusion: refined output
end
MultiDiffusion->>MultiDiffusion: Average overlapping regions
end
MultiDiffusion-->>CorrDiffSolarMD: residual output
CorrDiffSolarMD->>CorrDiffSolarMD: Add regression + residual
end
CorrDiffSolarMD->>CorrDiffSolarMD: Postprocess output (denormalize)
CorrDiffSolarMD-->>diagnostic_solar: Solar radiation output
diagnostic_solar->>IOBackend: Write output
end
diagnostic_solar-->>User: Return IOBackend
Thanks for opening this @sunjingan !