TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

❓ [Question] Expected isITensor() to be true but got false Requested ITensor from Var, however Var type is c10::IValue

Open clks-wzz opened this issue 2 years ago • 4 comments

I try to use python trtorch==0.4.1 to compile my own pytorch jit traced model, and I find that it goes wrong with the following information:

Traceback (most recent call last): File "./prerecall_server.py", line 278, in <module> ModelServing(args), File "./prerecall_server.py",, line 133, in __init__ self.model = trtorch.compile(self.model, compile_settings) File "/usr/local/lib/python3.6/dist-packages/trtorch/_compiler.py", line 73, in compile compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_compile_spec(compile_spec)) RuntimeError: [Error thrown at core/conversion/var/Var.cpp:149] Expected isITensor() to be true but got false Requested ITensor from Var, however Var type is c10::IValue

I make debug and find that the module contains the unknown operation.

`

class Causal_Norm_Classifier(nn.Module):

 def __init__(self, num_classes=1000, feat_dim=2048, use_effect=False, num_head=2, tau=16.0, alpha=1.0, gamma=0.03125, mu=0.9, *args):
    super(Causal_Norm_Classifier, self).__init__()
    # default alpha = 3.0
    #self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim).cuda(), requires_grad=True)
    self.scale = tau / num_head   # 16.0 / num_head
    self.norm_scale = gamma       # 1.0 / 32.0
    self.alpha = alpha            # 3.0
    self.num_head = num_head
    self.feat_dim = feat_dim
    self.head_dim = feat_dim // num_head
    self.use_effect = use_effect
    self.relu = nn.ReLU(inplace=True)
    self.mu = mu

    self.register_parameter('weight', nn.Parameter(torch.Tensor(num_classes, feat_dim), requires_grad=True))

    self.reset_parameters(self.weight)
    
def reset_parameters(self, weight):
    stdv = 1. / math.sqrt(weight.size(1))
    weight.data.uniform_(-stdv, stdv)


def forward(self, x, training=True, use_effect=True):
    # calculate capsule normalized feature vector and predict
    normed_w = self.multi_head_call(self.causal_norm, self.weight, weight=self.norm_scale)
    normed_x = self.multi_head_call(self.l2_norm, x)
    y = torch.mm(normed_x * self.scale, normed_w.t())

    return y

def multi_head_call(self, func, x, weight=None):
    assert len(x.shape) == 2
    x_list = torch.split(x, self.head_dim, dim=1)
    if weight:
        y_list = [func(item, weight) for item in x_list]
    else:
        y_list = [func(item) for item in x_list]
    assert len(x_list) == self.num_head
    assert len(y_list) == self.num_head
    return torch.cat(y_list, dim=1)

def l2_norm(self, x):
    normed_x = x / torch.norm(x, 2, 1, keepdim=True)
    return normed_x

def causal_norm(self, x, weight):
    norm= torch.norm(x, 2, 1, keepdim=True)
    normed_x = x / (norm + weight)
    return normed_x

`

Can you help me with this?

clks-wzz avatar Mar 15 '22 10:03 clks-wzz

Is this still an issue with Torch-TensorRT 1.0.0?

narendasan avatar Apr 08 '22 01:04 narendasan

Hello, I am having the same issue in v.1.0.0.

Could you update us about this? Thanks

mjack3 avatar Aug 11 '22 12:08 mjack3

The original issue seems solved but seems like an issue with partitioning not being able to access a const tensor that is passed as an input to an earlier graph.

narendasan avatar Aug 12 '22 01:08 narendasan

Thanks @narendasan

In my desk computer with v1.0.0 the conversion works.

Then, in a Jetson with v1.0.0 (built from source) the conversion fails with the error.

  • May this be moved as issue?.
  • Do you know if v1.1.0 can be used in a Jetson? That may solve the problem.

mjack3 avatar Aug 12 '22 06:08 mjack3

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

github-actions[bot] avatar Nov 11 '22 00:11 github-actions[bot]

Do you know if v1.1.0 can be used in a Jetson? That may solve the problem.

Sorry for not getting back sooner, Jetson support is a bit shaky right now because we are still trying to align on a new PyTorch build process. You should be able to use ~release/ngc/22.07-release/ngc/22.10 branches with the NVIDIA's build of PyTorch for Jetson

narendasan avatar Dec 15 '22 17:12 narendasan

Bo needs to ask user to try with latest codebase.

Christina-Young-NVIDIA avatar Dec 20 '22 02:12 Christina-Young-NVIDIA

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

github-actions[bot] avatar Mar 21 '23 00:03 github-actions[bot]