unitable
unitable copied to clipboard
add gptfast decoder
Add decoder with static kv-cache from gpt-fast. Manually checked the results of the images in dataset/mini_pubtabnet/val
, but have not actually run the acc/TEDS
metrics on the test set.
Benchmark cell detection model with:
- 3090
- max_decode_len = 512
- fp32
The main modifications in full_pipeline.ipynb
.
- Specify which class to use for the decoder:
backbone = ImgLinearBackbone(d_model=d_model, patch_size=patch_size)
encoder = Encoder(
d_model=d_model,
nhead=nhead,
dropout = dropout,
activation="gelu",
norm_first=True,
nlayer=12,
ff_ratio=4,
)
# decoder_class = Decoder
decoder_class = GPTFastDecoder
- Initialize the decoder in
load_vocab_and_model
, callmap_state_dict
when usingGPTFastDecoder
.
def load_vocab_and_model(..., decoder_class: Type[nn.Module]):
decoder = decoder_class(
d_model=d_model,
nhead=nhead,
dropout = dropout,
activation="gelu",
norm_first=True,
nlayer=4,
ff_ratio=4,
)
model = EncoderDecoder(
backbone=backbone,
encoder=encoder,
decoder=decoder,
...
)
state_dict = torch.load(model_weights, map_location="cpu")
if isinstance(model.decoder, GPTFastDecoder):
state_dict = map_state_dict(state_dict)
model.load_state_dict(state_dict)
model = model.to(device)
return vocab, model
- In
autoregressive_decode
, ifGPTFastDecoder
is used,setup_caches
needs to be called first.
def autoregressive_decode(...):
model.eval()
is_gpt_fast = isinstance(model.decoder, GPTFastDecoder)
if is_gpt_fast:
with torch.device(image.device):
model.decoder.setup_caches(max_batch_size=image.shape[0], max_seq_length=max_decode_len, dtype=image.dtype)
memory = model.encode(image)
...