earth2studio icon indicating copy to clipboard operation
earth2studio copied to clipboard

feat: Integrate CorrDiffSolar model with MultiDiffusion support

Open sunjingan opened this issue 1 month ago • 2 comments

Earth2Studio Pull Request

Hi @Charlelie,

As discussed over email, this PR integrates the retrained CorrDiffSolar model into Earth2Studio.

Key Changes

  • Added CorrDiffSolarMD Model (earth2studio/models/dx/corrdiffMD.py).
  • Added MultiDiffusion Sampler (earth2studio/utils/multidiffusion.py).
  • Merged diagnostic_solar() logic into earth2studio/run.py.

Usage Example

  • An example inference script has been added at examples/19_multidiff_solar.py.
  • Note: We need your help to upload the package to the NGC, and then one can use load_default_package() to download it. (The package is in the docker image as I described in e-mail).

sunjingan avatar Nov 11 '25 07:11 sunjingan

Greptile Overview

Greptile Summary

This PR integrates the CorrDiffSolar model with MultiDiffusion support for high-resolution solar radiation downscaling.

Key additions:

  • New CorrDiffMD base class and CorrDiffSolarMD subclass in corrdiffMD.py
  • New MultiDiffusion sampler for windowed diffusion processing
  • New diagnostic_solar workflow function in run.py
  • Example script demonstrating usage

Critical issues found:

  • Multiple import errors and syntax bugs that will cause runtime failures
  • Missing self.device attribute access will crash during inference
  • Undefined solarcorrdiffic.srx attribute access in run.py
  • Incorrect class name import (corrdiffMD vs CorrDiffMD) in __init__.py
  • Deprecated datetime.utcfromtimestamp usage
  • Hardcoded img_shape override defeats dynamic calculation

Architecture concerns:

  • The hardcoded 320x320 window size may limit flexibility for different resolutions
  • MultiDiffusion implementation appears sound for handling large images via overlapping windows

Confidence Score: 1/5

  • This PR has critical runtime errors that will cause immediate failures
  • Multiple critical bugs present including missing attributes (self.device, solarcorrdiffic.srx), incorrect imports, syntax errors in type annotations, and deprecated API usage. These issues will cause AttributeError and ImportError exceptions at runtime
  • Primary attention needed on earth2studio/models/dx/corrdiffMD.py (missing self.device), earth2studio/run.py (undefined srx attribute), and earth2studio/models/dx/__init__.py (incorrect import name)

Important Files Changed

File Analysis

Filename Score Overview
earth2studio/models/dx/corrdiffMD.py 2/5 New CorrDiffMD and CorrDiffSolarMD model classes with MultiDiffusion support. Critical bugs: duplicate imports, hardcoded img_shape, missing self.device attribute, deprecated datetime function, incorrect type hints
earth2studio/utils/multidiffusion.py 4/5 New MultiDiffusion sampler implementing windowed diffusion with overlapping regions. Minor style issues with Optional type hints, but logic appears sound
earth2studio/run.py 2/5 New diagnostic_solar workflow function. Critical bugs: incorrect type annotation for plt_lr parameter, undefined attribute access (solarcorrdiffic.srx)
earth2studio/models/dx/init.py 2/5 Added imports for corrdiffMD and CorrDiffSolarMD. Issue: importing non-existent corrdiffMD class (lowercase) will cause ImportError
examples/19_multidiff_solar.py 4/5 Example script for CorrDiffSolarMD inference. Minor style issue with commented environment variable, but overall implementation is clean

Sequence Diagram

sequenceDiagram
    participant User
    participant diagnostic_solar
    participant PrognosticModel
    participant CorrDiffSolarMD
    participant MultiDiffusion
    participant RegressionModel
    participant ResidualModel

    User->>diagnostic_solar: Run inference with time, nsteps
    diagnostic_solar->>DataSource: Fetch initial conditions
    DataSource-->>diagnostic_solar: Return x, coords
    
    loop For each timestep (nsteps)
        diagnostic_solar->>PrognosticModel: Forward pass
        PrognosticModel-->>diagnostic_solar: pro_out, pro_out_coord
        
        diagnostic_solar->>CorrDiffSolarMD: __call__(pro_out, pro_out_coord)
        
        CorrDiffSolarMD->>CorrDiffSolarMD: Interpolate to output grid
        CorrDiffSolarMD->>CorrDiffSolarMD: Compute solar zenith angle
        CorrDiffSolarMD->>CorrDiffSolarMD: Preprocess input (normalize)
        
        CorrDiffSolarMD->>CorrDiffSolarMD: get_windows(stride)
        
        loop For each window
            CorrDiffSolarMD->>RegressionModel: regression_step(window)
            RegressionModel-->>CorrDiffSolarMD: regression output
        end
        
        CorrDiffSolarMD->>CorrDiffSolarMD: Average overlapping regions
        
        alt inference_mode == "both"
            CorrDiffSolarMD->>MultiDiffusion: __call__(net, img_lr, regression_output, windows)
            
            loop For each diffusion step
                loop For each window
                    MultiDiffusion->>ResidualModel: Forward pass (denoising)
                    ResidualModel-->>MultiDiffusion: denoised output
                    MultiDiffusion->>ResidualModel: Forward pass (second order)
                    ResidualModel-->>MultiDiffusion: refined output
                end
                MultiDiffusion->>MultiDiffusion: Average overlapping regions
            end
            
            MultiDiffusion-->>CorrDiffSolarMD: residual output
            CorrDiffSolarMD->>CorrDiffSolarMD: Add regression + residual
        end
        
        CorrDiffSolarMD->>CorrDiffSolarMD: Postprocess output (denormalize)
        CorrDiffSolarMD-->>diagnostic_solar: Solar radiation output
        
        diagnostic_solar->>IOBackend: Write output
    end
    
    diagnostic_solar-->>User: Return IOBackend

greptile-apps[bot] avatar Nov 11 '25 08:11 greptile-apps[bot]

Thanks for opening this @sunjingan !

NickGeneva avatar Nov 12 '25 18:11 NickGeneva