flash-linear-attention
flash-linear-attention copied to clipboard
[Feature Request]
Feature Request
Hello,
Thank you for all of your great work.
I was wondering if it would be a reasonable inclusion to add even more fused linear activation functions to the repository? Such as gelu/sqrelu/geglu/etc.
Motivation
I would think having these additional functions would lead to greater variability in model pretraining.
Your Contribution
I attempted to implement a sqrelu linear class here following what you had done for swiglu. I am just using a basic mlp for example. I would need to add the jiterator as well:
def sqrelu_fwd_torch(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
def sqrelu_bwd_torch(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
def sqrelu_fwdbwd_torch(x, g):
dtype = x.dtype
x, g = x.float(), g.float()
r = F.relu(x)
dx = (2.0 * g * r)
y = (r * r)
return dx.to(dtype), y.to(dtype)
@torch.compile
def sqrelu_fwd_compiled(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.compile
def sqrelu_bwd_compiled(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
@torch.compile
def sqrelu_fwdbwd_compiled(x, g):
dtype = x.dtype
x, g = x.float(), g.float()
r = F.relu(x)
dx = (2.0 * g * r)
y = (r * r)
return dx.to(dtype), y.to(dtype)
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type="cuda")
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type="cuda")
class SQReLULinearFunction(torch.autograd.Function):
@staticmethod
@autocast_custom_fwd
def forward(ctx, x, weight, bias):
with torch.no_grad():
if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
y = sqrelu_fwd_compiled(x)
else:
y = sqrelu_fwd_torch(x)
out = F.linear(y, weight, bias)
ctx.save_for_backward(x, weight)
ctx.linear_bias_is_none = bias is None
return out
@staticmethod
@autocast_custom_bwd
def backward(ctx, dout):
x, weight = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dz = F.linear(dout, weight.t()).view_as(x)
with torch.no_grad():
if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor):
dx, y = sqrelu_fwdbwd_compiled(x, dz)
else:
dx, y = sqrelu_fwdbwd_torch(x, dz)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y.reshape(-1, y.shape[-1]))
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
return dx, dlinear_weight, dlinear_bias
sqrelu_linear = SQReLULinearFunction.apply
class SQReLULinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x):
return sqrelu_linear(x, self.weight, self.bias)
# MLP
class MLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
dropout: float,
use_bias: bool,
):
super().__init__()
self.linear_in = nn.Linear(dim, hidden_dim, bias=use_bias)
self.d1 = nn.Dropout(dropout)
self.sqrelu_linear = SQReLULinear(hidden_dim, dim, bias=use_bias)
self.d2 = nn.Dropout(dropout)
def forward(self, x):
x = self.linear_in(x)
x = self.d1(x)
x = self.sqrelu_linear(x)
x = self.d2(x)
return x
If verified for correctness. I could attempt to implement the others as well.
Not sure whether this is better as a discussion or feature request.
I appreciate your time and consideration.
Thank you,
Enrico