parseq icon indicating copy to clipboard operation
parseq copied to clipboard

Changing dec_depth results in an error during inference.

Open dotneet opened this issue 1 year ago • 0 comments

I've updated the parseq-patch16-224.yaml file with the following changes and have trained the model.

# @package _global_
defaults:
  - override /model: parseq

model:
  img_size: [ 224, 224 ]  # [ height, width ]
  patch_size: [ 16, 16 ]  # [ height, width ]
  dec_depth: 4

The training is going smoothly, and I've selected a checkpoint to try inference. However, I encountered the following error:

Traceback (most recent call last):
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/queueing.py", line 527, in process_events
    response = await route_utils.call_process_api(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/route_utils.py", line 270, in call_process_api
    output = await app.get_blocks().process_api(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/blocks.py", line 1847, in process_api
    result = await self.call_function(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/blocks.py", line 1433, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2144, in run_sync_in_worker_thread
    return await future
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 851, in run
    result = context.run(func, *args)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/gradio/utils.py", line 805, in wrapper
    response = f(*args, **kwargs)
  File "webui.py", line 45, in predict
    logits = model(img_tensor)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/shinji/projects/parseq/strhub/models/parseq/system.py", line 88, in forward
    return self.model.forward(self.tokenizer, images, max_length)
  File "/Users/shinji/projects/parseq/strhub/models/parseq/model.py", line 164, in forward
    tgt_out = self.decode(
  File "/Users/shinji/projects/parseq/strhub/models/parseq/model.py", line 103, in decode
    return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/shinji/projects/parseq/strhub/models/parseq/modules.py", line 121, in forward
    query, content = mod(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/shinji/projects/parseq/strhub/models/parseq/modules.py", line 95, in forward
    content = self.forward_stream(
  File "/Users/shinji/projects/parseq/strhub/models/parseq/modules.py", line 69, in forward_stream
    tgt2, sa_weights = self.self_attn(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/Users/shinji/projects/parseq/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 5318, in multi_head_attention_forward
    raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([61, 61]), but should be (2, 2).

This error appears to occur when the content is updated during refinement iterations.

Related codes: https://github.com/baudm/parseq/blob/main/strhub/models/parseq/model.py#L164 https://github.com/baudm/parseq/blob/main/strhub/models/parseq/modules.py#L121 https://github.com/baudm/parseq/blob/main/strhub/models/parseq/modules.py#L95

I confirmed that this error can be avoided by changing the refine_iters to 0.

dotneet avatar May 10 '24 03:05 dotneet