parseq
parseq copied to clipboard
Changing dec_depth results in an error during inference.
I've updated the parseq-patch16-224.yaml file with the following changes and have trained the model.
# @package _global_
defaults:
- override /model: parseq
model:
img_size: [ 224, 224 ] # [ height, width ]
patch_size: [ 16, 16 ] # [ height, width ]
dec_depth: 4
The training is going smoothly, and I've selected a checkpoint to try inference. However, I encountered the following error:
Traceback (most recent call last):
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/queueing.py", line 527, in process_events
response = await route_utils.call_process_api(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/route_utils.py", line 270, in call_process_api
output = await app.get_blocks().process_api(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/blocks.py", line 1847, in process_api
result = await self.call_function(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/blocks.py", line 1433, in call_function
prediction = await anyio.to_thread.run_sync(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
return await get_async_backend().run_sync_in_worker_thread(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2144, in run_sync_in_worker_thread
return await future
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 851, in run
result = context.run(func, *args)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/utils.py", line 805, in wrapper
response = f(*args, **kwargs)
File "webui.py", line 45, in predict
logits = model(img_tensor)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/shinji/projects/parseq/strhub/models/parseq/system.py", line 88, in forward
return self.model.forward(self.tokenizer, images, max_length)
File "/Users/shinji/projects/parseq/strhub/models/parseq/model.py", line 164, in forward
tgt_out = self.decode(
File "/Users/shinji/projects/parseq/strhub/models/parseq/model.py", line 103, in decode
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/shinji/projects/parseq/strhub/models/parseq/modules.py", line 121, in forward
query, content = mod(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/shinji/projects/parseq/strhub/models/parseq/modules.py", line 95, in forward
content = self.forward_stream(
File "/Users/shinji/projects/parseq/strhub/models/parseq/modules.py", line 69, in forward_stream
tgt2, sa_weights = self.self_attn(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1241, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 5318, in multi_head_attention_forward
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([61, 61]), but should be (2, 2).
This error appears to occur when the content is updated during refinement iterations.
Related codes: https://github.com/baudm/parseq/blob/main/strhub/models/parseq/model.py#L164 https://github.com/baudm/parseq/blob/main/strhub/models/parseq/modules.py#L121 https://github.com/baudm/parseq/blob/main/strhub/models/parseq/modules.py#L95
I confirmed that this error can be avoided by changing the refine_iters to 0.