lit
lit copied to clipboard
How to visulaize attentions
Here is the code.
import sys
from absl import app
from absl import flags
from absl import logging
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
# Use the regular GLUE data loaders, because these are very simple already.
from lit_nlp.examples.datasets import glue
from lit_nlp.lib import utils
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
import tensorflow_datasets as tfds
from transformers import BertTokenizer, BertForSequenceClassification
import pandas as pd
import torch
import transformers
df = pd.read_excel("data.xlsx",sheet_name='master_data')
print(df.shape)
df = df[df['train'] == 1]
df = df.head(100)
df = df[['UTTERANCE','label']]
df['label'] = df['label'].astype(int)
print(df.head(2))
def load_tfds(*args, do_sort=True, **kw):
"""Load from TFDS, with optional sorting."""
# Materialize to NumPy arrays.
# This also ensures compatibility with TF1.x non-eager mode, which doesn't
# support direct iteration over a tf.data.Dataset.
# ds = tfds.load('glue/sst2', split='train', shuffle_files=True,download=True)
ret = df.values.tolist()
print(ret)
# if do_sort:
# # Recover original order, as if you loaded from a TSV file.
# ret.sort(key=lambda ex: ex['idx'])
return ret
class SST2Data(lit_dataset.Dataset):
"""Stanford Sentiment Treebank, binary version (SST-2).
See https://www.tensorflow.org/datasets/catalog/glue#gluesst2.
"""
LABELS = ['0', '1']
def __init__(self, data):
self._examples = []
for ex in load_tfds(df):
self._examples.append({
'sentence': ex[0],
'label': self.LABELS[ex[1]],
})
print(self._examples)
def spec(self):
return {
'sentence': lit_types.TextSegment(),
'label': lit_types.CategoryLabel(vocab=self.LABELS)
}
FLAGS = flags.FLAGS
FLAGS.set_default("development_demo", True)
flags.DEFINE_string(
"model_path",
"https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
"Path to trained model, in standard transformers format, e.g. as "
"saved by model.save_pretrained() and tokenizer.save_pretrained()")
def _from_pretrained(cls, *args, **kw):
"""Load a transformers model in PyTorch, with fallback to TF2/Keras weights."""
try:
return cls.from_pretrained(*args, **kw)
except OSError as e:
logging.warning("Caught OSError loading model: %s", e)
logging.warning(
"Re-trying to convert from TensorFlow checkpoint (from_tf=True)")
return cls.from_pretrained(*args, from_tf=True, **kw)
class SimpleSentimentModel(lit_model.Model):
"""Simple sentiment analysis model."""
LABELS = ["0", "1"] # negative, positive
def __init__(self, model_name_or_path):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# This is a just a regular PyTorch model.
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2,output_hidden_states=True,output_attentions=True)
self.model.eval()
##
# LIT API implementation
def max_minibatch_size(self):
# This tells lit_model.Model.predict() how to batch inputs to
# predict_minibatch().
# Alternately, you can just override predict() and handle batching yourself.
return 32
def predict_minibatch(self, inputs):
# Preprocess to ids and masks, and make the input batch.
encoded_input = self.tokenizer.batch_encode_plus(
[ex["sentence"] for ex in inputs],
return_tensors="pt",
add_special_tokens=True,
max_length=256,
padding="longest",
truncation="longest_first")
# Check and send to cuda (GPU) if available
if torch.cuda.is_available():
self.model.cuda()
for tensor in encoded_input:
encoded_input[tensor] = encoded_input[tensor].cuda()
# Run a forward pass.
with torch.no_grad(): # remove this if you need gradients.
out: transformers.modeling_outputs.SequenceClassifierOutput = self.model(**encoded_input)
# Post-process outputs.
batched_outputs = {
"probas": torch.nn.functional.softmax(out.logits, dim=-1),
"input_ids": encoded_input["input_ids"],
"ntok": torch.sum(encoded_input["attention_mask"], dim=1),
"cls_emb": out.hidden_states[-1][:, 0], # last layer, first token
}
# Return as NumPy for further processing.
detached_outputs = {k: v.cpu().numpy() for k, v in batched_outputs.items()}
# Unbatch outputs so we get one record per input example.
for output in utils.unbatch_preds(detached_outputs):
ntok = output.pop("ntok")
output["tokens"] = self.tokenizer.convert_ids_to_tokens(
output.pop("input_ids")[1:ntok - 1])
yield output
def input_spec(self) -> lit_types.Spec:
return {
"sentence": lit_types.TextSegment(),
"label": lit_types.CategoryLabel(vocab=self.LABELS, required=False)
}
def output_spec(self) -> lit_types.Spec:
return {
"tokens": lit_types.Tokens(),
"probas": lit_types.MulticlassPreds(parent="label", vocab=self.LABELS,
null_idx=0),
"cls_emb": lit_types.Embeddings()
}
def get_wsgi_app():
"""Returns a LitApp instance for consumption by gunicorn."""
FLAGS.set_default("server_type", "external")
FLAGS.set_default("demo_mode", True)
# Parse flags without calling app.run(main), to avoid conflict with
# gunicorn command line flags.
unused = flags.FLAGS(sys.argv, known_only=True)
return main(unused)
def main(_):
# Normally path is a directory; if it's an archive file, download and
# extract to the transformers cache.
model_path = FLAGS.model_path
if model_path.endswith(".tar.gz"):
model_path = transformers.file_utils.cached_path(
model_path, extract_compressed_file=True)
# Load the model we defined above.
models = {"sst": SimpleSentimentModel(model_path)}
# Load SST-2 validation set from TFDS.
datasets = {"sst_dev": SST2Data(df)}
# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
return lit_demo.serve()
if __name__ == "__main__":
app.run(main)