flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

[Feature Request]

Open conceptofmind opened this issue 8 months ago • 7 comments

Feature Request

Hello,

Thank you for all of your great work.

I was wondering if it would be a reasonable inclusion to add even more fused linear activation functions to the repository? Such as gelu/sqrelu/geglu/etc.

Motivation

I would think having these additional functions would lead to greater variability in model pretraining.

Your Contribution

I attempted to implement a sqrelu linear class here following what you had done for swiglu. I am just using a basic mlp for example. I would need to add the jiterator as well:

def sqrelu_fwd_torch(x):
    r = F.relu(x)
    return (r * r).to(dtype=x.dtype)

def sqrelu_bwd_torch(g, x):
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

def sqrelu_fwdbwd_torch(x, g):
    dtype = x.dtype
    x, g = x.float(), g.float()
    r = F.relu(x)
    dx = (2.0 * g * r)
    y = (r * r)
    return dx.to(dtype), y.to(dtype)

@torch.compile
def sqrelu_fwd_compiled(x):
    r = F.relu(x)
    return (r * r).to(dtype=x.dtype)

@torch.compile
def sqrelu_bwd_compiled(g, x):
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

@torch.compile
def sqrelu_fwdbwd_compiled(x, g):
    dtype = x.dtype
    x, g = x.float(), g.float()
    r = F.relu(x)
    dx = (2.0 * g * r)
    y = (r * r)
    return dx.to(dtype), y.to(dtype)

autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type="cuda")
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type="cuda")

class SQReLULinearFunction(torch.autograd.Function):

    @staticmethod
    @autocast_custom_fwd
    def forward(ctx, x, weight, bias):
        with torch.no_grad():
            if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
                y = sqrelu_fwd_compiled(x)
            else:
                y = sqrelu_fwd_torch(x)

        out = F.linear(y, weight, bias)
        ctx.save_for_backward(x, weight)
        ctx.linear_bias_is_none = bias is None
        return out
    
    @staticmethod
    @autocast_custom_bwd
    def backward(ctx, dout):
        x, weight = ctx.saved_tensors
        dout = dout.reshape(-1, dout.shape[-1])
        dz = F.linear(dout, weight.t()).view_as(x)
        with torch.no_grad():
            if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
                dx, y = sqrelu_fwdbwd_compiled(x, dz)
            else:
                dx, y = sqrelu_fwdbwd_torch(x, dz)
        
        dlinear_weight = torch.einsum("bo,bi->oi", dout, y.reshape(-1, y.shape[-1]))
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
        return dx, dlinear_weight, dlinear_bias
    
sqrelu_linear = SQReLULinearFunction.apply

class SQReLULinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(self, x):
        return sqrelu_linear(x, self.weight, self.bias)

# MLP
class MLP(nn.Module):
    def __init__(
        self, 
        dim: int, 
        hidden_dim: int, 
        dropout: float,
        use_bias: bool,
    ):
        super().__init__()

        self.linear_in = nn.Linear(dim, hidden_dim, bias=use_bias)
        self.d1 = nn.Dropout(dropout)
        self.sqrelu_linear = SQReLULinear(hidden_dim, dim, bias=use_bias)
        self.d2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear_in(x)
        x = self.d1(x)
        x = self.sqrelu_linear(x)
        x = self.d2(x)
        return x

If verified for correctness. I could attempt to implement the others as well.

Not sure whether this is better as a discussion or feature request.

I appreciate your time and consideration.

Thank you,

Enrico

conceptofmind avatar Feb 26 '25 16:02 conceptofmind