nni
nni copied to clipboard
KeyError: 'aten::expand_as'
Describe the issue:
It seems that the operation of
aten::expand_as
is still not supported.
If possible, please fix this issue as soon. Thanks!
@sverrejoh @bgianfo
@J-shang - is this a new support we shall add to backlog?
@scarlett2018 yes, I will support this op in v2.7 if possible.
@scarlett2018 yes, I will support this op in v2.7 if possible.
have you fix it?
@xuzhuang1996
This is a simple workaround, https://github.com/microsoft/nni/pull/4852, ~but as this pr description said, there may be insufficient speedup near expand_as
.~ (We've fixed the issue and it's speedup fine now)
Any scenario that uses expand_as
will help us to improve this op-related speedup. If this workaround does not meet you need, please contact us.
@xuzhuang1996 This is a simple workaround, #4852, but as this pr description said, there may be insufficient speedup near
expand_as
. Any scenario that usesexpand_as
will help us to improve this op-related speedup. If this workaround does not meet you need, please contact us.
Sadly not:
aten::expand_as is not Supported!
Sadly not:
aten::expand_as is not Supported!
Could you show more message about this error? This pr has not yet entered the master branch, did you install it from the source code from commit checkout? If so, could you provide a simple example so that we can reproduce your problem?
FYI, we have tested it on the following simple example.
import torch
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(10, 5)
self.fc2 = torch.nn.Linear(10, 1)
self.fc3 = torch.nn.Linear(5, 2)
def forward(self, x):
a = self.fc1(x)
b = self.fc2(x).expand_as(a)
return self.fc3(a + b)
model = TestModel()
pruner = L1NormPruner(model, [{'op_names': ['fc1'], 'sparsity': 0.5}])
_, masks = pruner.compress()
pruner._unwrap_model()
print(masks)
ModelSpeedup(model, torch.rand(10, 10), masks).speedup_model()
print(model)
hi @xuzhuang1996 Do you still have this problem?
hi @xuzhuang1996 Do you still have this problem?
sadly not!