earthkit-data icon indicating copy to clipboard operation
earthkit-data copied to clipboard

standardize requests to optimize caching

Open malmans2 opened this issue 2 years ago • 10 comments

Closes #208

malmans2 avatar Oct 13 '23 06:10 malmans2

CLA assistant check
All committers have signed the CLA.

FussyDuck avatar Oct 13 '23 06:10 FussyDuck

Conflicts with develop fixed, this is ready for review.

The idea is to normalize the requests to optimize caching:

  1. len 1 iterables are squeezed (probably I could have done the opposite as well, but I'm not sure if all request parameters allow lists + the request shown by the portal also uses strings rather than lists when a single element is selected)
  2. sort len > 1 iterables and the request dict, so the cache key generated is not affected by the order of the elements. There's a few parameters that should not be sorted. area and grid are the only ones I am aware of. The parameters that can not be sorted are currently hardcoded, I guess they could become a global variable

malmans2 avatar Oct 17 '23 11:10 malmans2

Thank you for this PR. If it is just needed for the caching, one option is not to alter the request, just change the way the cache key is generated. E.g. in cds.py:

 def _retrieve(self, dataset, request):
        def retrieve(target, args):
            self.client().retrieve(args[0], args[1], target)

        # modify the request used for the cache key generation
        request_cache = self._sort_request_for_caching(request)
        
        return self.cache_file(
            retrieve,
            (dataset, request),
            cache_args=(dataset, request_cache)
            extension=EXTENSIONS.get(request.get("format"), ".cache"),
        )

In my opinion it would be a cleaner solution. Note: it would require the modification in cache_file() in caching.py.

sandorkertesz avatar Oct 20 '23 14:10 sandorkertesz

I thought about it, but I have 2 concerns:

  1. This would trigger 6 requests rather than just 2:
from_source("cds", ..., variable=["1", "2", "3", "4"], split_on={"variable": 2})
from_source("cds", ..., variable=["1", "3", "2", "4"], split_on={"variable": 2})
from_source("cds", ..., variable=["1", "4", "2", "3"], split_on={"variable": 2})
  1. The public attribute requests of CDSRetriever does not show the request that has been actually submitted. I find it confusing from an user perspective.

I'll implement your suggestion if you still think that's better/cleaner.

malmans2 avatar Oct 20 '23 14:10 malmans2

You are right. To resolve point 2, I modified the sample code to separate the actual request from the one used for the cache key generation. Of course it still would not solve point 1 and I agree that you need to modify the actual request to cover that case.

However, please note that CDSRetriever is not part of the public api and as such should not be used to retrieve data from the CDS. The public api is from_source("cds", ...).

sandorkertesz avatar Oct 23 '23 11:10 sandorkertesz

Of course it still would not solve point 1 and I agree that you need to modify the actual request to cover that case.

How do you want to proceed? Should earthkit-data cover that case (leave this PR as it is) or not (implement this)? If the latter, in EQC we will standardize the requests before feeding them to earthkit-data (i.e., we don't need this PR).

However, please note that CDSRetriever is not part of the public api and as such should not be used to retrieve data from the CDS. The public api is from_source("cds", ...).

If we don't use CDSRetriever directly, we can't associate the data downloaded with their requests (we need it to cache processed data as explained in point 3 here). If EQC should not use CDSRetriever.requests, then we need to chunk the requests ourselves before feeding them to earthkit-data (i.e., we can't use the split_on functionality implemented in https://github.com/ecmwf/earthkit-data/pull/227)

malmans2 avatar Oct 23 '23 12:10 malmans2

I still think CDSRetriever should not be used like that. This type of usage is not tested, can be refactored potentially breaking your code.

My suggestion is to pass an object to from_source that can perform the request preprocessing and able to return the actual requests sent to cds. E.g.

   from earthkit.data import RequestPreprocessor
   rp = RequestPreprocessor()
   ds = from_source("cds",....., pre_processor=rp)
   requests = rp.request
  class RequestPreprocessor:
     ....
     @property
     def request(self):
         ...

In this way the pre-processing could be reused for other retrieval types like "mars". And you would be able to subclass it and change the behaviour without having to modify earthkit-data.

An alternative solution is to somehow query the request from each retrieved file source object.

sandorkertesz avatar Oct 24 '23 08:10 sandorkertesz

Being able to pass some sort of setup/teardown class that allows custom pre/post-processing would work for us.

We can use it to pre-process the requests (e.g., the normalize decorators already implemented or the normalization in this PR) and post-process (e.g., global mean or time mean).

If you go that way and post-processed results are cached, then we don't even need to use our own caching. One important feature that we need is to be able to decide whether to post-process and cache each chunk separately (e.g., global mean & split_on="date") or post-process requests after concatenating them (e.g., time mean & split_on="date").

For reference, this is the sort of use EQC evaluators are currently doing:

def global_mean(ds: xr.Dataset, weighted) -> xr.Dataset:
    if weighted:
        ds = ds.weighted(ds["cell_area"])
    return ds.mean(("latitude", "longitude"))
    
def time_mean(ds: xr.Dataset) -> xr.Dataset:
    return ds.mean("time")


ds_global = download_and_transform(
    collection_id,
    requests,
    chunks={"year": 1, "month": 1},
    transform_each_chunk=True,
    transform_func=global_mean,
    transform_func_kwargs={"weighted": False},
)  # First, cache raw data. Then, cache the global mean of each chunk. Finally, concatenate cached results.

ds_global_weighted = download_and_transform(
    collection_id,
    requests,
    chunks={"year": 1, "month": 1},
    transform_each_chunk=True,
    transform_func=global_mean,
    transform_func_kwargs={"weighted": True},
)  # Cache the global weighted mean of each chunk re-using cached raw data

ds_time = download_and_transform(
    collection_id,
    requests,
    chunks={"year": 1, "month": 1},
    transform_each_chunk=False,
    transform_func=time_mean,
)  # Cache the time mean of all chunks concatenated re-using cached raw data

If you decide to implement https://github.com/ecmwf/earthkit-data/pull/228#issuecomment-1776781130, feel free to ping me if you'd like me to test it out.

malmans2 avatar Oct 24 '23 10:10 malmans2

Talking with @EddyCMWF and @JamesVarndell I realised that the reason why we need to be able to associate each file to the corresponding request was not clear. Here is a simpler example:

from earthkit.data import from_source
from earthkit.data.sources.cds import CdsRetriever

import cacholote
import xarray as xr


@cacholote.cacheable
def _global_mean(collection_id: str, request: dict[str, Any]) -> xr.Dataset:
    ds = from_source("cds", collection_id, request).to_xarray()
    return ds.mean(("x", "y"))


def global_mean(collection_id: str, *args: dict[str, Any], **kwargs: Any) -> xr.Dataset:
    datasets = []
    for request in CdsRetriever(collection_id, *args, **kwargs).requests:
        datasets.append(_global_mean(collection_id, request))
    return xr.merge(datasets)


ds_100yrs = global_mean(..., year=range(1900, 2000), split_on="year")
print(ds_100yrs.chunksizes)  # {"year": (1, ) * 100}

# Retrieve cached data
ds_10yrs = global_mean(..., year=range(1990, 2000), split_on="year")
print(ds_10yrs.chunksizes)  # {"year": (1, ) * 10}

malmans2 avatar Oct 26 '23 14:10 malmans2

Hi @malmans2, I am sorry I could not work on this PR for a while. I will resume reviewing next week.

sandorkertesz avatar Nov 24 '23 11:11 sandorkertesz