Added Pangu, Fengu, SwinRNN models
Modulus Pull Request
Description
Checklist
- [x] I am familiar with the Contributing Guidelines.
- [x] New or existing tests cover these changes.
- [ ] The documentation is up to date with these changes.
- [x] The CHANGELOG.md is up to date with these changes.
- [ ] An issue is linked to this pull request.
Dependencies
timm>=0.9.12
Dear Dallas,
Thank you very much for the review. I will finish the following changes and make two new pull requests for your review. I will target to finish them within this week.
- Tidy up the layers and make first pull request for model only.
- Make second pull request for weatherbench datapipe.
Best Regards, Ivan
From: Dallas Foster @.> Date: Saturday, 18 May 2024 at 2:57 AM To: NVIDIA/modulus @.> Cc: Ivan Au Yeung @.>, Author @.> Subject: Re: [NVIDIA/modulus] Added Pangu, Fengu, SwinRNN models (PR #341)
@dallasfoster requested changes on this pull request.
Needs tidying of the layers and then we can rereview. Also consider merging the weatherbench datapipe separately, which will smooth the process.
In examples/weather/pangu_weather/train_pangu_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605417203:
+from torch.nn.parallel import DistributedDataParallel
+from omegaconf import DictConfig
+from modulus.models.pangu import Pangu
+from modulus.datapipes.climate import ERA5HDF5Datapipe
+from modulus.distributed import DistributedManager
+from modulus.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad
+from modulus.launch.logging import LaunchLogger, PythonLogger, initialize_mlflow
+from modulus.launch.utils import load_checkpoint, save_checkpoint
+try:
- from apex import optimizers
+except:
-
raise ImportError(
-
"FCN training requires apex package for optimizer."
Change the import error here.
In examples/weather/pangu_weather/train_pangu_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605418283:
Initialize loggers
-
initialize_wandb(
-
project="Modulus-Launch-Dev",
-
entity="Modulus",
-
name="FourCastNet-Training",
-
group="FCN-DDP-Group",
-
)
Let's remove this commented code.
In examples/weather/pangu_weather/train_pangu_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605418602:
batch_size=1,
-
patch_size=(1, 1), -
num_workers=8,
Can any of these be added to the config?
In examples/weather/pangu_weather/train_pangu_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605419059:
- no_channals_pangu = 4 + 5 * 13
-
datapipe = ERA5HDF5Datapipe(
-
data_dir="/data/train/", -
stats_dir="/data/stats/", -
channels=[i for i in range(no_channals_pangu)], -
num_samples_per_year=cfg.num_samples_per_year_train, -
batch_size=1, -
patch_size=(1, 1), -
num_workers=8, -
device=dist.device, -
process_rank=dist.rank, -
world_size=dist.world_size, -
)
-
logger.success(f"Loaded datapipe of size {len(datapipe)}")
-
mask_dir = "/data/constant_mask"
Can this be a configurable?
In examples/weather/pangu_weather/train_pangu_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605419536:
img_size=(721, 1440),
-
patch_size=(2, 4, 4), -
embed_dim=192, -
num_heads=(6, 12, 12, 6), -
window_size=(2, 6, 12),
Can any of this be configurable?
In examples/weather/pangu_weather/train_pangu_lite_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605421654:
Initialize loggers
-
initialize_wandb(
-
project="Modulus-Launch-Dev",
-
entity="Modulus",
-
name="FourCastNet-Training",
-
group="FCN-DDP-Group",
-
)
Remove commented code here.
In examples/weather/pangu_weather/train_pangu_lite_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605421836:
batch_size=1,
-
patch_size=(1, 1), -
num_workers=8,
Same comment, can these be configurable.
In examples/weather/pangu_weather/train_pangu_lite_era5.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605422169:
- datapipe = ERA5HDF5Datapipe(
-
data_dir="/data/train/", -
stats_dir="/data/stats/", -
channels=[i for i in range(no_channals_pangu)], -
num_samples_per_year=cfg.num_samples_per_year_train, -
batch_size=1, -
patch_size=(1, 1), -
num_workers=8, -
device=dist.device, -
process_rank=dist.rank, -
world_size=dist.world_size, -
)
-
logger.success(f"Loaded datapipe of size {len(datapipe)}")
-
mask_dir = "/data/constant_mask"
-
land_mask = torch.from_numpy(
Can the data type be configurable here? Either float32 or float16?
On modulus/datapipes/climate/weatherbench.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605423398:
Can we break this off as a separate MR.
On modulus/models/fengwu/fengwu.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605425877:
There is a lot of code in this file. Can we group together layers and subcomments into separate files? For example, let's put MLP, transformer blocks, attention blocks, and upsampling into models/layers or something.
In modulus/models/fengwu/fengwu.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605426791:
)
-
self.decoder_t = DecoderLayer( -
img_size=img_size, -
patch_size=patch_size, -
out_chans=pressure_level, -
dim=embed_dim, -
output_resolution=resolution[0], -
middle_resolution=resolution[1], -
depth=2, -
depth_middle=6, -
num_heads=num_heads[:2], -
window_size=window_size[1:], -
drop_path=drop_path, -
) -
def forward(self, surface, z, r, u, v, t):
I think we should make it principle to have the forward call be strictly def forward(self, x): and the internal of the forward model separates the components as necessary.
On modulus/models/fengwu/shift_window_mask.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605427430:
Consider putting these into modulus/models/utils folder.
On modulus/models/fengwu/utils.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605428272:
Again, consider putting these into modulus/models/utils
In modulus/models/pangu/pangu.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605429204:
+import math
+from dataclasses import dataclass
+import numpy as np
+import torch
+from timm.layers import DropPath, trunc_normal_
+from torch import nn
+from ..meta import ModelMetaData
+from ..module import Module
+from .patch_embed import PatchEmbed2D, PatchEmbed3D, PatchRecovery2D, PatchRecovery3D
+from .shift_window_mask import get_shift_window_mask, window_partition, window_reverse
+from .utils import crop3d, get_earth_position_index, get_pad3d
+class UpSample(nn.Module):
Is this class duplicated from Fengwu code? Let's consolidate all of the layers as much as possible.
On modulus/models/pangu/pangu.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605430074:
Similar comment to fengwu, let us separate out all of the layers into a separate folder and abstract them so that they can be shared by pangu, fengwu, etc...
On modulus/models/swinvrnn/swinvrnn.pyhttps://github.com/NVIDIA/modulus/pull/341#discussion_r1605430474:
Similar comment about consolidation applies here.
— Reply to this email directly, view it on GitHubhttps://github.com/NVIDIA/modulus/pull/341#pullrequestreview-2064155157, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AZLELSAT4NRHLDFOF5JCMTDZCZHK3AVCNFSM6AAAAABDG5GL6SVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDANRUGE2TKMJVG4. You are receiving this because you authored the thread.Message ID: @.***>
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci