flair icon indicating copy to clipboard operation
flair copied to clipboard

[Question]: Controlling the CUDA OOM error for SequenceTagger via the use of tokens threshold

Open adhadse opened this issue 10 months ago • 4 comments

Question

While trying to deploy the NER model, often the length of the generated token sequence can not be controlled even though the characters can be limited.

  • With that said, I tried to limit the amount of tokens per batch by making the every batch inference dynamic.
  • Each batch can have a fixed limit of total tokens, but the batch size will vary. This was done to mitigate any memory issue.

Although, It didn't played nice, Even though the total tokens were fixed amount,

  • I still got CUDA memory errors ever so often in Colab T4 with 16 GB instance.
  • Plus, it was slow as compared to just iterating over batches of a fixed sentences size 16 & mini_batch_size=16, as compared to providing all sentences I want inference of at once which would normally fail with CUDA OOM error a lot of times.

I'll write down the piece of code that I tried to work with:

class NERExtractionError(Exception):
    def __init__(self, error, articles_tokens_count=0):
        super().__init__(error)
        self.name = "NERExtractionError"
        self.error = error
        self.batch_tokens_count = articles_tokens_count

class NERExtractionWithAutoBatching:
    def __init__(self, expected_total_tokens_in_batch=35_000):
        self.articles_tokens_count: list[int] = []
        self.expected_total_tokens_in_batch = expected_total_tokens_in_batch

    def generate_batches_of_n_token_lengths(
            self, df: pd.DataFrame
        ):
            batches = []
            df["tokens_count_cumsum"] = df["token_count"].cumsum(axis="index")
    
            while df.shape[0] != 0:
                if df.iloc[0]["tokens_count_cumsum"] > self.expected_total_tokens_in_batch:
                    batch = df.iloc[[0]]
                    # print(f"first if condition {type(batch)}")
                    batches.append(batch)
                    df = df.drop(index=batch.index)
                    df["tokens_count_cumsum"] = df["token_count"].cumsum(axis="index")
                else:
                    batch = df[df["tokens_count_cumsum"] <= self.expected_total_tokens_in_batch]
                    # print(f"second if condition {type(batch)}")
                    batches.append(batch)
                    df = df.drop(index=batch.index)
                    df["tokens_count_cumsum"] = df["token_count"].cumsum(axis="index")
            return batches

    def inference(self, articles: pd.DataFrame):
        total_tokens_in_each_article: list[int] = articles["token_count"].tolist()
        sentences = articles["sentence"].tolist()

        flair_model = SequenceTagger.load("flair/ner-english-ontonotes-large")
        all_articles_ner = []

        try:
            flair_model.predict(
                sentences, mini_batch_size=len(sentences),
                embedding_storage_mode=None
            )
        except Exception as e:
            del sentences
            torch.cuda.empty_cache()
            raise NERExtractionError(
                e,
                articles_tokens_count=total_tokens_in_each_article
            )

        # each sentence is an article instance
        for sentence in sentences:
            article_ner = []
            for entity in sentence.get_labels('ner'):
                ner_text_and_label = {
                    "text": entity.data_point.text,
                    "labels": entity.to_dict()
                }
                article_ner.append(ner_text_and_label)
            all_articles_ner.append(article_ner)
            sentence.clear_embeddings()

        torch.cuda.empty_cache()

        return all_articles_ner
    
    def flair_inference(self, df):
        """Main function, pass it a pd.DataFrame containing "inputs" col filled with hundreds and thousands of characters
        CHARACTERS ARE CAPPED AT 20_000
        """
        flair_output = []

        df["sentence"] = df["inputs"].apply(lambda x: Sentence(x))
        df["token_count"] = df["sentence"].apply(lambda x: len(x.tokens))

        articles_batches = []
        batches = self.generate_batches_of_n_token_lengths(df)

        for i, batch in enumerate(batches):
            print(f"[INFO] Batch #{i} token count:"
                  f"{humanize.intcomma(sum(batch['token_count'].tolist()))} "
                  f"of dynamic batch size of {len(batch)}")
            output = self.inference(batch)
            flair_output.extend(output)

        return flair_output

Is this a valid path to take, or are there any better option out there to work with humongous amount of data to gets it's NER.

adhadse avatar Oct 16 '23 05:10 adhadse

Hi @adhadse I think there is no definitive guarantee to never go OOM, however you could use a sentence splitter to split your 20k-character paragraphs into smaller sentences. Then the smaller sentences should be easier to process using batch inference

helpmefindaname avatar Oct 16 '23 08:10 helpmefindaname

@helpmefindaname Thanks for the response. With that I have no idea how would I keep track of hundreds of articles, cause those few batch of articles will result in thousands of Sentence but it isn't possible to trace them back to the article.

I guess the only approach with the sentence splitter would be to iterate over each article, create batch of Sentences, this obviously would be expected to have lesser chance of OOM error, but might reduce the speed of inference (I guess trade-off).

Please correct me if I'm wrong. By the way, why can't tokens be used as a measure to control OOM occurences.

adhadse avatar Oct 17 '23 06:10 adhadse

For tracking, you can use metadata:

sentences = []
for article_id, article_text in enumerate(articles):
   for sentence_id, sentence in enumerate(sentence_splitter.split(article_text)):
      sentence.add_metadata("article_id", article_id)
      sentence.add_metadata("sentence_id", sentence_id)

and then later use sentence.get_metadata("article_id") and sentence.get_metadata("sentence_id") to trace the sentence back to the origin. If you only need the extracted entities, this is very simple. If you also want to also have the position of the entity within the article, you need to add up the entity positions like described in https://github.com/flairNLP/flair/issues/3322#issuecomment-1742687899

Since the model you are using uses transformers embeddings which use the attention mechanism with a O(n²) complexity with respect to the sequence length, you should gain speed by using more smaller sentences instead of less larger ones. Keep in mind that the amount of tokens to compute stays about the same, so there is no reason to assume speed reduction.

About the tokens: As said before, you are using a model that uses transformers embeddings. Those embeddings use a embedding-specific sub-token tokenization and compute on that level. So for example when you use german electra-base, and want to compute the text donaudampfschifffahrtselektrizitätenhauptbetriebswerkbauunterbeamtengesellschaft, you will see 1 token, but the tokenizer of gelectra will compute embeddings for the following subtokens: ['don', '##aud', '##ampf', '##schiff', '##fahrts', '##elekt', '##riz', '##itäten', '##haupt', '##betriebs', '##werk', '##bau', '##unter', '##beamten', '##gesellschaft']. The largest memory footprint is during that computation, hence the 15 sub-tokens are more important than the final embedding for 1 token.

You might be able to assume the amount of subtokens, like on average 1 token has 3 subtokens, but you won't guarantee cases where you will have many more subtokens instead.

helpmefindaname avatar Oct 17 '23 12:10 helpmefindaname

I am counting tokens based on Sentence which has tokenized the article:

df["sentence"] = df["inputs"].apply(lambda x: Sentence(x))
df["token_count"] = df["sentence"].apply(lambda x: len(x.tokens))

adhadseKavida avatar Oct 17 '23 12:10 adhadseKavida