tensor2tensor icon indicating copy to clipboard operation
tensor2tensor copied to clipboard

Enable depth_k != depth_v in local_attention_2d and masked_local_attention_2d

Open sgrigory opened this issue 4 years ago • 2 comments

This PR resolves a TODO in tensor2tensor/layers/common_attention_test.py : enable depth_v != depth_k for common_attention.local_attention_2d and common_attention.masked_local_attention_2d

Modification is simple: one just needs to alter the shape passed to scatter_blocks_2d when generating the output, so that its last dimension is taken from v, and not q

Tests:

  • pytest tensor2tensor/layers/common_attention_test.py tensor2tensor/layers/common_image_attention_test.py passes
  • Sanity check: if v is split along the depth dimension, then applying local_attention_2d on each of the two parts separately and concatenating the results should be equivalent to applying local_attention_2d with original v - see the code below
Expand sanity check code and output
import numpy as np

import pandas as pd
import tensorflow as tf

from tensor2tensor.layers.common_attention import local_attention_2d

# Check that attention is commutative with splitting/concatenating v along the depth dimension
# Try varios split points
for split_idx in range(1, 30):

    batch, heads, length, depth_k, depth_v, query_shape = 3, 4, 25, 16, 30, (4, 4)

    q = tf.random_normal([batch, heads, length, length, depth_k], dtype=tf.float64)
    k = tf.random_normal([batch, heads, length, length, depth_k], dtype=tf.float64)
    v = tf.random_normal([batch, heads, length, length, depth_v], dtype=tf.float64)
    
    # Apply attention with the first part of v
    output_part1 = local_attention_2d(
        q,
        k,
        v[:, :, :, :, :split_idx],
        query_shape=query_shape,
        memory_flange=(3, 3))
    
    # Apply attention with the second part of v
    output_part2 = local_attention_2d(
        q,
        k,
        v[:, :, :, :, split_idx:],
        query_shape=query_shape,
        memory_flange=(3, 3))
    
    # Put together results of two parts
    output_concat = tf.concat([output_part1, output_part2], axis=4)
    
    # Apply attention with the original v
    output_full = local_attention_2d(
        q,
        k,
        v,
        query_shape=query_shape,
        memory_flange=(3, 3))
    
    # Compute the difference - should be small
    with tf.Session() as sess:
        res_diff = (output_concat - output_full).eval()
        print(np.abs(res_diff).max())


2.220446049250313e-15
2.6645352591003757e-15
2.220446049250313e-15
1.7763568394002505e-15
2.220446049250313e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
2.6645352591003757e-15
1.7763568394002505e-15
2.220446049250313e-15
2.220446049250313e-15
2.6645352591003757e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
2.6645352591003757e-15
2.220446049250313e-15
2.4424906541753444e-15
2.6645352591003757e-15
3.1086244689504383e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
1.3322676295501878e-15
0.0
2.220446049250313e-15
0.0
1.7763568394002505e-15

sgrigory avatar Oct 30 '21 14:10 sgrigory

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

:memo: Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

google-cla[bot] avatar Oct 30 '21 14:10 google-cla[bot]

@googlebot I signed it!

sgrigory avatar Oct 30 '21 14:10 sgrigory