segment-anything
segment-anything copied to clipboard
how to get hidden_state from every layers of ViT of sam vision encoder?
thanks a lot
The simplest way is probably to use pytorch's forward hooks functionality to grab the output of the Block
modules (assuming that's what you mean by the hidden state of the layers in this case).
You can do something like:
from segment_anything.modeling.image_encoder import Block
# ... assuming SamPredictor & image data are already set up ...
# Use forward hooks to store 'Block' outputs when encoding image
captures = []
hook_func = lambda m, inp, out: captures.append(out)
for m in predictor.model.modules():
if isinstance(m, Block):
m.register_forward_hook(hook_func)
with torch.no_grad(): predictor.set_image(image)
print("Number of blocks captured:", len(captures))
You can change the import ... Block
part to other model components to grab the corresponding outputs (like the neck or attention outputs).