ray
ray copied to clipboard
Distributed XGBoostTrainer Improvement
Description
xgboost_ray.train
can accept a list of parquet or csv files in a RayDMatrix object as input (ref). This does not appear to be possible for XGBoostTrainer. This leads to users loading all the data in the driver rather than the workers when using XGBoostTrainer, which limits scaling. A solution to this is to allow XGBoostTrainer to handle lists of parquet or csv paths as xgboost_ray.train
does.
Use case
No response
I'd argue that this is somewhat of a bug as it is a regression from our existing xgboost behavior...
@richardliaw , lol, I also struggled with whether there was a good reason why this feature wasn't included. Locally, I was able to add the ability to handle parquet paths by changing this function to this:
def _get_dmatrices(
self, dmatrix_params: Dict[str, Any]
) -> Dict[str, "xgboost_ray.RayDMatrix"]:
result = {}
for k, v in self.datasets.items():
# if RayFileType.PARQUET or CSV, the Dataset should only contain a list of file paths
if dmatrix_params[k].get("filetype") in [RayFileType.PARQUET, RayFileType.CSV]:
result.update({k: self._dmatrix_cls(v.take_all(), label=self.label_column, **dmatrix_params.get(k, {}))})
else:
result.update({k: self._dmatrix_cls(v, label=self.label_column, **dmatrix_params.get(k, {}))})
return result
Would you like me to submit a PR for this change?
cc @Yard1
cc @krfricke as well
This will be fixed as part of the XGB refactor which is in progress right now.