jetseg
jetseg copied to clipboard
A couple efficiency question
hi, thanks for this repo, here are a couple efficiency questions: a) In REU function Why is there a cast to double ? Why only in the backward pass ? I think this is unefficient, maybe useless and could be left in FP32, what do you think ?
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""Performs a backpropagation."""
(data, ) = ctx.saved_tensors
data = data.double()
grad = torch.where(data >= 0.0, 1.0, torch.exp(data) * (data + 1))
return grad_output * grad
b) why no usage of FP16 mixed prediction which could speed up training and inference ?
c) in REU function
def forward(ctx, data: torch.Tensor) -> torch.Tensor:
"""Performs a forward pass."""
ctx.save_for_backward(data)
return torch.where(data <= 0.0, data * torch.exp(data), data)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""Performs a backpropagation."""
(data, ) = ctx.saved_tensors
data = data.double()
grad = torch.where(data >= 0.0, 1.0, torch.exp(data) * (data + 1))
return grad_output * grad
The exponential is computed 2 times (forward, backward) over the whole range (positive and negative). Is think is could be computed only 1 time, and only for the negative value, for efficiency