[RFC][Train] Allow for reporting results from multiple workers
Description
Current state
Currently, Ray Train only reports metrics from the first worker. This is fine in most cases, but for some applications, it may be desirable to report metrics from all workers and/or report aggregations, such as mean and std. We also require that functionality for some tests.
Note: Saving checkpoints from multiple workers is beyond the scope of this proposal.
Before Ray AIR, Ray Train supported reporting result aggregation through result preprocessors (https://github.com/ray-project/ray/pull/22099).
With the current structure of the DataParallelTrainer, the reporting code is fully contained within the _report method:
def _report(self, training_iterator: TrainingIterator) -> None:
for results in training_iterator:
first_worker_results = results[0]
tune.report(**first_worker_results)
As can be seen, it would be trivial to extend this functionality to arbitrary number of workers or aggregation logic. Below are two proposals on how to allow users to do that in a lightweight manner.
Proposal 1: Promote _report to DeveloperAPI and encourage users to subclass
In this proposal, we encourage users to simply subclass DataParallelTrainer/TorchTrainer (and so on) and override the report method with their own custom logic, eg.
class TorchTrainerMean(TorchTrainer):
def report(self, training_iterator: TrainingIterator) -> None:
for results in training_iterator:
mean_results = {
f"mean_{k}": np.mean([result[k] for result in results])
for k in results[0]
if not k.startswith("_")
}
tune.report(mean_results)
Proposal 2: Add results_processing_fn argument to DataParallelTrainer
The class would be modified to include:
def __init__(
self,
...,
*,
results_processing_fn: Callable[[List[Dict[str, Any]]], Dict[str, Any]]=lambda results: results[0]
):
...
def _report(self, training_iterator: TrainingIterator) -> None:
for results in training_iterator:
processed_results = self.results_processing_fn(results)
tune.report(**processed_results)
Proposal 3: Direct users to use third party libraries like torchmetrics
For Torch, users can use torchmetrics, which has built-in support for DDP. Similar solutions may exist for Tensorflow. It's unclear how that supports non-metric usecases, such as eg. time measurement, profiling info such as memory usage etc. On the other hand, this would require us to only update documentation to mention this approach.
Conclusion
Either proposal would be a lightweight way to allow users to modify the data reported to Tune. I do not have a personal preference towards either, though I feel like Proposal 2 fits better with the rest of the API.
Proposal 3 requires only documentation changes, and can be implemented independently (tracked here https://github.com/ray-project/ray/issues/31434)
Use case
No response
Thanks for putting this together!
Few questions:
- Is this only a problem with
DLTrainerand its subclasses? Would it make sense to have aresults_processing_fnfor other trainers likeXGBoostTrainer? - Would
results_processing_fnbe exposed inTorchTrainerandTensorflowTrainer? - Would it make sense to place
results_processing_fnin a config likeRunConfiginstead of exposing it as a top-level parameter?
- Other trainers have different internals - eg. for XGBoost, the reporting is done by a callback (which too can be overriden to report metrics from multiple workers). I'd like to focus on DL for now as this is where we had the most requests and which is the simplest to tackle.
- Yes!
- I think it shouldn't, unless it's supported by all Trainers. We could make it an argument in
TorchConfig(and so on), but I am not sure whether it makes sense to put it there, as those configs deal with setting up the workers and not with what happens on the Tune side.
Currently, Ray Train only reports metrics from the first worker. This is fine in most cases, but for some applications, it may be desirable to report metrics from all workers and/or report aggregations, such as mean and std. We also require that functionality for some tests.
For aggregations, can they use torchmetrics instead? That's becoming the standard in the pytorch ecosystem AFAICT
Yeah, it's possible to use that right now. That being said,torchmetrics doesn't cover tensorflow or anything else that you may want to log from multiple workers aside from actual metrics.
@Yard1 how would I use torchmetrics for aggregation? Wouldn't you run into the same problem of not having access to all of the results?
@Yard1 Also, what are metrics that you want to aggregate from all workers individually?
@bveeramani torchmetrics is distributed training compatible... it will automatically aggregate across workers using allreduce.
@richardliaw I was thinking profiling information could be useful? I don't have a special need myself - this is something we have been talking about on and off for a while. Some users were also interested in this feature, eg. https://discuss.ray.io/t/how-can-i-synchronization-metrics-in-ray-train-valid-loop/8500 https://discuss.ray.io/t/pytorch-distributedtrainable-tune-report-on-rank-0-only/5127/1
For both cases seems like we just need to provide best practices - telling users to do a sum/average/median across all workers with torchmetrics, and also reporting the same things on all workers if necessary?
I'll add that as a proposal!
sorry if I wasn’t clear before. I don’t think we need to discuss multiple options here because I don’t see a very concrete use case yet for any of other alternatives.
Let me know if that makes sense.
On Wed, Jan 4, 2023 at 3:29 AM Antoni Baum @.***> wrote:
I'll add that as a proposal!
— Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/31409#issuecomment-1370808969, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABCRZZKIGBUD3FPIQRHTNZLWQVNJVANCNFSM6AAAAAATQASTEQ . You are receiving this because you were mentioned.Message ID: @.***>
That's fair. In any case, if we do not want to provide an API for this and instead rely on third party tools like torchmetrics, we should update documentation & provide an example, so that's still an action item.
Yep exactly. Can we perhaps update this issue to track the action item?
On Wed, Jan 4, 2023 at 9:29 AM Antoni Baum @.***> wrote:
That's fair. In any case, if we do not want to provide an API for this and instead rely on third party tools like torchmetrics, we should update documentation & provide an example, so that's still an action item.
— Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/31409#issuecomment-1371216664, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABCRZZPE474CTBNADQ6TBDDWQWXN7ANCNFSM6AAAAAATQASTEQ . You are receiving this because you were mentioned.Message ID: @.***>
I'll make a separate issue for that, and we can defer this one until we have a concrete usecase.
https://github.com/ray-project/ray/issues/31434
Closing this one since we have a separate issue for now. When we have a concrete use case we can bring it up again!