standardize requests to optimize caching
Closes #208
Conflicts with develop fixed, this is ready for review.
The idea is to normalize the requests to optimize caching:
- len 1 iterables are squeezed (probably I could have done the opposite as well, but I'm not sure if all request parameters allow lists + the request shown by the portal also uses strings rather than lists when a single element is selected)
- sort len > 1 iterables and the request dict, so the cache key generated is not affected by the order of the elements. There's a few parameters that should not be sorted. area and grid are the only ones I am aware of. The parameters that can not be sorted are currently hardcoded, I guess they could become a global variable
Thank you for this PR. If it is just needed for the caching, one option is not to alter the request, just change the way the cache key is generated. E.g. in cds.py:
def _retrieve(self, dataset, request):
def retrieve(target, args):
self.client().retrieve(args[0], args[1], target)
# modify the request used for the cache key generation
request_cache = self._sort_request_for_caching(request)
return self.cache_file(
retrieve,
(dataset, request),
cache_args=(dataset, request_cache)
extension=EXTENSIONS.get(request.get("format"), ".cache"),
)
In my opinion it would be a cleaner solution. Note: it would require the modification in cache_file() in caching.py.
I thought about it, but I have 2 concerns:
- This would trigger 6 requests rather than just 2:
from_source("cds", ..., variable=["1", "2", "3", "4"], split_on={"variable": 2})
from_source("cds", ..., variable=["1", "3", "2", "4"], split_on={"variable": 2})
from_source("cds", ..., variable=["1", "4", "2", "3"], split_on={"variable": 2})
- The public attribute
requestsofCDSRetrieverdoes not show the request that has been actually submitted. I find it confusing from an user perspective.
I'll implement your suggestion if you still think that's better/cleaner.
You are right. To resolve point 2, I modified the sample code to separate the actual request from the one used for the cache key generation. Of course it still would not solve point 1 and I agree that you need to modify the actual request to cover that case.
However, please note that CDSRetriever is not part of the public api and as such should not be used to retrieve data from the CDS. The public api is from_source("cds", ...).
Of course it still would not solve point 1 and I agree that you need to modify the actual request to cover that case.
How do you want to proceed? Should earthkit-data cover that case (leave this PR as it is) or not (implement this)? If the latter, in EQC we will standardize the requests before feeding them to earthkit-data (i.e., we don't need this PR).
However, please note that CDSRetriever is not part of the public api and as such should not be used to retrieve data from the CDS. The public api is from_source("cds", ...).
If we don't use CDSRetriever directly, we can't associate the data downloaded with their requests (we need it to cache processed data as explained in point 3 here). If EQC should not use CDSRetriever.requests, then we need to chunk the requests ourselves before feeding them to earthkit-data (i.e., we can't use the split_on functionality implemented in https://github.com/ecmwf/earthkit-data/pull/227)
I still think CDSRetriever should not be used like that. This type of usage is not tested, can be refactored potentially breaking your code.
My suggestion is to pass an object to from_source that can perform the request preprocessing and able to return the actual requests sent to cds. E.g.
from earthkit.data import RequestPreprocessor
rp = RequestPreprocessor()
ds = from_source("cds",....., pre_processor=rp)
requests = rp.request
class RequestPreprocessor:
....
@property
def request(self):
...
In this way the pre-processing could be reused for other retrieval types like "mars". And you would be able to subclass it and change the behaviour without having to modify earthkit-data.
An alternative solution is to somehow query the request from each retrieved file source object.
Being able to pass some sort of setup/teardown class that allows custom pre/post-processing would work for us.
We can use it to pre-process the requests (e.g., the normalize decorators already implemented or the normalization in this PR) and post-process (e.g., global mean or time mean).
If you go that way and post-processed results are cached, then we don't even need to use our own caching. One important feature that we need is to be able to decide whether to post-process and cache each chunk separately (e.g., global mean & split_on="date") or post-process requests after concatenating them (e.g., time mean & split_on="date").
For reference, this is the sort of use EQC evaluators are currently doing:
def global_mean(ds: xr.Dataset, weighted) -> xr.Dataset:
if weighted:
ds = ds.weighted(ds["cell_area"])
return ds.mean(("latitude", "longitude"))
def time_mean(ds: xr.Dataset) -> xr.Dataset:
return ds.mean("time")
ds_global = download_and_transform(
collection_id,
requests,
chunks={"year": 1, "month": 1},
transform_each_chunk=True,
transform_func=global_mean,
transform_func_kwargs={"weighted": False},
) # First, cache raw data. Then, cache the global mean of each chunk. Finally, concatenate cached results.
ds_global_weighted = download_and_transform(
collection_id,
requests,
chunks={"year": 1, "month": 1},
transform_each_chunk=True,
transform_func=global_mean,
transform_func_kwargs={"weighted": True},
) # Cache the global weighted mean of each chunk re-using cached raw data
ds_time = download_and_transform(
collection_id,
requests,
chunks={"year": 1, "month": 1},
transform_each_chunk=False,
transform_func=time_mean,
) # Cache the time mean of all chunks concatenated re-using cached raw data
If you decide to implement https://github.com/ecmwf/earthkit-data/pull/228#issuecomment-1776781130, feel free to ping me if you'd like me to test it out.
Talking with @EddyCMWF and @JamesVarndell I realised that the reason why we need to be able to associate each file to the corresponding request was not clear. Here is a simpler example:
from earthkit.data import from_source
from earthkit.data.sources.cds import CdsRetriever
import cacholote
import xarray as xr
@cacholote.cacheable
def _global_mean(collection_id: str, request: dict[str, Any]) -> xr.Dataset:
ds = from_source("cds", collection_id, request).to_xarray()
return ds.mean(("x", "y"))
def global_mean(collection_id: str, *args: dict[str, Any], **kwargs: Any) -> xr.Dataset:
datasets = []
for request in CdsRetriever(collection_id, *args, **kwargs).requests:
datasets.append(_global_mean(collection_id, request))
return xr.merge(datasets)
ds_100yrs = global_mean(..., year=range(1900, 2000), split_on="year")
print(ds_100yrs.chunksizes) # {"year": (1, ) * 100}
# Retrieve cached data
ds_10yrs = global_mean(..., year=range(1990, 2000), split_on="year")
print(ds_10yrs.chunksizes) # {"year": (1, ) * 10}
Hi @malmans2, I am sorry I could not work on this PR for a while. I will resume reviewing next week.