MinerU icon indicating copy to clipboard operation
MinerU copied to clipboard

`unimernet` CustomMBartDecoder does not support Flash Attention 2

Open sepcnt opened this issue 3 months ago • 0 comments

Description of the bug | 错误描述

Due to a change in the default behavior of Torch, we need to apply a monkey patch to unimernet when using torchtext=0.18.0 and torch<2.4.

How to reproduce the bug | 如何复现

On Ubuntu 22.04, create a new venv with following makefile:

.PHONY: all install setup patch download run

all: install setup download

install:
	curl -LsSf https://astral.sh/uv/install.sh | sh
	source $(HOME)/.local/bin/env

setup:
	uv venv --python 3.10
	source .venv/bin/activate
	wget https://gitee.com/myhloli/MinerU/raw/master/requirements-docker.txt
	uv pip install -r requirements-docker.txt --extra-index-url https://wheels.myhloli.com -i https://mirrors.aliyun.com/pypi/simple
	uv pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
	uv pip install -U magic-pdf
	uv pip install modelscope
	
patch:
	git apply dep.patch 


download:
	wget https://gitee.com/myhloli/MinerU/raw/master/scripts/download_models.py
	python download_models.py

	sed -i 's|cpu|cuda|g' ~/magic-pdf.json 

run:
	magic-pdf -p ./input -o ./output -l=ch

Here comes a monkey patch to solve the issue:

diff --git a/.venv/lib/python3.10/site-packages/unimernet/models/unimernet/encoder_decoder.py b/.venv/lib/python3.10/site-packages/unimernet/models/unimernet/encoder_decoder.py
index 7cf350a..cccef39 100644
--- a/.venv/lib/python3.10/site-packages/unimernet/models/unimernet/encoder_decoder.py
+++ b/.venv/lib/python3.10/site-packages/unimernet/models/unimernet/encoder_decoder.py
@@ -432,6 +432,7 @@ class CustomMBartForCausalLM(MBartForCausalLM):
         print("CustomMBartForCausalLM init")
         super().__init__(config)
         # Modify the decoder within MBartDecoderWrapper
+        config._attn_implementation = "eager"
         self.model.decoder = CustomMBartDecoder(config)

Then could run successfully, this is just a reference. (without patch would cause error)

Operating system | 操作系统

Linux

Python version | Python 版本

3.10

Software version | 软件版本 (magic-pdf --version)

0.9.x

Device mode | 设备模式

cuda

sepcnt avatar Nov 18 '24 06:11 sepcnt