jetseg icon indicating copy to clipboard operation
jetseg copied to clipboard

A couple efficiency question

Open bobarbobo opened this issue 10 months ago • 0 comments

hi, thanks for this repo, here are a couple efficiency questions: a) In REU function Why is there a cast to double ? Why only in the backward pass ? I think this is unefficient, maybe useless and could be left in FP32, what do you think ?

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """Performs a backpropagation."""

        (data, ) = ctx.saved_tensors
        data = data.double()
        grad = torch.where(data >= 0.0, 1.0, torch.exp(data) * (data + 1))
        return grad_output * grad

b) why no usage of FP16 mixed prediction which could speed up training and inference ?

c) in REU function

    def forward(ctx, data: torch.Tensor) -> torch.Tensor:
        """Performs a forward pass."""

        ctx.save_for_backward(data)
        return torch.where(data <= 0.0, data * torch.exp(data), data)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """Performs a backpropagation."""

        (data, ) = ctx.saved_tensors
        data = data.double()
        grad = torch.where(data >= 0.0, 1.0, torch.exp(data) * (data + 1))
        return grad_output * grad

The exponential is computed 2 times (forward, backward) over the whole range (positive and negative). Is think is could be computed only 1 time, and only for the negative value, for efficiency

bobarbobo avatar Aug 14 '23 11:08 bobarbobo