tensor2tensor
tensor2tensor copied to clipboard
Enable depth_k != depth_v in local_attention_2d and masked_local_attention_2d
This PR resolves a TODO in tensor2tensor/layers/common_attention_test.py : enable depth_v != depth_k for common_attention.local_attention_2d and common_attention.masked_local_attention_2d
Modification is simple: one just needs to alter the shape passed to scatter_blocks_2d when generating the output, so that its last dimension is taken from v, and not q
Tests:
pytest tensor2tensor/layers/common_attention_test.py tensor2tensor/layers/common_image_attention_test.pypasses- Sanity check: if
vis split along the depth dimension, then applyinglocal_attention_2don each of the two parts separately and concatenating the results should be equivalent to applyinglocal_attention_2dwith originalv- see the code below
Expand sanity check code and output
import numpy as np
import pandas as pd
import tensorflow as tf
from tensor2tensor.layers.common_attention import local_attention_2d
# Check that attention is commutative with splitting/concatenating v along the depth dimension
# Try varios split points
for split_idx in range(1, 30):
batch, heads, length, depth_k, depth_v, query_shape = 3, 4, 25, 16, 30, (4, 4)
q = tf.random_normal([batch, heads, length, length, depth_k], dtype=tf.float64)
k = tf.random_normal([batch, heads, length, length, depth_k], dtype=tf.float64)
v = tf.random_normal([batch, heads, length, length, depth_v], dtype=tf.float64)
# Apply attention with the first part of v
output_part1 = local_attention_2d(
q,
k,
v[:, :, :, :, :split_idx],
query_shape=query_shape,
memory_flange=(3, 3))
# Apply attention with the second part of v
output_part2 = local_attention_2d(
q,
k,
v[:, :, :, :, split_idx:],
query_shape=query_shape,
memory_flange=(3, 3))
# Put together results of two parts
output_concat = tf.concat([output_part1, output_part2], axis=4)
# Apply attention with the original v
output_full = local_attention_2d(
q,
k,
v,
query_shape=query_shape,
memory_flange=(3, 3))
# Compute the difference - should be small
with tf.Session() as sess:
res_diff = (output_concat - output_full).eval()
print(np.abs(res_diff).max())
2.220446049250313e-15
2.6645352591003757e-15
2.220446049250313e-15
1.7763568394002505e-15
2.220446049250313e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
2.6645352591003757e-15
1.7763568394002505e-15
2.220446049250313e-15
2.220446049250313e-15
2.6645352591003757e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
2.6645352591003757e-15
2.220446049250313e-15
2.4424906541753444e-15
2.6645352591003757e-15
3.1086244689504383e-15
2.220446049250313e-15
2.220446049250313e-15
0.0
1.3322676295501878e-15
0.0
2.220446049250313e-15
0.0
1.7763568394002505e-15
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
:memo: Please visit https://cla.developers.google.com/ to sign.
Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.
What to do if you already signed the CLA
Individual signers
- It's possible we don't have your GitHub username or you're using a different email address on your commit. Check your existing CLA data and verify that your email is set on your git commits.
Corporate signers
- Your company has a Point of Contact who decides which employees are authorized to participate. Ask your POC to be added to the group of authorized contributors. If you don't know who your Point of Contact is, direct the Google project maintainer to go/cla#troubleshoot (Public version).
- The email used to register you as an authorized contributor must be the email used for the Git commit. Check your existing CLA data and verify that your email is set on your git commits.
- The email used to register you as an authorized contributor must also be attached to your GitHub account.
ℹ️ Googlers: Go here for more info.
@googlebot I signed it!