mamba
mamba copied to clipboard
How to compute MACs or FLOPs of mamba
How to compute MACs or FLOPs of mamba?
We calc FLOPs based on the ref code, though it is very different from the real speed in practise.
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
import numpy as np
# fvcore.nn.jit_handles
def get_flops_einsum(input_shapes, equation):
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
assert not with_complex
flops = 0 # below code flops = 0
if False:
...
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
"""
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
if with_Group:
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
else:
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
if False:
...
"""
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
"""
in_for_flops = B * D * N
if with_Group:
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
else:
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
flops += L * in_for_flops
if False:
...
"""
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
"""
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
if False:
...
"""
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
"""
return flops
def selective_scan_flop_jit(inputs, outputs):
# xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip)
assert inputs[0].debugName().startswith("xs") # (B, D, L)
assert inputs[2].debugName().startswith("As") # (D, N)
assert inputs[3].debugName().startswith("Bs") # (D, N)
with_Group = len(inputs[3].type().sizes()) == 4
with_D = inputs[5].debugName().startswith("Ds")
if not with_D:
with_z = inputs[5].debugName().startswith("z")
else:
with_z = inputs[6].debugName().startswith("z")
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group)
return flops
The formula we used is 9 * d_state * d_model
(times batch size times sequence length). This is for a forward pass, so triple that for forward + backward pass.
Note that these are the flops on top of the standard 2 * parameters * tokens
FLOP count incurred from linear layers (or 6 * D^2 * tokens
total for forward + backward).
This is a brief explanation:
Note that the cost of computing the input-dependent dt/B/C is baked into the linear layer FLOP counts above
We ignore batch
and d_model
in the calculations below, since it’s all trivially parallelized over these dimensions.
The $Bx$ and $Ch$ calculations have $2LN$ (2 * seqlen * d_state
) mults and $LN$ adds (adds are for the $C$ part only)
Remaining flops are associative scan on $N$ (d_state
) independent recurrences.
- $2L$ associative operations
- SSM op is
(a1, b1) o (a2, b2) = (a1a2, a2b1 + b2)
- two multiply and one add = 3 FLOPs per associative operation
$2L * 3 = 6L$ total
Summing these gives the $9LN$.
- 2L associative operations
Thank you for your quick reply. Can you explain that why is there 2*L associative operations, but not L?
If you look at the algorithm for associative scan that's how it works. See https://en.wikipedia.org/wiki/Prefix_sum for example
Also note that the above is not accounting for the expansion factor of the Mamba block. In other words the number of channels of the selective SSM scan is 2*d_model
Many thanks. I think I've got the answer.
We calc FLOPs based on the ref code, though it is very different from the real speed in practise.
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 if False: ... """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: delta = delta + delta_bias[..., None].float() if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] """ flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") if False: ... """ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None """ in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if False: ... """ for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: if C.dim() == 3: y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) """ if with_D: flops += B * D * L if with_Z: flops += B * D * L if False: ... """ out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) """ return flops def selective_scan_flop_jit(inputs, outputs): # xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip) assert inputs[0].debugName().startswith("xs") # (B, D, L) assert inputs[2].debugName().startswith("As") # (D, N) assert inputs[3].debugName().startswith("Bs") # (D, N) with_Group = len(inputs[3].type().sizes()) == 4 with_D = inputs[5].debugName().startswith("Ds") if not with_D: with_z = inputs[5].debugName().startswith("z") else: with_z = inputs[6].debugName().startswith("z") B, D, L = inputs[0].type().sizes() N = inputs[2].type().sizes()[1] flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group) return flops
Hi @MzeroMiko , did you able to figure out how to calculate FLOPs for selective scan? I used your script, and as you noted it is larger than what I expected?
@llmexperiment
As addressed by @albertfgu , you can just return 9BLDN
if you only use the core function of selective_scan.
For full script:
def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
assert not with_complex
# https://github.com/state-spaces/mamba/issues/110
flops = 9 * B * L * D * N
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
return flops
def selective_scan_flop_jit(inputs, outputs):
print_jit_input_names(inputs)
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False, with_Group=True)
return flops
I have a naive follow up question:
If we use associative scan algorithm, in the wiki (https://en.wikipedia.org/wiki/Prefix_sum) for prefix sum it shows that the work-efficient version only takes O(T) while the faster span version takes O(T log T).
May I ask whether the mamba kernel is more similar to the work-efficient version or the fast version. Because it seems to me the fast and slow version all takes forward/backward latency of scale O(\log T). But they require different number of cores to compute and have very different asymptotic growth with respect to sequence length T.
We use the work-efficient version (Blelloch's scan).
In a world with infinite parallelism the lower-span version may be faster by a constant. But GPUs have a lot of different constraints; we actually already max out its parallelism and the bottleneck is compute, so the work-efficient version is much faster.
The formula we used is
9 * d_state * d_model
(times batch size times sequence length). This is for a forward pass, so triple that for forward + backward pass. Note that these are the flops on top of the standard2 * parameters * tokens
FLOP count incurred from linear layers (or6 * D^2 * tokens
total for forward + backward).This is a brief explanation:
Note that the cost of computing the input-dependent dt/B/C is baked into the linear layer FLOP counts above
We ignore
batch
andd_model
in the calculations below, since it’s all trivially parallelized over these dimensions.The Bx and Ch calculations have 2LN (
2 * seqlen * d_state
) mults and LN adds (adds are for the C part only)Remaining flops are associative scan on N (
d_state
) independent recurrences.
- 2L associative operations
- SSM op is
(a1, b1) o (a2, b2) = (a1a2, a2b1 + b2)
- two multiply and one add = 3 FLOPs per associative operation
2L∗3=6L total
Summing these gives the 9LN.
Hi @albertfgu ,
I find your response very informative, and I am trying to understand deeper. I have two quick questions.
- Why there are
2L
associative operations? - Does the calculation of
9LN
is for scan part only?
- If you look at the computation graph of the prefix sum, there are two passes, each of which has L-1 operations. See the "work-efficient" diagram in https://en.wikipedia.org/wiki/Prefix_sum
- Yes
- If you look at the computation graph of the prefix sum, there are two passes, each of which has L-1 operations. See the "work-efficient" diagram in https://en.wikipedia.org/wiki/Prefix_sum
- Yes
How should I calculate the FLOPS for a standard Mamba Layer, and what would be an approximate value? Thank you very much
We calc FLOPs based on the ref code, though it is very different from the real speed in practise.
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 if False: ... """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: delta = delta + delta_bias[..., None].float() if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] """ flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") if False: ... """ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None """ in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if False: ... """ for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: if C.dim() == 3: y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) """ if with_D: flops += B * D * L if with_Z: flops += B * D * L if False: ... """ out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) """ return flops def selective_scan_flop_jit(inputs, outputs): # xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip) assert inputs[0].debugName().startswith("xs") # (B, D, L) assert inputs[2].debugName().startswith("As") # (D, N) assert inputs[3].debugName().startswith("Bs") # (D, N) with_Group = len(inputs[3].type().sizes()) == 4 with_D = inputs[5].debugName().startswith("Ds") if not with_D: with_z = inputs[5].debugName().startswith("z") else: with_z = inputs[6].debugName().startswith("z") B, D, L = inputs[0].type().sizes() N = inputs[2].type().sizes()[1] flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group) return flops
thanks for your work, can you explain how to use this coumpute FLOPs of mamba with a input[B,L,D]? Thank you
@lth456321 Have you figured out this problem? I am also troubled by this issue. I try to use it, while got error. Can anyone provide the FLOPs calculation formula for the entire Mamba Block? Including three fully connected layers, a causal conv1d, and some norms?
triton.compiler.errors.CompilationError: at 31:24: HAS_BIAS: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
^
ValueError("arange's arguments must be of type tl.constexpr")
@lth456321 Have you figured out this problem? I am also troubled by this issue. I try to use it, while got error. Can anyone provide the FLOPs calculation formula for the entire Mamba Block? Including three fully connected layers, a causal conv1d, and some norms?
triton.compiler.errors.CompilationError: at 31:24: HAS_BIAS: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X += row * stride_x_row Y += row * stride_y_row if HAS_RESIDUAL: RESIDUAL += row * stride_res_row if STORE_RESIDUAL_OUT: RESIDUAL_OUT += row * stride_res_out_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) ^ ValueError("arange's arguments must be of type tl.constexpr")
I had the same error and solved it by changing this:
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
at "mamba/mamba_ssm/ops/triton/layer_norm.py" line 365 to this:
BLOCK_N: tl.constexpr = int(min(MAX_FUSED_SIZE, triton.next_power_of_2(N)))
N = int(N)
During training, BLOCK_N
was of <class 'int'>
(which worked). But when called by fvcore, it was a torch tensor (which throws an error).
Same for N
.
But I haven't found the reason for this, and I have no experience with Triton at all, so this is probably not a good solution.
How about flops for mamba2 ? does any one know how to calculate it manually ?