State of the package and roadmap for development
Using this issue to document the current state of the package, remaining loose ends, and possible paths for enhancements.
1. Current state of the package
The package has all the functionality for running data assimilation experiments, split into four main sub-modules:
- dabench.data: Generate or download data for experiments. Includes a range of dynamical systems and interfaces with ERA5 on Google Cloud Platform. Outputs as an xarray DataArray.
- dabench.observer: Gather synthetic observations from data produced by dabench.data. Can randomly sample or specify specific locations and times. Observers can be stationary or move around between timesteps. Can add Gaussian error or manually specify error. Outputs as an xarray DataArray with some special variables and attributes to keep track of where, when, and how observations were sampled (used by the DA cycler).
- dabench.model: Parent class that users use to define their own forecast model, which must have a forecast() method. The idea is to allow flexibility in what that forecast() method is, but the input must always be an xarray DataArray and the output is a tuple with the final timestep of the forecast and the full forecast returned separately (both as xarray DataArrays). This is likely the trickiest part of the process, but if you’re using one of the data generators as your forecast model for the experiment it’s pretty straightforward.
- dabench.dacycler: Run data assimilation experiments using 3DVar, ETKF, Hybrid-Gain (3DVar combined with ETKF), 4D-Var, and 4DVar-Backprop. Requires an input state vector and observation vector (both xarray DataArrays), as well as some other parameters including the analysis window, estimated observation error (for building the observation covariance matrix), and number of cycles to run. Output is also an xarray DataArray.
The examples repo has straightforward examples of running basic data assimilation experiments. The great thing is that the process is always the same:
- Generate “nature run” data using dabench.data.
- Gather observations using dabench.observer.
- Define forecast model using dabench.model.
- Run the cycle using dabench.dacycler.
2. Loose ends (also see other GitHub issues)
-
Larger systems: Currently, the DA methods fail on large systems due to memory constraints and large matrix computations. For example, trying to run 3DVar on a ~100,000D system (like a simple dinosaur run) means trying to create and invert a 100,000 x 100,000 background covariance matrix. There are various ways to get around this: using a control variable transform, implementing sparse matrices, localized versions of the data assimilation methods e.g. the local ensemble transform Kalman filter, and/or rearranging the problem to use a linear solver instead of inverting matrices directly. We discuss some of these in equations 26 and 27 of the 4DVar-Backprop paper.
-
JAX errors in DA cycle runs: All of the DA methods are executed using jax’s jax.lax.scan() method, which creates some restrictions and sometimes generates unforeseen errors. All arrays passed into and out of each step of the scan must be of fixed shape. For example, to deal with DA experiments where there are different numbers of observations available at each cycle, we had to implement observation time and location masks. Problems can also be hard to debug because you usually can’t view the values of variables within jit-compiled functions, and the error messages from jax can be somewhat opaque. Writing more checks on user inputs that raise warnings and errors would help with ease-of-use. See the documentation for some more information: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html.
-
Errors with user-defined forecast functions: Related to the JAX challenges above, if users provide their own model and forecast() method, the da cycler may break if (1) the method converts a traced JAX array to a numpy array and (2) if the method does not return an xarray as the “carry” output (the last timestep in the forecast, returned as the first element of the output tuple from the forecast method) that perfectly matches the shape and metadata of the input xarray. For example, if the input xarray has a time dimension (even if it’s a single timestep) and the carry output does not, jax.lax.scan() will throw an error. We could provide more input checks and guidance for getting around this issue. Another possibility would be to have a second, non-jax pathway for running da cyclers to fall back on in cases where scan() doesn’t work (e.g. a simple “for” loop or apply).
3. Enhancements
-
Metrics: It would be great to have a sub-module that helps calculate basic metrics. I started to work on that in dabench/metrics, although that code uses simple JAX/numpy arrays rather than xarrays. For xarray, the xarray-skillscore package could be used as inspiration or to compute some metrics: https://github.com/xarray-contrib/xskillscore. Note: if the metrics are all calculated using jax then they could be used as loss functions for optimization.
-
New observer types: The observer class is highly flexible and if provided with times and locations can mimic satellites, weather balloons, and more. One possible enhancement is to interface with pyorbital and/or other packages to automatically generate locations and times instead of making users provide these manually. There could be types of observers, e.g. SatelliteObserver(), that inherit from the main parent class. This wouldn’t be a huge lift to add basic functionality, and could be a fun task. See the examples repo for some code snippets.
-
Improved documentation: Using ReadTheDocs and sphinx, we have an API guide available, but it would be great to include some more basic instructions and examples for installing and using the package.