Integrated-Gradients icon indicating copy to clipboard operation
Integrated-Gradients copied to clipboard

Can you provide an example of text like Sentiment classification?

Open lemon234071 opened this issue 5 years ago • 2 comments

Specifically, how to deal with word embedding....

lemon234071 avatar Sep 15 '19 03:09 lemon234071

Sure, we will soon add a text example.

The key idea is to attribute from the output to the embedding tensor (as the network is only differentiable up to the embedding tensor).

We will receive an attribution for each dimension of each token embedding. Then for each token we sum the attributions along all embedding dimensions to obtain a token level attribution.

Here is some example code:

def ig_text(inp, label, t_embedding, t_label, t_grad, baseline=None, steps=20):
  # Args:
  # - inp: Input tokens (or token ids) whose prediction (for the provided label) must be explained
  # - label: Prediction label 
  # - t_embedding: Embedding tensor
  # - t_grad: Tensor computing gradients of prediction w.r.t. the embedding tensor
  # - t_label: Placeholder tensor specifying the prediction label for which gradients must be computed
  if baseline is None:
    baseline = 0*inp
  embs = sess.run(t_embedding, {t_inp: [inp, baseline]})  # <batch, num_tokens, emd_dims>
  inp_emb = embs[0, :, :]
  baseline_emb = embs[1, :, :]
  scaled_embs = [baseline_emb + (float(i)/steps)*(inp_emb-baseline_emb) for i in range(0, steps+1)]
  feed[t_embedding] = scaled_embs
  feed[t_label] = label
  grads, scores = sess.run([t_grad, y_probs], feed_dict=feed)  # shapes: <steps+1, inp_emb.shape>, <steps, num_labels>    
  ig = (inp_emb-baseline_emb)*np.average(grads[1:,:,:], axis=0)  # shape: <inp_emb.shape>
  token_ig = np.sum(ig, axis=-1)  # shape: <num_tokens>
  return token_ig

ankurtaly avatar Sep 16 '19 22:09 ankurtaly

Thinks! Thank you very much for quick reply! It's very helpful. I just did not know how to obtain a token level attribution from embedding attributions, didn't figure out the principle. Best wishes.

lemon234071 avatar Sep 17 '19 02:09 lemon234071