torchrec
torchrec copied to clipboard
[Question] Is there FP8 embeddings support for training?
Hello, it looks like EmbeddingBagCollection forces data type to be float32 or float16 during initialization. https://github.com/pytorch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L179
Is there any support to make embedding be float8? Note, this is for training.
Doesn't look like it no, what is your use case? Feel free to put up a pull request
Hello @PaulZhang12, thanks for your reply. The use case is a normal deep learning recommendation model training with all the embeddings in FP8 format. The reason I do not use FP32 or FP16 embeddings is because I want to save memory. A simple example as below:
import torch
import torchrec
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
class myModel(torch.nn.Module):
def __init__(self, input_size: int, output_size: int):
super(myModel, self).__init__()
self.L= torch.nn.Linear(input_size, output_size)
self.ebc = torchrec.EmbeddingBagCollection(
device="cpu",
tables=[
torchrec.EmbeddingBagConfig(
name="t1",
embedding_dim=8,
num_embeddings=32,
feature_names=["f1"],
pooling=torchrec.PoolingType.SUM,
data_type=torchrec.modules.embedding_configs.DataType.FP8,
),
torchrec.EmbeddingBagConfig(
name="t2",
embedding_dim=8,
num_embeddings=32,
feature_names=["f2"],
pooling=torchrec.PoolingType.SUM,
data_type=torchrec.modules.embedding_configs.DataType.FP8,
),
],
)
def forward(self, kjt):
embeddings = self.ebc(kjt)
input = [embeddings ["f1"], embeddings ["f2"]]
cat = torch.cat(input, dim=1)
output = self.L(cat)
return output
#Training
model = myModel(input_size=16, output_size=1)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for _ in range(1000):
optimizer.zero_grad()
kjt = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.randint(0, 31, (8,)),
lengths=torch.tensor([2, 2, 1, 3]),
)
prediction = model(kjt)
target = torch.randint(0, 1, (2, 1))
loss = loss_fn(prediction, target.float())
loss.backward()
optimizer.step()