GPU RAM Requirements for Formula Detection (do_formula_enrichment)
Hi,
I am using docling with an RTX 3090 and encountering a CUDA out-of-memory error when enabling do_formula_enrichment=True. Could you provide information on the expected GPU RAM usage for formula detection? How much memory is typically required to process documents with this setting enabled?
Thanks in advance!
Hello,
The CodeFormula model can be quite heavy, and a default batch size of 16 may be too high for some hardware setups. We’re currently working on an update that will allow each model to use its own batch size, ensuring better adaptability for models of different sizes.
In the meantime, you can manually reduce the batch size by importing and modifying the settings object. Please note that this change applies to all models, not just CodeFormula:
from docling.datamodel.settings import settings
settings.perf.elements_batch_size = 2
I hope this helps!
I will update this issue as soon as the more fine-grained batch size selection feature is released.
Thank you !!
Same problem here, I'm using RTX 2080 Ti with 11G memory, but setting batch-size to 1 and num_threads to 1 still didn't work for me... How much memory does it take for you to run bs=1? @JPC612
@hjenryin around 7.5 GB GPU VRAM, but I split the PDF into single-page PDFs and merge them back together at the end.
Hello,
I came across this problem recently. The default setting of:
docling.models.code_formula_model.CodeFormulaModel.elements_batch_size
is 5. When I leave it at the default, it uses approximately 18-20GB of VRAM. I tested setting it to 7 and used a bit more (see attached image)
Raising the batch size however, does not increase the speed. In My testing, i noticed that the following settings within the file located here: "\wsl.localhost\Ubuntu\home\wstation\miniconda3\envs\newenv\lib\python3.12\site-packages\docling_ibm_models\code_formula_model\code_formula_predictor.py"
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM
from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import (
SamOptImageProcessor,
)
_log = logging.getLogger(__name__)
class StopOnString(StoppingCriteria):
def __init__(self, tokenizer, stop_string):
self.stop_token_ids = tokenizer.encode(stop_string, add_special_tokens=False)
def __call__(self, input_ids, scores, **kwargs):
for sequence in input_ids:
sequence_list = sequence.tolist()
for i in range(len(sequence_list) - len(self.stop_token_ids) + 1):
if (
sequence_list[i : i + len(self.stop_token_ids)]
== self.stop_token_ids
):
return True
return False
class CodeFormulaPredictor:
"""
Code and Formula Predictor using a multi-modal vision-language model.
This class enables the prediction of code or LaTeX representations
from input images of code snippets or mathematical formulas.
Attributes
----------
_device : str
The device on which the model is loaded (e.g., 'cpu' or 'cuda').
_num_threads : int
Number of threads used for inference when running on CPU.
_tokenizer : transformers.PreTrainedTokenizer
Tokenizer for processing textual inputs to the model.
_model : transformers.PreTrainedModel
Pretrained multi-modal vision-language model.
_image_processor : transformers.ImageProcessor
Processor for normalizing and preparing input images.
_temperature : float
Sampling temperature for generation; controls randomness in predictions.
"""
def __init__(
self,
artifacts_path: str,
device: str = "cpu",
num_threads: int = 4,
):
"""
Initializes the CodeFormulaPredictor with the specified model artifacts.
Parameters
----------
artifacts_path : str
Path to the directory containing the pretrained model files.
device : str, optional
Device to run the inference on ('cpu' or 'cuda'), by default "cpu".
num_threads : int, optional
Number of threads for CPU inference, by default 4.
"""
self._device = device
self._num_threads = num_threads
if device == "cpu":
torch.set_num_threads(self._num_threads)
self._tokenizer = AutoTokenizer.from_pretrained(
artifacts_path, use_fast=True, padding_side="left"
)
self._model = SamOPTForCausalLM.from_pretrained(artifacts_path).to(self._device)
self._model.eval()
self._image_processor = SamOptImageProcessor.from_pretrained(artifacts_path)
_log.debug("CodeFormulaModel settings: {}".format(self.info()))
def info(self) -> dict:
"""
Retrieves configuration details of the CodeFormulaPredictor instance.
Returns
-------
dict
A dictionary containing configuration details such as the device and
the number of threads used.
"""
info = {
"device": self._device,
"num_threads": self._num_threads,
}
return info
def _get_prompt(self, label: str) -> str:
"""
Constructs the prompt for the model based on the input label.
Parameters
----------
label : str
The type of input, either 'code' or 'formula'.
Returns
-------
str
The constructed prompt including necessary tokens and query.
Raises
------
NotImplementedError
If the label is not 'code' or 'formula'.
"""
if label == "code":
query = "<code_image_to_text>"
elif label == "formula":
query = "<equation>"
else:
raise NotImplementedError("Label must be either code or formula")
prompt = (
"A chat between a curious user and an artificial intelligence"
" assistant. The assistant gives helpful, detailed, and polite answers to"
" the user's questions. USER: "
)
prompt += (
"<img>" + "<imgpad>" * 256 + "</img>" + "\n" + " ASSISTANT:" + "\n" + query
)
return prompt
def _strip(self, text: str):
"""
Removes any occurrences of the substrings in remove_list from the end of text.
Parameters
----------
text : str
The original string.
Returns
-------
str
The trimmed string.
"""
remove_list = [r"\quad", r"\\", r"\,", " c c c c", " l l l l l"]
changed = True
while changed:
changed = False
for substr in remove_list:
if text.endswith(substr):
text = text[: -len(substr)]
changed = True
return text.strip()
@torch.inference_mode()
def predict(
self,
images: List[Union[Image.Image, np.ndarray]],
labels: List[str],
temperature: Optional[float] = 0.0,
) -> List[str]:
"""
Predicts the textual representation of input images (code or LaTeX).
Parameters
----------
images : List[Union[Image.Image, np.ndarray]]
List of images to be processed, provided as PIL Image objects or numpy arrays.
labels : List[str]
List of labels indicating the type of each image ('code' or 'formula').
temperature : Optional[float]
Sampling temperature for generation, by default set to 0.0.
Returns
-------
List[str]
List of predicted textual outputs for each input image in the given input
order.
Raises
------
TypeError
If any of the input images is not of a supported type (PIL Image or numpy array).
Excpetion
In case the temperature is an invalid number.
"""
if (
temperature is None
or not (isinstance(temperature, float) or isinstance(temperature, int))
or temperature < 0
):
raise Exception("Temperature must be a number greater or equal to 0.")
do_sample = True
if temperature == 0:
do_sample = False
temperature = None
if len(labels) != len(images):
raise Exception(
"The number of images must be the same as the number of labels."
)
images_tmp = []
for image in images:
if isinstance(image, Image.Image):
image = image.convert("RGB")
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
else:
raise TypeError("Not supported input image format")
images_tmp.append(image)
images_tensor = torch.stack(
[self._image_processor(img) for img in images_tmp]
).to(self._device)
prompts = [self._get_prompt(label) for label in labels]
tokenized = self._tokenizer(prompts, padding=True, return_tensors="pt")
tokenized = {k: v.to(self._device) for k, v in tokenized.items()}
prompt_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
stopping_criteria = StoppingCriteriaList(
[
StopOnString(self._tokenizer, r" \quad \quad \quad \quad"),
StopOnString(self._tokenizer, r" \\ \\ \\ \\"),
StopOnString(self._tokenizer, r" \, \, \, \,"),
StopOnString(self._tokenizer, r" c c c c c c c c c c c c c c c c"),
StopOnString(self._tokenizer, r" l l l l l l l l l l l l l l l l l"),
]
)
if self._device == "cpu":
output_ids_list = self._model.generate(
input_ids=prompt_ids,
attention_mask=attention_mask,
images=images_tensor,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=4096 - prompt_ids.shape[1],
use_cache=True,
no_repeat_ngram_size=200,
stopping_criteria=stopping_criteria,
)
else:
with torch.autocast(device_type=self._device, dtype=torch.bfloat16):
output_ids_list = self._model.generate(
prompt_ids,
images=images_tensor,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=4096 - prompt_ids.shape[1],
use_cache=True,
no_repeat_ngram_size=200,
stopping_criteria=stopping_criteria,
)
outputs = self._tokenizer.batch_decode(
output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True
)
outputs = [self._strip(output) for output in outputs]
return outputs
these settings lead to the formula recognizer to produce extremely long formulas of random characters, in some cases it would reach or exceed the 4096 token limit, which is why it takes so long for it to process the formulas. I first tried to adjust the prompt, but that led to more problems. So what I did instead is adjust to _predict function:
# my_code_formula_predictor.py
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
import logging
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers import( AutoTokenizer,
GenerationConfig,
StoppingCriteriaList,
MaxLengthCriteria,
MaxTimeCriteria
)
from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM
from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import (
SamOptImageProcessor,
)
_log = logging.getLogger(__name__)
@torch.inference_mode()
def predict(
self,
images: List[Union[Image.Image, np.ndarray]],
labels: List[str],
temperature: Optional[float] = 0.0,
max_generation_time: float = 30.0, # Robust timeout
max_new_tokens: int = 512, # Limit for runaway tokens
) -> List[str]:
"""
Predicts the textual representation of input images (code or LaTeX).
Parameters
----------
images : List[Union[Image.Image, np.ndarray]]
List of images to be processed, provided as PIL Image objects or numpy arrays.
labels : List[str]
List of labels indicating the type of each image ('code' or 'formula').
temperature : Optional[float]
Sampling temperature for generation, by default set to 0.0.
Returns
-------
List[str]
List of predicted textual outputs for each input image in the given input
order.
Raises
------
TypeError
If any of the input images is not of a supported type (PIL Image or numpy array).
Exception
In case the temperature is an invalid number.
Modifications
-------------
[03/26/2025]
- Implemented robust stopping criteria using `MaxLengthCriteria` and `MaxTimeCriteria`
for controlled generation, replacing the previous custom string-based stopping criteria.
- Updated generation parameters (`max_new_tokens` set explicitly to 512 and
`no_repeat_ngram_size` reduced from 200 to 20) to mitigate runaway token generation issues.
- Integrated Hugging Face's `GenerationConfig` for a structured and maintainable definition
of model generation parameters.
- Enhanced logging to clearly indicate generation termination conditions (timeouts or token limits),
aiding debugging and monitoring.
"""
_log.info("Starting prediction for %d images with labels: %s", len(images), labels)
try:
if temperature is None or not isinstance(temperature, (float, int)) or temperature < 0:
raise Exception("Temperature must be a number greater or equal to 0.")
do_sample = temperature != 0.0
actual_temperature = temperature if do_sample else 1.0 # standard practice
if len(labels) != len(images):
raise Exception("The number of images must be the same as the number of labels.")
images_tmp = []
for idx, image in enumerate(images):
try:
if isinstance(image, Image.Image):
image = image.convert("RGB")
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
else:
raise TypeError("Not supported input image format")
images_tmp.append(image)
_log.debug("Processed image %d successfully", idx)
except Exception as e:
_log.error("Error processing image %d: %s", idx, e, exc_info=True)
raise
images_tensor = torch.stack(
[self._image_processor(img) for img in images_tmp]
).to(self._device)
_log.debug("Created images tensor with shape: %s", images_tensor.shape)
prompts = [self._get_prompt(label) for label in labels]
# _log.debug("Prompts: %s", prompts)
tokenized = self._tokenizer(prompts, padding=True, return_tensors="pt")
tokenized = {k: v.to(self._device) for k, v in tokenized.items()}
prompt_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
_log.debug("Tokenized prompt_ids shape: %s", prompt_ids.shape)
##############################################
# swapping to GenerationConfig 3/27/2025 00:06AM
##############################################
# Robust stopping criteria implementation:
stopping_criteria = StoppingCriteriaList([
MaxLengthCriteria(max_length=prompt_ids.shape[1] + max_new_tokens),
MaxTimeCriteria(max_generation_time)
])
_log.debug("Configured robust stopping criteria (MaxLength: %d tokens, MaxTime: %.2f seconds).",
prompt_ids.shape[1] + max_new_tokens, max_generation_time)
# Use GenerationConfig for better readability and maintainability
# Restore complete GenerationConfig with detailed beam search parameters:
generation_config = GenerationConfig(
do_sample=False, # Beam search is deterministic
num_beams=6, # Optimal balance between speed and diversity
num_beam_groups=2, # strictly > 1
early_stopping=True, # Stops early once suitable output is found
repetition_penalty=1.3, # Penalize repetitive token sequences
length_penalty=0.9, # Slightly prefer concise outputs
diversity_penalty=0.7, # Improves diversity across beams
max_new_tokens=max_new_tokens, # Prevent runaway generation
no_repeat_ngram_size=3, # Prevent repetitive sequences explicitly
use_cache=True, # Enable caching for efficiency
temperature=actual_temperature, # Explicitly include the temperature setting
)
_log.debug("Restored complete GenerationConfig with beam search: %s", generation_config.to_dict())
if self._device == "cpu":
output = self._model.generate(
input_ids=prompt_ids,
attention_mask=attention_mask,
images=images_tensor,
generation_config=generation_config, # pass everything else via GenerationConfig
stopping_criteria=stopping_criteria, # pass stopping criteria here instead
output_scores=True, # <-- Enables access to beam/token scores
return_dict_in_generate=True # <-- Required to access scores
)
else:
with torch.autocast(device_type=self._device, dtype=torch.bfloat16):
output = self._model.generate(
input_ids=prompt_ids,
attention_mask=attention_mask,
images=images_tensor,
generation_config=generation_config, # pass everything else via GenerationConfig
stopping_criteria=stopping_criteria, # pass stopping criteria here instead
output_scores=True, # <-- Enables access to beam/token scores
return_dict_in_generate=True # <-- Required to access scores
)
output_ids_list = output.sequences
_log.debug("Generation complete, output shape: %s", output_ids_list.shape)
outputs = self._tokenizer.batch_decode(
output_ids_list[:, prompt_ids.shape[1]:], skip_special_tokens=True
)
outputs = [self._strip(output) for output in outputs]
# Log beam scores clearly and correctly:
if hasattr(output, 'sequences_scores') and output.sequences_scores is not None:
scores = output.sequences_scores.cpu().numpy()
for idx, score in enumerate(scores):
_log.info("Output %d beam score: %.4f", idx, score)
else:
_log.warning("Beam scores not available. Ensure num_beams>1 and output_scores=True.")
# Additional logging about generation length and potential truncation
for idx, output in enumerate(outputs):
token_count = len(self._tokenizer.encode(output))
if token_count >= max_new_tokens:
_log.warning("Output %d reached the maximum token limit (%d tokens).", idx, token_count)
else:
_log.debug("Output %d generated %d tokens.", idx, token_count)
_log.info("Prediction completed successfully.")
return outputs
except Exception as e:
_log.error("Error in predict(): %s", e, exc_info=True)
raise
Additionally, i had to adjust the following dtype .float() parameters in this file since autocast expects bfloat16: "\wsl.localhost\Ubuntu\home\wstation\miniconda3\envs\newenv\lib\python3.12\site-packages\transformers\generation\utils.py"
in def _sample around lines 3300~3305:
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache):
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
if getattr(self, "hf_quantizer", None) is not None:
is_compileable &= self.hf_quantizer.is_compileable
is_compileable = is_compileable and not generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
is_prefill = True
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
# next_token_logits = outputs.logits[:, -1, :].clone().float() changed to enable beam search: 3/27/2025 12:34PM
next_token_logits = outputs.logits[:, -1, :].clone()
next_token_logits = next_token_logits.to(input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
in def _beam_search around lines 3772~3779:
def _beam_search(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
https://huggingface.co/blog/how-to-generate (especially the beam search section).
You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
(https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`:
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# 1. init beam_search values
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
do_sample = generation_config.do_sample
early_stopping = generation_config.early_stopping
length_penalty = generation_config.length_penalty
max_length = generation_config.max_length
num_beams = generation_config.num_beams
num_return_sequences = generation_config.num_return_sequences
batch_size_unflattened, cur_len = input_ids.shape
batch_size = batch_size_unflattened // num_beams
# TODO (joao): standardize special cases
if self.__class__.__name__ == "MoshiDepthDecoder":
vocab_size = self.config.audio_vocab_size
elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
vocab_size = self.get_output_embeddings().out_features
else:
vocab_size = self.config.get_text_config().vocab_size
decoder_prompt_len = cur_len
this_peer_finished = False
# At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
# with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
# (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
# non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
top_num_beam_mask = torch.cat(
(torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
dim=0,
).to(input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
# (joao) feature lost in the refactor. Probably won't implement, hurts readbility with minimal gains (there
# are newer low-memory alternatives like the offloaded cache)
sequential = generation_config.low_memory
if sequential:
raise ValueError(
"`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
"#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
)
# 2. init output tuples
all_scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
beam_indices = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# 3. init running tensors and static-shaped placeholders
# per batch, beam-item holding current token in loop and completed sequences
output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
running_sequences = torch.full(
(batch_size, num_beams, max_length),
fill_value=output_fill_value,
dtype=torch.int64,
device=input_ids.device,
)
running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
sequences = running_sequences.clone().detach()
# per batch, beam-item score, logprobs
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
running_beam_scores[:, 1:] = -1e9
beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
# per batch, beam-item state bit indicating if sentence has finished.
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
# per batch, beam-item state bit indicating if there are valid continuations.
next_token_hits_stopping_criteria = torch.zeros(
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device
)
# per batch selected beam indices
running_beam_indices = torch.full(
(batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
)
beam_indices = running_beam_indices.clone().detach()
# 4. run the generation loop
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# a. Forward current tokens, obtain the logits
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
model_outputs = self(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
model_outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue
# logits = model_outputs.logits[:, -1, :].clone().float() # Clone is needed to avoid keeping a hanging ref (changed to enable beam search: 3/27/2025 12:34PM)
logits = model_outputs.logits[:, -1, :].clone() # Clone is needed to avoid keeping a hanging ref
logits = logits.to(input_ids.device)
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
log_probs = nn.functional.log_softmax(logits, dim=-1)
log_probs = logits_processor(flat_running_sequences, log_probs)
in def _beam_search between lines 4069-4100:
def _group_beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
**model_kwargs,
):
r"""
Generates sequences of token ids for models with a language modeling head using **diverse beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
else:
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
# the same group don't produce same tokens every time.
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
outputs = self(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
if output_logits:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
raw_logit_score = outputs.logits[:, -1, :].clone()
raw_logit_score = raw_logit_score.to(input_ids.device)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of current group only
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
# .float() is needed to retain precision for later logits manipulations
# next_token_logits = outputs.logits[batch_group_indices, -1, :].float() (changed to enable beam search: 3/27/2025 12:34PM)
next_token_logits = outputs.logits[batch_group_indices, -1, :]
next_token_logits = next_token_logits.to(input_ids.device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# if output_scores: # troubleshooting "Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source." error 03/27/2025 12:37PM
# processed_score[batch_group_indices] = next_token_scores_processed
# changed to enable beam search: 3/27/2025 12:34PM
if output_scores:
if processed_score.dtype != next_token_scores_processed.dtype:
next_token_scores_processed = next_token_scores_processed.to(processed_score.dtype)
processed_score[batch_group_indices] = next_token_scores_processed
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
next_token_scores, next_tokens = torch.topk(
next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
)
With beam search enabled however, more VRAM is used and i kept getting out of memory errors even though I was running it on a 48GB GPU. So then what I did was lower batch size to 2:
# docling_extract_formulas.py
import logging
import os
# import sys
# sys.path.insert(0, "/wsl.localhost/Ubuntu/home/wstation/projects/rag_project")
import json
from docling_core.types.doc import TextItem
from docling_core.types.doc.labels import DocItemLabel
import docling.models.code_formula_model
docling.models.code_formula_model.CodeFormulaModel.elements_batch_size = 2
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
AcceleratorDevice,
PdfPipelineOptions,
)
from docling.datamodel.settings import settings
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
########################################
# 1) Configure Docling GPU + Pipeline
########################################
# Tells Docling to use the GPU (CUDA) with 10 threads:
accelerator_options = AcceleratorOptions(
num_threads=10,
device=AcceleratorDevice.CUDA
)
pipeline_options = PdfPipelineOptions()
pipeline_options.accelerator_options = accelerator_options
# If your PDFs are text-based, you can disable OCR to speed up:
pipeline_options.do_ocr = False
# If you want table structure (e.g. cell matching):
# pipeline_options.do_table_structure = True
# pipeline_options.table_structure_options.do_cell_matching = True
# Enable formula enrichment so Docling tries to detect formulas
pipeline_options.do_formula_enrichment = True
# Create a DocumentConverter with the pipeline settings
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options
)
}
)
The end result is that processing went from 5 minutes down to about 63 seconds for 15 page PDF article with 16 equations within it. GPU usage when beam search is enabled and batch size is set to two is approximately 40GB:
Despite massive performance enhancements, the formula extraction still struggles with proper formula extraction. For example, the following are all 16 of the latex equations from the PDF:
Latex Formatted
1 &= f_{\mathrm{field}} + f_{\mathrm{air}} + f_{\mathrm{dep}} (1)
f_{\mathrm{field}} &= 1 - \left(f_{\mathrm{air}} + f_{\mathrm{dep}}\right) = f_{\mathrm{field} \rightarrow \mathrm{crop}} + f_{\mathrm{field} \rightarrow \mathrm{rest}} (2)
f_{\mathrm{field} \rightarrow \mathrm{crop}} &= f_{\mathrm{field}} \times f_{\mathrm{intercept,crop}} (3)
f_{\mathrm{field} \rightarrow \mathrm{rest}} &= f_{\mathrm{field}} \times \left(1 - f_{\mathrm{intercept,crop}}\right) = f_{\mathrm{field} \rightarrow \mathrm{cover}} + f_{\mathrm{field} \rightarrow \mathrm{soil}} (4)
f_{\mathrm{field} \rightarrow \mathrm{cover}} &= f_{\mathrm{field} \rightarrow \mathrm{rest}} \times f_{\mathrm{eff,cover}} (5)
f_{\mathrm{field} \rightarrow \mathrm{soil}} &= f_{\mathrm{field} \rightarrow \mathrm{rest}} \times \left(1 - f_{\mathrm{eff,cover}}\right) (6)
f_{\mathrm{eff,cover}} &= f_{\mathrm{soil} \rightarrow \mathrm{cover}} \times f_{\mathrm{intercept,cover}} (7)
\mathrm{DT50}_{\mathrm{cover},T} &= 10^{\log_{10}\left(\mathrm{DT50}_{\mathrm{cover},T_{\mathrm{ref}}}\right)-0.01995\times\left(T-T_{\mathrm{ref}}\right)} (8)
k_{\mathrm{degCover}} &= \frac{\ln(2)}{\mathrm{DT50}_{\mathrm{cover},T}} (9)
k_{\mathrm{OCover}} &= k_{\mathrm{volat}} + k_{\mathrm{uptake}} + k_{\mathrm{degCover}} (10)
f_{\mathrm{degCover}} &= f_{\mathrm{field}\rightarrow\mathrm{cover}} \times \frac{k_{\mathrm{degCover}}}{k_{\mathrm{OCover}}} \times \left(1 - e^{-k_{\mathrm{OCover}} \times t_{\mathrm{assess}}}\right) { (10)}
f_{\mathrm{volCover}} &= f_{\mathrm{field}\rightarrow\mathrm{cover}} \times \frac{k_{\mathrm{volat}}}{k_{\mathrm{OCover}}} \times \left(1 - e^{-k_{\mathrm{OCover}} \times t_{\mathrm{assess}}}\right) { (11)}
f_{\mathrm{uptakeCover}} &= f_{\mathrm{field}\rightarrow\mathrm{cover}} \times \frac{k_{\mathrm{uptake}}}{k_{\mathrm{OCover}}} \times \left(1 - e^{-k_{\mathrm{OCover}} \times t_{\mathrm{assess}}}\right) { (12)}
f_{\mathrm{Leaves-cover}} &= f_{\mathrm{field}\rightarrow\mathrm{cover}} \times e^{-k_{\mathrm{OCover}} \times t_{\mathrm{assess}}} { (13)}
f_{\mathrm{cover-sec}} &= f_{\mathrm{uptakeCover}} + f_{\mathrm{Leaves-cover}} { (14)}
\mathrm{IS} &= \sum_{p,c}\left(m_{\mathrm{emi},p,c} \times CF_{p,c}\right) { (15)}
\%\ \mathrm{change} &= \left[\frac{\mathrm{IS\ without\ cover} - \mathrm{IS\ with\ cover}}{\mathrm{IS\ without\ cover}}\right] \times 100 { (16)}
But the extracted equations are quite messy:
[
{
"latex": "1 & = f _ { \\text{field} } + f_{\\text{air} }+ f_{dep} & & ( 1 )"
},
{
"latex": "f _ { f i e l d } = 1 - ( f _ { a i r } + f ^ { \\ } _ { \\text{dep} } ) = { f _{\\text{field} \\rightarrow c \\tt r o p } } + { f } _{{ \\text`field\\rightarrow rest } \\quad ( 2 )"
},
{
"latex": "f _ { \\text{field} \\rightarrow \\text{\\tt crop} } = f _ {\\text{ field} } \\times f_{intercept,\\tt crop} \\quad \\quad ( 3 )"
},
{
"latex": "\\begin{array} { c c c } \\| \\L a p - & f _ { \\text{field} \\rightarrow \\text{\\ } } & f_{\\text{fields} \\times (1 - f_{intercept,crop} ) = f_{{\\text{\\text{fold} \\to cover} + f_{ {\\text{\\bold} } \\to \\text{{{4} } } \\\\ \\tt p t i n g & & & \\quad & & ( 4 ) \\end{array}{}"
},
{
"latex": "f _ { f i e l d \\to \\text{cover} } = f _ { \\text{\\text{field\\to} } } _ { r e s t } \\times f ^ { } { _ { e f f } }, \\text{{over} }"
},
{
"latex": "f _ { f i e l d \\to \\text{soil} } = f _ { \\text{\\text{field} } \\to r e s t } \\times ( 1 - f ^ { \\ } _ { e f f, \\text{{cover} } ) \\quad \\quad ( 6 ) \\ \\text {\\text{where} }"
},
{
"latex": "f _ { \\text{eff,cover} } = f _ {\\text{soil} \\rightarrow \\coter } \\times f ^ { \\ } _ { i \\text{\\text{intercept,cover}} } \\quad \\ \\ ( 7 ) \\stackrel { \\prec - 1 } { \\Theta }."
},
{
"latex": "D T 5 0 _ { \\text{cover}, T } = 1 0 ^ { \\log _ { 1 0 } ( \\text{\\emph{DT50}_{\\text{cover},T_{ref} } ) - 0. 0 1 9 9 5 \\times ( T - T _ {ref} ) } } \\quad ( 8 )"
},
{
"latex": "k _ { \\deg \\text{Cover} } = \\frac { \\ln ( 2 ) } { D T { 5 } { 0 }, \\tilde { \\cos }, T } \\quad \\begin{pmatrix} 9 & \\text{\\emph{whel} } \\\\ \\text{{\\text{after} } \\end{pmetrix}"
},
{
"latex": "k _ { O C o v e r } = k _ { \\text{volat} } + k_{\\text{uptake} }+k_{degCover} \\quad \\ \\ (10 )"
},
{
"latex": "\\rightharpoonup \\rbegin{array} { c c c } & & \\\\ & & \\cdots \\\\ \\text{dati-} & & f _ { \\degCover} = f_{\\field\\to\\cover} \\times \\frac{k_{\\degCover}}{k_\\text{Cover} } \\times (1 - e ^ { ( - k _ { 0 } \\text{\\text{x}$assess} ) } \\\\ & t h e & & & ( ( ( 1 1 ) ) \\\\ \\tilde { \\ } a n l \\end{array}{}"
},
{
"latex": "\\text{alf-} \\quad f _ { \\text{volCover} } = f_{field} \\rightarrow \\tau \\times \\frac { k _ { v o l a t } } { k ^ { \\ } _ { O C o v e r } } \\times ( 1 - e ^ { ( - k \\tt \\tt O c v e n x t { \\tt x } \\tt s s e s ) } ) \\\\ \\tt h e d \\, a"
},
{
"latex": "& ( 8 ) & & f _ { \\text{uptakeCover} = f_{field-cover} \\times \\frac{k_{\\ttuptake} } { k_{OCovert} } \\times ( 1 - e ^ { ( - k _ { OCovertXt-assess)} } ) & \\\\ \\ n p e r - & & & ( ( 1 3 ) ) &"
},
{
"latex": "f _ { \\text{Leaves-cover} } = f _ { f i e l d \\rightarrow \\cot \\times e^{(-k\\text{Cover}^{t}assess} } \\quad ( 1 4 )"
},
{
"latex": "f _ { \\text{cover-sec} } = f _ {\\text{uptakeCover} } + f ^ { \\ } _ { L e a v e s - \\cot e r } \\quad \\ \\ ( 1 5 ) \\stackrel { \\dots } { f o r $ w }"
},
{
"latex": "\\text{IS} = \\sum _ { p, c } ( m _ { \\text{emi}, p, e } \\times C F _ { { p }, { c } } ) \\quad \\quad ( 1 6 ) \\, \\begin{smallmatrix} \\text{\\emph{on} \\, g } \\\\ \\text{{emis} } \\end{ll}"
},
{
"latex": "\\% - \\text{change} = \\left [ \\frac { I S \\text{\\em without cover} - I S with cover} { \\emph {S without cover}} } \\right ] \\times 1 0 0 \\quad \\Theta ^ { \\colon } _ { \\substack { c } \\\\ \\overbrace { \\overline { \\nu } = \\cdots } } \\ \\ \\strut r a c {. }"
}
]
Nevertheless, I believe that simple "iterative refinement" or post processing should bring quality up. Below is a python file i am working on to do some of the post processing
# latex_to_md.py
import re
import json
class LatexToMarkdown:
"""
A flexible converter from raw LaTeX to Markdown-friendly math notation.
Can optionally clean up OCR artifacts, remove alignment symbols, etc.
Example usage:
eq_list = [{"latex": "f _ { f i e l d } = 1 - ( f _ { a i r } + f _ { d e p } ) (1)"}]
converter = LatexToMarkdown(
eq_list,
remove_equation_numbers=True,
remove_alignment_ampersands=True,
remove_environments=True,
collapse_subscripts=True,
replace_rm_with_text=True,
block_display=True
)
results = converter.convert_all()
# Now results is a list of { "index": ..., "markdown": "$$...$$" }
"""
def __init__(
self,
equations_json,
remove_equation_numbers=True,
remove_alignment_ampersands=True,
remove_environments=True,
collapse_subscripts=True,
replace_rm_with_text=True,
block_display=True
):
"""
:param equations_json: List[dict], each with {"latex": "<LaTeX-string>"}
:param remove_equation_numbers: bool
If True, remove occurrences of (1), (2), etc. and \tag{1}, \tag{2}, etc.
:param remove_alignment_ampersands: bool
If True, remove alignment symbols (&) and "\\\\" from the LaTeX.
:param remove_environments: bool
If True, remove \begin{align}, \end{align}, partial \begin{array} etc.
:param collapse_subscripts: bool
If True, merge spaced-out subscripts like f _ { f i e l d } => f_{field}.
:param replace_rm_with_text: bool
If True, replace \mathrm{xyz} with \text{xyz}.
:param block_display: bool
If True, wrap final equation in $$...$$ (display mode).
If False, wrap final equation in $...$ (inline mode).
"""
self.equations_json = equations_json
self.remove_equation_numbers = remove_equation_numbers
self.remove_alignment_ampersands = remove_alignment_ampersands
self.remove_environments = remove_environments
self.collapse_subscripts = collapse_subscripts
self.replace_rm_with_text = replace_rm_with_text
self.block_display = block_display
# ---------------------------------------------------------
# 1) Basic replacements / standardization
# ---------------------------------------------------------
def _replace_rm_with_text(self, eq_str):
"""Replace \mathrm{...} with \text{...} if toggle is True."""
if self.replace_rm_with_text:
eq_str = eq_str.replace(r"\mathrm", r"\text")
return eq_str
def _remove_equation_numbers_fn(self, eq_str):
"""Remove (1), (2), etc. and \\tag{1}, etc. if toggle is True."""
if self.remove_equation_numbers:
# Remove bare (123)
eq_str = re.sub(r"\(\d+\)", "", eq_str)
# Remove \tag{123}
eq_str = re.sub(r"\\tag\s*\{\d+\}", "", eq_str)
return eq_str
# ---------------------------------------------------------
# 2) Alignment & environment handling
# ---------------------------------------------------------
def _remove_alignment_ampersands_fn(self, eq_str):
"""
Remove & and \\\\ if toggled on (commonly used in align/array blocks).
Be careful if you DO want to keep them for actual alignment in KaTeX.
"""
if self.remove_alignment_ampersands:
eq_str = eq_str.replace("&", "")
eq_str = eq_str.replace("\\\\", "")
return eq_str
def _remove_or_simplify_environments_fn(self, eq_str):
"""
Remove or simplify common environment wrappers like
\begin{align}, \end{align}, \begin{array}, etc.
If you want to keep arrays or pmatrix, disable self.remove_environments
or comment out the relevant lines.
"""
if self.remove_environments:
# Remove known environment wrappers that break some markdown renderers
eq_str = re.sub(r"\\begin\{align\*?\}", "", eq_str)
eq_str = re.sub(r"\\end\{align\*?\}", "", eq_str)
# Either fully remove array environment, or partially remove
# from \begin{array}{...} up to \end{array}.
eq_str = re.sub(r"\\begin\{array\}(\[.*?\])?\{.*?\}", "", eq_str)
eq_str = re.sub(r"\\end\{array\}(\{\})?", "", eq_str)
# If you want to remove pmatrix/bmatrix fully, uncomment:
# eq_str = re.sub(r"\\begin\{pmatrix\}.*?\\end\{pmatrix\}", "", eq_str, flags=re.DOTALL)
# eq_str = re.sub(r"\\begin\{bmatrix\}.*?\\end\{bmatrix\}", "", eq_str, flags=re.DOTALL)
return eq_str
# ---------------------------------------------------------
# 3) Collapsing subscript/superscript spacing
# ---------------------------------------------------------
def _collapse_subscript_spaces_fn(self, eq_str):
"""
Turn f _ { f i e l d } => f_{field}, if toggle is True.
"""
if not self.collapse_subscripts:
return eq_str
# (a) reduce multiple spaces inside braces => single space
eq_str = re.sub(
r"\{\s+([^}]*)\s+\}",
lambda m: "{" + " ".join(m.group(1).split()) + "}",
eq_str
)
# (b) remove spaces from purely alphabetical or \command blocks
def remove_inner_spaces(match):
inner = match.group(1) # inside braces
# If it's all letters or backslash-commands, remove spaces
if re.match(r'^[A-Za-z\\]+(\s[A-Za-z\\]+)*$', inner):
return "{" + inner.replace(" ", "") + "}"
else:
return "{" + inner + "}"
eq_str = re.sub(r"\{([^}]+)\}", remove_inner_spaces, eq_str)
return eq_str
# ---------------------------------------------------------
# 4) Final wrapper in $...$ or $$...$$
# ---------------------------------------------------------
def _wrap_in_math_mode(self, eq_str):
eq_str = eq_str.strip()
if self.block_display:
return f"$$ {eq_str} $$"
else:
return f"$ {eq_str} $"
# ---------------------------------------------------------
# Main single-equation converter
# ---------------------------------------------------------
def convert_equation(self, latex_str):
eq_str = latex_str
# 1) Basic replacements
eq_str = self._replace_rm_with_text(eq_str)
eq_str = self._remove_equation_numbers_fn(eq_str)
# 2) Alignment / environment
eq_str = self._remove_alignment_ampersands_fn(eq_str)
eq_str = self._remove_or_simplify_environments_fn(eq_str)
# 3) Subscript cleanup
eq_str = self._collapse_subscript_spaces_fn(eq_str)
# 4) Wrap in math mode
eq_str = self._wrap_in_math_mode(eq_str)
return eq_str
# ---------------------------------------------------------
# Convert all equations in self.equations_json
# ---------------------------------------------------------
def convert_all(self):
"""
Returns: list of dicts like [{"index": 1, "markdown": "$$ ... $$"}, ...]
"""
markdown_equations = []
for idx, eq_dict in enumerate(self.equations_json, 1):
latex_str = eq_dict.get("latex", "")
if not latex_str:
continue
md_eq = self.convert_equation(latex_str)
markdown_equations.append({"index": idx, "markdown": md_eq})
return markdown_equations
def export_markdown_json(self):
"""
Returns a JSON string containing all the converted equations, e.g.:
[
{ "index": 1, "markdown": "$$ f_{field} = 42 $$" },
{ "index": 2, "markdown": "$$ x^2 + y^2 = 1 $$" }
]
"""
converted = self.convert_all()
return json.dumps(converted, ensure_ascii=False, indent=2)
def print_markdown(self):
"""Convenience: just print all converted equations to stdout."""
for eq in self.convert_all():
print(f"{eq['markdown']}\n(Eq. {eq['index']})\n")
But anyways, yeah expected to use anywhere between 20GB-40GB of VRAM to get it working properly, or try to offload some stuff to CPU
But anyways, yeah expected to use anywhere between 20GB-40GB of VRAM to get it working properly, or try to offload some stuff to CPU
Oh my... 40GB for a (code/formula) model that's only 512MB on disk? How is that even possible? Something's off. I can't run code or the formula enhancer, but I can run a 4-bit quantized Qwen3 32B model at 50 tok/s with my 3090 gpu.