DeepCTR
DeepCTR copied to clipboard
加载大数据的训练流程
我最近在使用deepCtr框架的时候,遇到问题:如果加载数据量过大的情况下,目前的数据加载、训练方式都会报内存错误。我也调研了相关解决方案,但是限于自身代码能力,总是改动不成功。希望deepctr(包括deepctr-torch)能够给出example,在数据量过大(内存加载不下)的情况下进行训练的方式,这有助于deepctr的实际工业化。
这种情况下,应该使用 TensorFlow 的 dataset pipeline,常用的接口就是 tf.data.Dataset.from_generator 来构造dataset。PyTroch 也有相应的 dataloader 接口。TensorFlow 的例子可以参考我们针对 k8s 分布式训练的例子的 input_fn https://github.com/sql-machine-learning/elasticdl/blob/develop/model_zoo/iris/dnn_estimator.py#L101-L110
Hi, deepctr_torch可以利用pytorch库里的IterableDataset来构建流式的数据输入,然后构造DataLoader。 目前deepctr_torch.models.basemodel不支持DataLoader的输入,但是可以看一下deepctr_torch.models.basemodel.fit的代码,模型训练时候依然是将输入转为DataLoader再训练的,所以可以修改一下源码,直接将DataLoader作为fit的输入,就可以避免内存错误的问题了。 Hi, you can use torch.utils.data.IterableDataset to construct flow dataset and use torch.utils.data.DataLoader to construct the model input. DataLoader is not supported in the newest version of deepctr_torch.models.basemodel. However, you can check the codes in deepctr_torch.models.basemodel.fit, then you will find the fit method transform the (x, y) to DataLoader before training. So, you can revise the codes and use DataLoader as one of the fit method's parameters directly. I think this may solve your problem.