flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

[Feature Request]

Open conceptofmind opened this issue 9 months ago • 7 comments

Feature Request

Hello,

Thank you for all of your great work.

I was wondering if it would be a reasonable inclusion to add even more fused linear activation functions to the repository? Such as gelu/sqrelu/geglu/etc.

Motivation

I would think having these additional functions would lead to greater variability in model pretraining.

Your Contribution

I attempted to implement a sqrelu linear class here following what you had done for swiglu. I am just using a basic mlp for example. I would need to add the jiterator as well:

def sqrelu_fwd_torch(x):
    r = F.relu(x)
    return (r * r).to(dtype=x.dtype)

def sqrelu_bwd_torch(g, x):
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

def sqrelu_fwdbwd_torch(x, g):
    dtype = x.dtype
    x, g = x.float(), g.float()
    r = F.relu(x)
    dx = (2.0 * g * r)
    y = (r * r)
    return dx.to(dtype), y.to(dtype)

@torch.compile
def sqrelu_fwd_compiled(x):
    r = F.relu(x)
    return (r * r).to(dtype=x.dtype)

@torch.compile
def sqrelu_bwd_compiled(g, x):
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

@torch.compile
def sqrelu_fwdbwd_compiled(x, g):
    dtype = x.dtype
    x, g = x.float(), g.float()
    r = F.relu(x)
    dx = (2.0 * g * r)
    y = (r * r)
    return dx.to(dtype), y.to(dtype)

autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type="cuda")
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type="cuda")

class SQReLULinearFunction(torch.autograd.Function):

    @staticmethod
    @autocast_custom_fwd
    def forward(ctx, x, weight, bias):
        with torch.no_grad():
            if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
                y = sqrelu_fwd_compiled(x)
            else:
                y = sqrelu_fwd_torch(x)

        out = F.linear(y, weight, bias)
        ctx.save_for_backward(x, weight)
        ctx.linear_bias_is_none = bias is None
        return out
    
    @staticmethod
    @autocast_custom_bwd
    def backward(ctx, dout):
        x, weight = ctx.saved_tensors
        dout = dout.reshape(-1, dout.shape[-1])
        dz = F.linear(dout, weight.t()).view_as(x)
        with torch.no_grad():
            if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
                dx, y = sqrelu_fwdbwd_compiled(x, dz)
            else:
                dx, y = sqrelu_fwdbwd_torch(x, dz)
        
        dlinear_weight = torch.einsum("bo,bi->oi", dout, y.reshape(-1, y.shape[-1]))
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
        return dx, dlinear_weight, dlinear_bias
    
sqrelu_linear = SQReLULinearFunction.apply

class SQReLULinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(self, x):
        return sqrelu_linear(x, self.weight, self.bias)

# MLP
class MLP(nn.Module):
    def __init__(
        self, 
        dim: int, 
        hidden_dim: int, 
        dropout: float,
        use_bias: bool,
    ):
        super().__init__()

        self.linear_in = nn.Linear(dim, hidden_dim, bias=use_bias)
        self.d1 = nn.Dropout(dropout)
        self.sqrelu_linear = SQReLULinear(hidden_dim, dim, bias=use_bias)
        self.d2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear_in(x)
        x = self.d1(x)
        x = self.sqrelu_linear(x)
        x = self.d2(x)
        return x

If verified for correctness. I could attempt to implement the others as well.

Not sure whether this is better as a discussion or feature request.

I appreciate your time and consideration.

Thank you,

Enrico

conceptofmind avatar Feb 26 '25 16:02 conceptofmind

@conceptofmind Hi, why not directly using torch.compile? I think it would also lead to reasonable speedup.

yzhangcs avatar Feb 26 '25 16:02 yzhangcs

@conceptofmind Hi, why not directly using torch.compile? I think it would also lead to reasonable speedup.

I am currently using torch.compile but did not know whether it made sense to further fully fuse the linear + activation or linear + activation + linear + dropout in this case as opposed to just using pytorchs compile codegen alone.

Interesting to hear that it causes issues with tensor parallel!

Previously adding jiterator plus compile had seen decent speedups but was looking to improve that further.

conceptofmind avatar Feb 26 '25 16:02 conceptofmind

@conceptofmind In your case, this is necessary. torch.compile is still not that wise to avoid one additional activation (which can be recomputed in bwd in a cheap way). But it would require a lot of efforts to fit in scenarios like 4d parallelisms, so we just implement swiglu in fla.

I just walk through your code snippets and it makes sense to me. PRs are welcome if you do need them work with fla.

yzhangcs avatar Feb 26 '25 16:02 yzhangcs

@conceptofmind In your case, this is necessary. torch.compile is still not that wise to avoid one additional activation (which can be recomputed in bwd in a cheap way). But it would require a lot of efforts to fit in scenarios like 4d parallelisms, so we just implement swiglu in fla.

I just walk through your code snippets and it makes sense to me. PRs are welcome if you do need them work with fla.

Thank you fo the insight!

I will look into creating a PR then to include additional activation + linear functions following similarly to above.

I will make adjustments to this code to use the jiterator instead to be standard:

def sqrelu_fwd_torch(x):
    r = F.relu(x)
    return (r * r).to(dtype=x.dtype)

def sqrelu_bwd_torch(g, x):
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

def sqrelu_fwdbwd_torch(x, g):
    dtype = x.dtype
    x, g = x.float(), g.float()
    r = F.relu(x)
    dx = (2.0 * g * r)
    y = (r * r)
    return dx.to(dtype), y.to(dtype)

@torch.compile
def sqrelu_fwd_compiled(x):
    r = F.relu(x)
    return (r * r).to(dtype=x.dtype)

@torch.compile
def sqrelu_bwd_compiled(g, x):
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

@torch.compile
def sqrelu_fwdbwd_compiled(x, g):
    dtype = x.dtype
    x, g = x.float(), g.float()
    r = F.relu(x)
    dx = (2.0 * g * r)
    y = (r * r)
    return dx.to(dtype), y.to(dtype)

Would it make sense to also include say something that does the full mlp fusion with both linear_in + activation + linear_out?

conceptofmind avatar Feb 26 '25 16:02 conceptofmind

@conceptofmind I think no lol, which makes designing TP plans much harder.

For native torch dtensor APIs, one need to design module pre/post hooks to handle inputs outputs, so we must pass linear weight as an arg to make it properly fit for TP. Passing several linear weights is not so neat for me.

yzhangcs avatar Feb 26 '25 16:02 yzhangcs

Ok haha. I will exclude that then and just focus on the gelu + linear, sqrelu + linear, etc lol.

conceptofmind avatar Feb 26 '25 16:02 conceptofmind

@conceptofmind Thank you in advance!

yzhangcs avatar Feb 26 '25 16:02 yzhangcs

@conceptofmind Hi, I will close this issue since we have implemented most of your request, feel free to open a new PR if you have any issue.

https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/activations.py

zhiyuan1i avatar Aug 06 '25 06:08 zhiyuan1i

@conceptofmind Hi, I will close this issue since we have implemented most of your request, feel free to open a new PR if you have any issue.

https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/activations.py

I will check it out.

Thank you,

Enrico

conceptofmind avatar Aug 06 '25 23:08 conceptofmind