flash-linear-attention
flash-linear-attention copied to clipboard
[Feature Request]
Feature Request
Hello,
Thank you for all of your great work.
I was wondering if it would be a reasonable inclusion to add even more fused linear activation functions to the repository? Such as gelu/sqrelu/geglu/etc.
Motivation
I would think having these additional functions would lead to greater variability in model pretraining.
Your Contribution
I attempted to implement a sqrelu linear class here following what you had done for swiglu. I am just using a basic mlp for example. I would need to add the jiterator as well:
def sqrelu_fwd_torch(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
def sqrelu_bwd_torch(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
def sqrelu_fwdbwd_torch(x, g):
dtype = x.dtype
x, g = x.float(), g.float()
r = F.relu(x)
dx = (2.0 * g * r)
y = (r * r)
return dx.to(dtype), y.to(dtype)
@torch.compile
def sqrelu_fwd_compiled(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.compile
def sqrelu_bwd_compiled(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
@torch.compile
def sqrelu_fwdbwd_compiled(x, g):
dtype = x.dtype
x, g = x.float(), g.float()
r = F.relu(x)
dx = (2.0 * g * r)
y = (r * r)
return dx.to(dtype), y.to(dtype)
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type="cuda")
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type="cuda")
class SQReLULinearFunction(torch.autograd.Function):
@staticmethod
@autocast_custom_fwd
def forward(ctx, x, weight, bias):
with torch.no_grad():
if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
y = sqrelu_fwd_compiled(x)
else:
y = sqrelu_fwd_torch(x)
out = F.linear(y, weight, bias)
ctx.save_for_backward(x, weight)
ctx.linear_bias_is_none = bias is None
return out
@staticmethod
@autocast_custom_bwd
def backward(ctx, dout):
x, weight = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dz = F.linear(dout, weight.t()).view_as(x)
with torch.no_grad():
if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
dx, y = sqrelu_fwdbwd_compiled(x, dz)
else:
dx, y = sqrelu_fwdbwd_torch(x, dz)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y.reshape(-1, y.shape[-1]))
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
return dx, dlinear_weight, dlinear_bias
sqrelu_linear = SQReLULinearFunction.apply
class SQReLULinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x):
return sqrelu_linear(x, self.weight, self.bias)
# MLP
class MLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
dropout: float,
use_bias: bool,
):
super().__init__()
self.linear_in = nn.Linear(dim, hidden_dim, bias=use_bias)
self.d1 = nn.Dropout(dropout)
self.sqrelu_linear = SQReLULinear(hidden_dim, dim, bias=use_bias)
self.d2 = nn.Dropout(dropout)
def forward(self, x):
x = self.linear_in(x)
x = self.d1(x)
x = self.sqrelu_linear(x)
x = self.d2(x)
return x
If verified for correctness. I could attempt to implement the others as well.
Not sure whether this is better as a discussion or feature request.
I appreciate your time and consideration.
Thank you,
Enrico
@conceptofmind Hi, why not directly using torch.compile? I think it would also lead to reasonable speedup.
@conceptofmind Hi, why not directly using
torch.compile? I think it would also lead to reasonable speedup.
I am currently using torch.compile but did not know whether it made sense to further fully fuse the linear + activation or linear + activation + linear + dropout in this case as opposed to just using pytorchs compile codegen alone.
Interesting to hear that it causes issues with tensor parallel!
Previously adding jiterator plus compile had seen decent speedups but was looking to improve that further.
@conceptofmind In your case, this is necessary. torch.compile is still not that wise to avoid one additional activation (which can be recomputed in bwd in a cheap way).
But it would require a lot of efforts to fit in scenarios like 4d parallelisms, so we just implement swiglu in fla.
I just walk through your code snippets and it makes sense to me. PRs are welcome if you do need them work with fla.
@conceptofmind In your case, this is necessary.
torch.compileis still not that wise to avoid one additional activation (which can be recomputed in bwd in a cheap way). But it would require a lot of efforts to fit in scenarios like 4d parallelisms, so we just implement swiglu infla.I just walk through your code snippets and it makes sense to me. PRs are welcome if you do need them work with fla.
Thank you fo the insight!
I will look into creating a PR then to include additional activation + linear functions following similarly to above.
I will make adjustments to this code to use the jiterator instead to be standard:
def sqrelu_fwd_torch(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
def sqrelu_bwd_torch(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
def sqrelu_fwdbwd_torch(x, g):
dtype = x.dtype
x, g = x.float(), g.float()
r = F.relu(x)
dx = (2.0 * g * r)
y = (r * r)
return dx.to(dtype), y.to(dtype)
@torch.compile
def sqrelu_fwd_compiled(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.compile
def sqrelu_bwd_compiled(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
@torch.compile
def sqrelu_fwdbwd_compiled(x, g):
dtype = x.dtype
x, g = x.float(), g.float()
r = F.relu(x)
dx = (2.0 * g * r)
y = (r * r)
return dx.to(dtype), y.to(dtype)
Would it make sense to also include say something that does the full mlp fusion with both linear_in + activation + linear_out?
@conceptofmind I think no lol, which makes designing TP plans much harder.
For native torch dtensor APIs, one need to design module pre/post hooks to handle inputs outputs, so we must pass linear weight as an arg to make it properly fit for TP. Passing several linear weights is not so neat for me.
Ok haha. I will exclude that then and just focus on the gelu + linear, sqrelu + linear, etc lol.
@conceptofmind Thank you in advance!
@conceptofmind Hi, I will close this issue since we have implemented most of your request, feel free to open a new PR if you have any issue.
https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/activations.py
@conceptofmind Hi, I will close this issue since we have implemented most of your request, feel free to open a new PR if you have any issue.
https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/activations.py
I will check it out.
Thank you,
Enrico