lit icon indicating copy to clipboard operation
lit copied to clipboard

How to visulaize attentions

Open pratikchhapolika opened this issue 3 years ago • 0 comments

Here is the code.

import sys
from absl import app
from absl import flags
from absl import logging

from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
# Use the regular GLUE data loaders, because these are very simple already.
from lit_nlp.examples.datasets import glue
from lit_nlp.lib import utils
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
import tensorflow_datasets as tfds
from transformers import BertTokenizer, BertForSequenceClassification
import pandas as pd
import torch
import transformers



df = pd.read_excel("data.xlsx",sheet_name='master_data')
print(df.shape)
df = df[df['train'] == 1]
df = df.head(100)
df = df[['UTTERANCE','label']]
df['label'] = df['label'].astype(int)
print(df.head(2))




def load_tfds(*args, do_sort=True, **kw):
    """Load from TFDS, with optional sorting."""
    # Materialize to NumPy arrays.
    # This also ensures compatibility with TF1.x non-eager mode, which doesn't
    # support direct iteration over a tf.data.Dataset.

    # ds = tfds.load('glue/sst2', split='train', shuffle_files=True,download=True)
    ret = df.values.tolist()
    print(ret)
    # if do_sort:
    #     # Recover original order, as if you loaded from a TSV file.
    #     ret.sort(key=lambda ex: ex['idx'])
    return ret



class SST2Data(lit_dataset.Dataset):
    """Stanford Sentiment Treebank, binary version (SST-2).
    See https://www.tensorflow.org/datasets/catalog/glue#gluesst2.
    """

    LABELS = ['0', '1']

    def __init__(self, data):
        self._examples = []
        for ex in load_tfds(df):
            self._examples.append({
                'sentence': ex[0],
                'label': self.LABELS[ex[1]],
            })

        print(self._examples)

    def spec(self):
        return {
            'sentence': lit_types.TextSegment(),
            'label': lit_types.CategoryLabel(vocab=self.LABELS)
        }



FLAGS = flags.FLAGS

FLAGS.set_default("development_demo", True)

flags.DEFINE_string(
    "model_path",
    "https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
    "Path to trained model, in standard transformers format, e.g. as "
    "saved by model.save_pretrained() and tokenizer.save_pretrained()")


def _from_pretrained(cls, *args, **kw):
    """Load a transformers model in PyTorch, with fallback to TF2/Keras weights."""
    try:
        return cls.from_pretrained(*args, **kw)
    except OSError as e:
        logging.warning("Caught OSError loading model: %s", e)
        logging.warning(
            "Re-trying to convert from TensorFlow checkpoint (from_tf=True)")
        return cls.from_pretrained(*args, from_tf=True, **kw)


class SimpleSentimentModel(lit_model.Model):
    """Simple sentiment analysis model."""

    LABELS = ["0", "1"]  # negative, positive

    def __init__(self, model_name_or_path):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # This is a just a regular PyTorch model.
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2,output_hidden_states=True,output_attentions=True)
        self.model.eval()

    ##
    # LIT API implementation
    def max_minibatch_size(self):
        # This tells lit_model.Model.predict() how to batch inputs to
        # predict_minibatch().
        # Alternately, you can just override predict() and handle batching yourself.
        return 32

    def predict_minibatch(self, inputs):
        # Preprocess to ids and masks, and make the input batch.
        encoded_input = self.tokenizer.batch_encode_plus(
            [ex["sentence"] for ex in inputs],
            return_tensors="pt",
            add_special_tokens=True,
            max_length=256,
            padding="longest",
            truncation="longest_first")

        # Check and send to cuda (GPU) if available
        if torch.cuda.is_available():
            self.model.cuda()
            for tensor in encoded_input:
                encoded_input[tensor] = encoded_input[tensor].cuda()
        # Run a forward pass.
        with torch.no_grad():  # remove this if you need gradients.
            out: transformers.modeling_outputs.SequenceClassifierOutput = self.model(**encoded_input)

        # Post-process outputs.
        batched_outputs = {
            "probas": torch.nn.functional.softmax(out.logits, dim=-1),
            "input_ids": encoded_input["input_ids"],
            "ntok": torch.sum(encoded_input["attention_mask"], dim=1),
            "cls_emb": out.hidden_states[-1][:, 0],  # last layer, first token
        }
        # Return as NumPy for further processing.
        detached_outputs = {k: v.cpu().numpy() for k, v in batched_outputs.items()}
        # Unbatch outputs so we get one record per input example.
        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self.tokenizer.convert_ids_to_tokens(
                output.pop("input_ids")[1:ntok - 1])
            yield output

    def input_spec(self) -> lit_types.Spec:
        return {
            "sentence": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.LABELS, required=False)
        }

    def output_spec(self) -> lit_types.Spec:
        return {
            "tokens": lit_types.Tokens(),
            "probas": lit_types.MulticlassPreds(parent="label", vocab=self.LABELS,
                                                null_idx=0),
            "cls_emb": lit_types.Embeddings()
        }


def get_wsgi_app():
    """Returns a LitApp instance for consumption by gunicorn."""
    FLAGS.set_default("server_type", "external")
    FLAGS.set_default("demo_mode", True)
    # Parse flags without calling app.run(main), to avoid conflict with
    # gunicorn command line flags.
    unused = flags.FLAGS(sys.argv, known_only=True)
    return main(unused)


def main(_):
    # Normally path is a directory; if it's an archive file, download and
    # extract to the transformers cache.
    model_path = FLAGS.model_path
    if model_path.endswith(".tar.gz"):
        model_path = transformers.file_utils.cached_path(
            model_path, extract_compressed_file=True)

    # Load the model we defined above.
    models = {"sst": SimpleSentimentModel(model_path)}
    # Load SST-2 validation set from TFDS.
    datasets = {"sst_dev": SST2Data(df)}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    return lit_demo.serve()


if __name__ == "__main__":
    app.run(main)



Screenshot 2022-03-23 at 10 53 00 PM

pratikchhapolika avatar Mar 23 '22 11:03 pratikchhapolika