graph_weather icon indicating copy to clipboard operation
graph_weather copied to clipboard

Add support training and evaluation on WeatherBench 2

Open jacobbieker opened this issue 1 year ago • 13 comments

We want to be able to train and evaluate on WeatherBench2 to enable easy comparisons of the data.

Detailed Description

https://weatherbench2.readthedocs.io/en/latest/data-guide.html

Context

The data is already in Zarr and easy to pull and use, and freely available which is ideal. It does tend to be 6 hourly, so not as relevant for OCF, but a good way to start and check as a benchmark.

Possible Implementation

Adding a data module that is extensible

jacobbieker avatar Mar 05 '24 15:03 jacobbieker

Hi @jacobbieker, is it possible to work on this one? I would like to try

0xFrama avatar Mar 06 '24 20:03 0xFrama

HI, i was browsing through the repository and weatherbench2 as well. As far as I can tell, weatherbench2 is benchmarking on different datasets/evaluation tasks based on your data/output. I can also see some models/training pipelines. My q is, what specifically would you like implemented? What kind of functionality? Is it a function you call to load and use a specific dataset from Weatherbench? Or just to process/create the weatherbench config dictionaries?

It would be amazing if you could share how you wanted to use these functions and what you wanted included!

aavashsubedi avatar Mar 09 '24 11:03 aavashsubedi

@0xFrama Yes it is! This one could also be split into adding support for different datasets in WeatherBench2, might be good to split it up into smaller ones, such as supporting the ERA5 dataset for training, and the HRES for testing, or the like.

@aavashsubedi I'd like implemented the ability to load the datasets in WeatherBench 2 in a training/eval script for the models in this repo. The functionality would be to load and use a specific dataset from WeatherBench. As they are stored as Zarrs in Google Cloud, the general ability should be fairly small, Xarray can directly open the Zarrs, and then they can be streamed in for evaluation. Including the mean/stddev, or min/max for normalization would also be ideal.

If you both want to work on this, then it might work best to split which datasets you work on. I would propose 2 different PRs, each adding support for one of the following combinations:

  1. ERA5 Ground Truth one for training, and ERA5 Forecast for testing
  2. IFS Analysis for Training, and IFS HRES/ENS Mean for testing

That should allow us to use graph weather and compare to other models in WeatherBench 2 easier.

jacobbieker avatar Mar 09 '24 12:03 jacobbieker

While working on PR #87 I did some tests with WeatherBench2 and the ERA5 dataset, editing the run_fulll.py code. I don't know if this is what @jacobbieker had in mind, but maybe this snippet can be helpful:

import apache_beam   # Needs to be imported separately to avoid TypingError
import weatherbench2
import xarray as xr

# Run the code below to access cloud data on Colab!
from google.colab import auth
auth.authenticate_user()

class XrDataset(Dataset):
    def __init__(self):
        super().__init__()

        obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr'
        self.data = xr.open_zarr(obs_path) 
        
        # TODO: Check training period
        self.data = self.data.where(self.data['time.year'] < 2018, drop=True)
        self.num_levels = len(self.data["level"])
        self.num_vars = len(self.data.keys())
        self.NWP_features = ["geopotential", 
                         "specific_humidity",
                         "temperature",
                         "u_component_of_wind",
                         "v_component_of_wind",
                         "vertical_velocity"]

        self.aux_features = ["geopotential_at_surface", 
                            "land_sea_mask",
                            "toa_incident_solar_radiation",
                            "toa_incident_solar_radiation_12hr"]
    def __len__(self):
        return len(self.data["time"])

    def __getitem__(self, item):
        start = self.data.isel(time = item)
        end = self.data.isel(time = item+1)

        # Stack the NWP features for input
        input_data = np.stack(
                  [
                      (start[f"{var}"].values - np.array(const.ERA5_MEANS[f"{var}"])[:, None, None])
                      / (np.array(const.ERA5_STD[f"{var}"])[:, None, None] + 0.0001)
                      for var in self.NWP_features
                  ],
                  axis=-1,
              ).astype(np.float32)
        num_layers, num_lat, num_lon, num_vars = input_data.shape 
        input_data = input_data.reshape(num_lat, num_lon, num_vars*num_layers)
        input_data = np.nan_to_num(input_data)
        assert not np.isnan(input_data).any()

        # Stack the non-NWP features for input
        aux_data = np.stack(
            [
                (start[f"{var}"].values - const.ERA5_MEANS[f"{var}"])
                / (const.ERA5_STD[f"{var}"]+ 0.0001)
                for var in self.aux_features
            ],
            axis=-1,
        ).astype(np.float32)
        aux_data = np.nan_to_num(aux_data)
        assert not np.isnan(aux_data).any()
        assert input_data.shape[:2] == aux_data.shape[:2]

        # Stack space-time features for input
        lat_lons = np.array(np.meshgrid(start.latitude.values, start.longitude.values))
        sin_lats = np.sin(lat_lons[0,:,:])
        cos_lats = np.cos(lat_lons[0,:,:])
        sin_lons = np.sin(lat_lons[1,:,:])
        cos_lons = np.cos(lat_lons[1,:,:])

        date = start.time.values
        day_of_year = start.time.dt.dayofyear.values / 365.0
        sin_day_of_year = np.sin(day_of_year)*np.ones_like(lat_lons[0,:,:])
        cos_day_of_year = np.cos(day_of_year)**np.ones_like(lat_lons[0,:,:])

        space_time_data = np.stack([sin_lats, 
                                    cos_lats, 
                                    sin_lons, 
                                    cos_lons, 
                                    sin_day_of_year, 
                                    cos_day_of_year], axis = -1)
        space_time_data = np.nan_to_num(space_time_data).astype(np.float32)
        assert not np.isnan(space_time_data).any()
        assert input_data.shape[:2] == space_time_data.shape[:2]

        # Stack NWP features for output
        output_data = np.stack(
            [
                (end[f"{var}"].values - np.array(const.ERA5_MEANS[f"{var}"])[:, None, None])
                / (np.array(const.ERA5_STD[f"{var}"])[:, None, None] + 0.0001)
                for var in self.NWP_features
            ],
            axis=-1,
        ).astype(np.float32)
        num_layers, num_lat, num_lon, num_vars = output_data.shape 
        output_data = output_data.reshape(num_lat, num_lon, num_vars*num_layers)
        output_data = np.nan_to_num(output_data)
        assert not np.isnan(output_data).any()
        assert input_data.shape == output_data.shape

        input_data = np.concatenate( [input_data, aux_data, space_time_data], axis = -1)
        
        transform = transforms.Compose([transforms.ToTensor()])
        # Normalize now
        return (
            transform(input_data).transpose(0, 1).reshape(-1, input_data.shape[-1]),
            transform(output_data).transpose(0, 1).reshape(-1, output_data.shape[-1]),
        )

I still need to compute good estimates of mean and std of the dataset. Moreover the features and the training period are arbitrarily selected. In the next days I will try to compute the variance of the 3/6 hours changes as in #87.

gbruno16 avatar Mar 09 '24 15:03 gbruno16

@jacobbieker Thanks, that helps!

aavashsubedi avatar Mar 10 '24 09:03 aavashsubedi

@gbruno16 It Looks like most of the ERA5 stuff you have already taken care of! Nice! Perhaps an input variable that reads from a config file/list for data types (and then sorts them out into AUX/NWP features within the dataset?).

aavashsubedi avatar Mar 10 '24 09:03 aavashsubedi

Looks like theres already plenty of people working on this PR! Happy to either take the IFS dataset or can leave this one to @0xFrama and work on something else : )

aavashsubedi avatar Mar 10 '24 09:03 aavashsubedi

@aavashsubedi up to you! There is also #90 that I just added that might interest you? Probably a bit more involved than this one, but similar vein.

jacobbieker avatar Mar 11 '24 13:03 jacobbieker

Hi @aavashsubedi, I've already started working on this issue, but let me know if you want to work on it as well. I can look at something else, no worries :)

0xFrama avatar Mar 11 '24 15:03 0xFrama

Hi @aavashsubedi, I've already started working on this issue, but let me know if you want to work on it as well. I can look at something else, no worries :)

Heyy no worries! I haven't started working on this so all yours ; )

aavashsubedi avatar Mar 11 '24 18:03 aavashsubedi

Hi @jacobbieker, I'm having some trouble in calculating the mean and std for each of the variables in the dataset. My computer can't process that much data in a reasonable amount of time. Do you have any suggestion on how to proceed?

0xFrama avatar Mar 20 '24 15:03 0xFrama

Hi @jacobbieker, I'm having some trouble in calculating the mean and std for each of the variables in the dataset. My computer can't process that much data in a reasonable amount of time. Do you have any suggestion on how to proceed?

Not Jacob but I seem to have notifs enabled. I have been playing around with graph_weather on Kaggle maybe you can try this- plenty of ram and CPU processors on there? And this may not be the most surefire solution, but for mean/std, you can try the batched Monte Carlo estimate. What I mean is randomly loading batches/subsets of the entire dataset, computing the mean/std of this batch, and calculating the mean of that. This should give you an unbiased estimate of the true mean/std. It should be good enough(?- Jacob can chime in) until someone computes the population values if computing is a massive headache.

aavashsubedi avatar Mar 20 '24 15:03 aavashsubedi

Hi,

Yeah, the Monte Carlo estimate seems like a good way to do that. Otherwise just taking a few hundred random timesteps can give still a good approximation of the values. Besides Kaggle another option could be Planetary Computer, as you can access a Dask cluster with 2TB of ram for free there.

jacobbieker avatar Mar 20 '24 17:03 jacobbieker