mamba
mamba copied to clipboard
Significant differences in gradients between `_ref` and `_fn` when using the complex formulation.
Hi, I was using complex dynamics for an application, and was seeing large differences in gradients computed by mamba_inner_ref
and mamba_inner_fn
, the scan functions worked fine and performed much better in my test (however, even for that case I had to lower my tolerance to 1e-6 from 1e-8 for the real case). I am attaching a reproducible sample below for the test for mamba_inner_ref
, I am assuming this happens because of torch-with-complex-numbers is still under development, but would appreciate any guidance on how to solve this.
'''
Check for mamba_inner_fn
'''
import math
import torch
from torch import nn
from tqdm import tqdm
from einops import repeat
import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
# Define a random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Set device to CUDA if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def test_gradient_implementation(device=device):
# Create random input tensors and parameters
batch_size = 4
dstate = 10
dim = 3
seqlen = 7
xz = torch.randn(batch_size, dstate*2, seqlen, device=device, requires_grad=True)
conv1d_weight = torch.randn(dstate, 1, 4, device=device, requires_grad=True)
conv1d_bias = torch.randn(dstate, device=device, requires_grad=True)
x_proj_weight = torch.randn(dim*4 + 1, dstate, device=device, requires_grad=True)
dt_proj_weight = torch.randn(dstate, 1, device=device, requires_grad=True)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = 1**-0.5
nn.init.uniform_(dt_proj_weight, -dt_init_std, dt_init_std)
dt_bias = torch.randn(dstate, device=device, requires_grad=True)
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(dstate) * (math.log(0.1) - math.log(0.001))
+ math.log(0.001)
).clamp(min=1e-4)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_bias.copy_(inv_dt)
out_proj_weight = torch.randn(int(dstate/2), dstate, device=device, requires_grad=True)
out_proj_bias = None
A_log = torch.log(repeat(0.5 - 1j*torch.arange(0, dim, dtype=torch.float32, device=device),
"n -> d n",
d=dstate,
).contiguous())
A_log.requires_grad = True
A = -torch.exp(A_log).to(torch.cfloat)
D = torch.randn(dstate, device=device, requires_grad=True)
A.retain_grad()
D.retain_grad()
# Forward pass through mamba_inner_fn
output_fn = mamba_inner_fn(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
dt_proj_weight,
out_proj_weight,
out_proj_bias,
A,
None, # input-dependent B
None, # input-dependent C
D,
delta_bias=dt_bias,
delta_softplus=True
)
# Forward pass through mamba_inner_ref
output_ref = mamba_inner_ref(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
dt_proj_weight,
out_proj_weight,
out_proj_bias,
A,
None, # input-dependent B
None, # input-dependent C
D,
delta_bias=dt_bias,
delta_softplus=True
)
# Check if outputs are the same
out_mismatch = False
if not torch.allclose(output_fn, output_ref, atol=1e-6):
print("Outputs do not match! Diff: ", torch.norm(output_fn - output_ref))
out_mismatch = True
# Create dummy targets
target = torch.randn_like(output_fn)
# Zero gradients
def zero_gradients(*tensors):
for tensor in tensors:
if tensor is not None and tensor.grad is not None:
tensor.grad.zero_()
# Compute loss for mamba_inner_fn
loss_fn = F.mse_loss(output_fn, target)
# Backward pass through mamba_inner_fn
zero_gradients(
xz, conv1d_weight, conv1d_bias, x_proj_weight,
dt_proj_weight, out_proj_weight, out_proj_bias, A, D
)
loss_fn.backward(retain_graph=True)
grad_xz_fn = xz.grad.clone() if xz.grad is not None else None
grad_conv1d_weight_fn = conv1d_weight.grad.clone() if conv1d_weight.grad is not None else None
grad_conv1d_bias_fn = conv1d_bias.grad.clone() if conv1d_bias.grad is not None else None
grad_x_proj_weight_fn = x_proj_weight.grad.clone() if x_proj_weight.grad is not None else None
grad_dt_proj_weight_fn = dt_proj_weight.grad.clone() if dt_proj_weight.grad is not None else None
grad_out_proj_weight_fn = out_proj_weight.grad.clone() if out_proj_weight.grad is not None else None
grad_A_fn = A.grad.clone() if A.grad is not None else None
grad_D_fn = D.grad.clone() if D.grad is not None else None
# Compute loss for mamba_inner_ref
loss_ref = F.mse_loss(output_ref, target)
# Backward pass through mamba_inner_ref
zero_gradients(
xz, conv1d_weight, conv1d_bias, x_proj_weight,
dt_proj_weight, out_proj_weight, out_proj_bias, A, D
) #,
loss_ref.backward(retain_graph=True)
grad_xz_ref = xz.grad.clone() if xz.grad is not None else None
grad_conv1d_weight_ref = conv1d_weight.grad.clone() if conv1d_weight.grad is not None else None
grad_conv1d_bias_ref = conv1d_bias.grad.clone() if conv1d_bias.grad is not None else None
grad_x_proj_weight_ref = x_proj_weight.grad.clone() if x_proj_weight.grad is not None else None
grad_dt_proj_weight_ref = dt_proj_weight.grad.clone() if dt_proj_weight.grad is not None else None
grad_out_proj_weight_ref = out_proj_weight.grad.clone() if out_proj_weight.grad is not None else None
grad_A_ref = A.grad.clone() if A.grad is not None else None
grad_D_ref = D.grad.clone() if D.grad is not None else None
mismatch = False
# Check if gradients are the same
for grad_fn, grad_ref, name in zip(
[grad_xz_fn, grad_conv1d_weight_fn, grad_conv1d_bias_fn, grad_x_proj_weight_fn,
grad_dt_proj_weight_fn, grad_out_proj_weight_fn,
grad_A_fn, grad_D_fn], #,
[grad_xz_ref, grad_conv1d_weight_ref, grad_conv1d_bias_ref, grad_x_proj_weight_ref,
grad_dt_proj_weight_ref, grad_out_proj_weight_ref,
grad_A_ref, grad_D_ref], #,
["xz", "conv1d_weight", "conv1d_bias", "x_proj_weight", "dt_proj_weight",
"out_proj_weight", "out_proj_bias", "A", "D"]
):
if grad_fn is not None and grad_ref is not None:
if not torch.allclose(grad_fn, grad_ref, atol=1e-5):
mismatch = True
print(f"Gradient mismatch for {name}! Diff: {torch.norm(grad_fn - grad_ref)}")
elif grad_fn is None or grad_ref is None:
print(f"Gradient does not exist for {name} at least in one of the functions.")
return out_mismatch, mismatch
# Call the test function
out_correct = 0
grad_correct = 0
for _ in tqdm(range(1000)):
out_mismatch, grad_mismatch = test_gradient_implementation(device)
out_correct += 1 if not out_mismatch else 0
grad_correct += 1 if not grad_mismatch else 0
print(f"Outputs match in {out_correct} out of 1000 runs.")
print(f"Gradients match in {grad_correct} out of 1000 runs.")
At this point I finally get about 906 runs on 1000 for grad-match. Outputs always match. However, unlike the real case, the differences are quite big, especially for xz
and out_proj_weight
.
I also noticed this property, wherein the functions have agreeable gradients only when the inputs are in the range of what the function expects, which is why I have initialized as in the original code them instead of random samples (it has a much lower agreement in that case). I was trying to write my own CUDA kernels for some application and wanted to test how bad it performs v/s how bad is the real implementation.