CLIPasso
CLIPasso copied to clipboard
代码问题
class CLIPVisualEncoder(nn.Module): def init(self, clip_model): super().init() self.clip_model = clip_model self.featuremaps = None
for i in range(12): # 12 resblocks in VIT visual transformer
self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
self.make_hook(i))
def make_hook(self, name):
def hook(module, input, output):
if len(output.shape) == 3:
self.featuremaps[name] = output.permute(
1, 0, 2) # LND -> NLD bs, smth, 768
else:
self.featuremaps[name] = output
return hook
这个函数无法获得特征图