deployin C++ using libtorch or TensorRT.
I'd like to inquire whether this code has been successfully deployed in C++ using libtorch or TensorRT. I've encountered numerous issues when attempting to deploy VMamba, and I'm wondering if anyone has achieved successful deployment before
Hi @cebain , thanks for raising this interesting question. While we have not tried TensorRT (or C++ using libtorch), we believe all components should be supported properly. Each individual component (Mamba ssm or self attention) is already supported in the latest TensorRT release (for instance Nemotron-H model uses both of these components and also is supported in TensorRT).
Please let us know if you can successfully deploy on TensorRT !
when I try to use torch.jit.script to turn .ckpt into .pt, the following error occured: RuntimeError: Python builtin <built-in method fwd of PyCapsule object at 0x73b2657308a0> is currently not supported in Torchscript: File "/home/adminroot/Documents/ccc/pycharmProjects/UNetMamba-main/unetmamba_model/classification/models/vmamba.py", line 286 def SelectiveScanCore(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
which located in :
class SelectiveScanCore(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
)
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
and I don't know how to solve this