OOM with Bert Tokenizer
Hello,
I am using lime to get interpretation about a classification problem. First, I am using Flaubert Tokenizer (also I tried different tokenizer and had the same problem) to transfer my text to tokens. Next I put the tokens as input to my model to get the probability with a softmax function (all of this is wrapped in a prediction method)
def predict(text):
encoding = tokenizer(text, padding=True, truncation=True, max_length=100, return_tensors="pt",)
outputs = module(encoding["input_ids"], encoding["attention_mask"])
probas = F.softmax(outputs.logits, dim=1).detach().numpy()
return probas
after, I created the explainer and all the other things thats go with ...
explainer = LimeTextExplainer(class_names=['list', 'of', 'my', 'classes']) # I have 11 classes
msg = "here is my dummy txt" # OOM with this message
exp = explainer.explain_instance(msg, predict, top_labels=1)
exp.show_in_notebook(text=msg)
The problem is if I use a short message I get my result (exemple of msg =bonjour ca va ). And if I run with a longer message I get OOM after 1min.
Can you please see I did miss something here? Thnks!!
exp = explainer.explain_instance(text, predict, num_features=6, top_labels=2, num_samples=3) you can modify parameters num_samples ,default :5000
You could batch the module() forward pass within your predict(text) function by only taking batch-sized chunks of the texts and concatenating the probas before returning. The num_samples default of 5000 gives you a single batch of 5000 samples which is almost certainly causing the OOM. (also commented here)