lime
lime copied to clipboard
Added support for custom replacement_fn
Rather than removing text which can create oddities, we may want to consider ways to replace tokens that would otherwise be removed. I added a support for a custom replacement_fn, which is similar to the classifier_fn. My particular use case was using T5, as such, I modified the generation of perturbed data to be in batch style rather than going one at a time.
This solves partially #648
Example replacement_fn:
def t5_wrapper(text_as_list: List[str], masks: list[list[bool]]):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
out_refs = []
masker_idxs = []
outs = []
for mask in masks:
local_out = ""
local_out_ref = ""
local_masker_idx = 0
for idx in range(len(mask)):
if mask[idx]:
local_out += text_as_list[idx]
local_out_ref += text_as_list[idx]
else:
try:
local_out += tokenizer.additional_special_tokens[local_masker_idx]
local_masker_idx += 1
except IndexError:
continue
masker_idxs.append(local_masker_idx)
outs.append(local_out)
out_refs.append(local_out_ref)
model.cuda()
batch_size = 50
if len(outs) > batch_size:
input_ids = tokenizer(outs, return_tensors="pt", padding=True, max_length=512, truncation=True)
model_suggestions = []
for idx in range(0, len(input_ids.input_ids), batch_size):
local_inputs = {}
for key, value in input_ids.items():
local_inputs[key] = value[idx: idx+batch_size]
for key, value in local_inputs.items():
local_inputs[key] = value.cuda()
outputs = model.generate(**local_inputs)
model_suggestions.extend(tokenizer.batch_decode(outputs, skip_special_tokens=False))
else:
input_ids = tokenizer(outs, return_tensors="pt", padding=True)
for key, value in input_ids.items():
input_ids[key] = value.cuda()
outputs = model.generate(**input_ids)
model_suggestions = tokenizer.batch_decode(outputs, skip_special_tokens=False)
inversed_data = []
for idx, suggestion in enumerate(model_suggestions):
local_out = outs[idx]
local_masker_idx = masker_idxs[idx]
present_tokens = [tokenizer.additional_special_tokens[idx] for idx in range(local_masker_idx) if
tokenizer.additional_special_tokens[idx] in suggestion]
for idx, present in enumerate(present_tokens):
if idx == len(present_tokens) - 1:
index = suggestion.find(present)
start_idx = index + len(present)
local_out = local_out.replace(present, suggestion[start_idx:])
else:
base_index = suggestion.find(present)
start_idx = base_index + len(present)
upper_index = suggestion.find(present_tokens[idx + 1])
local_out = local_out.replace(present, suggestion[start_idx:upper_index])
for item in tokenizer.additional_special_tokens:
local_out = local_out.replace(item, "")
inversed_data.append(local_out)
return inversed_data