nni icon indicating copy to clipboard operation
nni copied to clipboard

torch.full can't use implicit scalars

Open SolomidHero opened this issue 3 years ago • 1 comments

Describe the issue: For some functions like torch.full it is needed to have input from other tensors. For example I use it to create lengths array, when it is not provided

I need traceable usage:

torch.full(x.shape, x.shape[0], device=x.device, dtype=torch.int32) # desirable, not working
torch.full(x.shape, int(x.shape[0]), device=x.device, dtype=torch.int32) # not desirable, but working

Environment:

  • NNI version: 2.9a2
  • Python version: 3.9
  • PyTorch version: checked 1.9, 1.12

How to reproduce it?:

import torch
import torch.nn as nn
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup


class M(nn.Module):
    def __init__(self, _in, _out):
        super().__init__()
        self.model = nn.Linear(_in, _out)

    def forward(self, x: torch.Tensor):
        add = torch.full(x.shape, x.shape[0], device=x.device, dtype=torch.int32)
        x = x + add
        x = self.model(x)
        return x

m = M(5, 7)
m_input = torch.randn(20, 5)

config_list = [{
    'sparsity_per_layer': 0.5,
    'op_types': ['Linear']
}]
pruner = L1NormPruner(m, config_list)
masked_m, masks = pruner.compress()
pruner._unwrap_model()

ModelSpeedup(m, m_input, masks, 20).speedup_model()

Error:

[2022-08-30 09:58:02] start to speedup the model
no multi-dimension masks found.
[2022-08-30 09:58:02] infer module masks...
[2022-08-30 09:58:02] Update mask for .aten::size.1
[2022-08-30 09:58:02] Update mask for .aten::size.3
[2022-08-30 09:58:02] Update mask for .aten::size.5
[2022-08-30 09:58:02] Update mask for .aten::Int.2
[2022-08-30 09:58:02] Update mask for .aten::Int.4
[2022-08-30 09:58:02] Update mask for .aten::ScalarImplicit.6
[2022-08-30 09:58:02] ERROR: aten::ScalarImplicit is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~
[2022-08-30 09:58:02] Update mask for .aten::full.7
[2022-08-30 09:58:02] WARNING: throw some args away when calling the function "full"
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/home/ubuntu/vol1/FastSpeech/notebooks/test_nni.ipynb Cell 2 in <cell line: 21>()
     [19](vscode-notebook-cell://ssh-remote%2Baws2/home/ubuntu/vol1/FastSpeech/notebooks/test_nni.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=18) print(m(m_input).sum())
     [20](vscode-notebook-cell://ssh-remote%2Baws2/home/ubuntu/vol1/FastSpeech/notebooks/test_nni.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19) print(masked_m(m_input).sum())
---> [21](vscode-notebook-cell://ssh-remote%2Baws2/home/ubuntu/vol1/FastSpeech/notebooks/test_nni.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20) ModelSpeedup(m, m_input, masks, 20).speedup_model()
     [23](vscode-notebook-cell://ssh-remote%2Baws2/home/ubuntu/vol1/FastSpeech/notebooks/test_nni.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=22) print(m(m_input).sum())
     [24](vscode-notebook-cell://ssh-remote%2Baws2/home/ubuntu/vol1/FastSpeech/notebooks/test_nni.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23) print(masked_m(m_input).sum())

File ~/vol1/miniconda3/envs/neuralmagic/lib/python3.9/site-packages/nni/compression/pytorch/speedup/compressor.py:543, in ModelSpeedup.speedup_model(self)
    540 fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
    542 _logger.info("infer module masks...")
--> 543 self.infer_modules_masks()
    544 _logger.info('resolve the mask conflict')
    546 # load the original stat dict before replace the model

File ~/vol1/miniconda3/envs/neuralmagic/lib/python3.9/site-packages/nni/compression/pytorch/speedup/compressor.py:380, in ModelSpeedup.infer_modules_masks(self)
    378 curnode = visit_queue.get()
    379 # forward mask inference for curnode
--> 380 self.update_direct_sparsity(curnode)
    381 successors = self.torch_graph.find_successors(curnode.unique_name)
    382 for successor in successors:

File ~/vol1/miniconda3/envs/neuralmagic/lib/python3.9/site-packages/nni/compression/pytorch/speedup/compressor.py:234, in ModelSpeedup.update_direct_sparsity(self, node)
    232         return
    233     # function doesn't have weights
--> 234     _auto_infer = AutoMaskInference(
    235         func, dummy_input, self, in_masks, in_constants=in_constants)
    236 else:
    237     weight_mask = None

File ~/vol1/miniconda3/envs/neuralmagic/lib/python3.9/site-packages/nni/compression/pytorch/speedup/infer_mask.py:80, in AutoMaskInference.__init__(self, module, dummy_input, speedup, in_masks, weight_mask, output_mask, name, in_constants, state_dict)
     76         self.in_masks[in_id] = torch.ones_like(self.dummy_input[in_id])
     77         # ones_like will put the created mask on the same device with the dummy_input
     78 
     79 # Initialize the mask for output tensors
---> 80 self.output = self.module(*dummy_input)
     81 # self.output.requires_grad_()
     82 if output_mask is not None:
     83     # assume the given output mask is right

File ~/vol1/miniconda3/envs/neuralmagic/lib/python3.9/site-packages/nni/compression/pytorch/speedup/jit_translate.py:244, in FuncAdapter.__call__(self, *args)
    242         for f in fs:
    243             self.keyword[p] = f(self.keyword[p])
--> 244 result = self.func(*self.positional, **self.keyword)
    245 if isinstance(result, int): # turn result of 'size' into tensor
    246     result = torch.as_tensor([result], dtype=torch.long)

TypeError: full() received an invalid combination of arguments - got (list, Tensor, dtype=torch.dtype), but expected one of:
 * (tuple of ints size, Number fill_value, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, Number fill_value, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

SolomidHero avatar Aug 30 '22 10:08 SolomidHero

@SolomidHero - thanks for reporting the issue and digging in more, would you like to commit a PR for it?

scarlett2018 avatar Sep 02 '22 02:09 scarlett2018