trax
trax copied to clipboard
TPU deadlock
Description
Hello,
I am trying to train reformer model using Trax and JAX. The training seems to be fine on Google Colab, but when I run it on google cloud server + TPU, it hangs on the "trax.supervised.Trainer".
The warning is as follows:
2020-08-26 17:46:37.421334: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.
Environment information
Ubuntu
$ pip freeze | grep trax
trax==1.3.4
$ pip freeze | grep tensor
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.0
tensorflow-addons==0.11.1
tensorflow-datasets==3.2.1
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.9.0
tensorflow-metadata==0.23.0
tensorflow-probability==0.7.0
tensorflow-text==2.3.0
$ pip freeze | grep jax
jax==0.1.75
jaxlib==0.1.52
$ python -V
Python 3.6.10 :: Anaconda, Inc.
For bugs: reproduction and error logs
Steps to reproduce:
...
import requests
import os
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + "10.206.164.18"
print(config.FLAGS.jax_backend_target)
from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.data import inputs
import numpy as np
import jax.numpy as jnp
from scipy.special import softmax
import sentencepiece as spm
from sentencepiece import SentencePieceProcessor
import random, glob, os
def fake_data():
with open("vocab.txt",'w') as f:
f.write("[MASK]\nL\nA\nG\nV\nE\nS\nI\nK\nR\nD\nT\nP\nN\nQ\nF\nY\nM\nH\nC\nW\nX\nU\nB\nZ\nO")
if not os.path.exists('dataset'):
os.makedirs('dataset')
with open("dataset/train_0.txt",'w') as f:
for i in range(50):
f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n")
with open("dataset/train_1.txt",'w') as f:
for i in range(50):
f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n")
fake_data()
spm.SentencePieceTrainer.train(input='vocab.txt',
model_prefix='protein',
vocab_size=30,
model_type="word",
#user_defined_symbols="<MASK>",
pad_id=0,
unk_id=1,
bos_id=2,
eos_id=3,
pad_piece="[PAD]",
unk_piece="[UNK]",
bos_piece="[BOS]",
eos_piece="[EOS]")
tokenizer = spm.SentencePieceProcessor(model_file='protein.model')
train_files = glob.glob("dataset/train*",recursive=True)
random.shuffle(train_files)
def mask_seq(seq,mask_prob=0.15):
seq = np.array(seq)
minValue = 1
maxValue = len(seq) - 2
max_mask_tokens = int(maxValue * 0.15 + 0.5)
randomlist = random.sample(range(minValue, maxValue), max_mask_tokens)
seq_masked = seq
seq_masked[randomlist] = tokenizer.encode("[MASK]")[0]
return seq_masked
def get_seq(train_files):
while True:
for file in train_files:
with open(file) as fp:
for line in fp:
yield line
def get_batch(seq_gen, batch_length):
batch = []
while True:
seq = next(seq_gen)
seq_ids = tokenizer.encode(seq,add_bos=True,add_eos=True)
new_batch_len = len(batch) + len(seq_ids)
if new_batch_len <= batch_length :
batch = batch + seq_ids
continue
next_batch = batch
batch = seq_ids
yield next_batch
# Set up the data pipeline.
def my_inputs(n_devices):
MAX_BATCH_LENGTH = 1024*4
seq_gen = get_seq(train_files)
batch_gen = get_batch(seq_gen,MAX_BATCH_LENGTH)
while True:
inputs = []
targets = []
mask = []
for i in range(n_devices):
batch_ids = next(batch_gen)
masked_seq_ids = mask_seq(batch_ids)
pad_amount = MAX_BATCH_LENGTH - len(batch_ids)
inputs.append(np.pad(masked_seq_ids, (0,pad_amount)))
targets.append(np.pad(batch_ids, (0,pad_amount)))
mask.append(np.pad(np.ones_like(batch_ids, dtype=np.float32),
(0,pad_amount),
mode='constant'))
inputs = np.stack(inputs)
targets = np.stack(targets)
mask = np.stack(mask)
yield (inputs, targets, mask)
inp_gen_test = my_inputs(trax.fastmath.device_count())
res = next(inp_gen_test)
print(tokenizer.decode(res[0][0].tolist()))
print(tokenizer.decode(res[1][0].tolist()))
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.Reformer
n_layers = 15
n_heads = 16
dropout = 0.1
n_tokens = 40000 # They have used very small n_tokens = 2048
vocab_size= 30
d_model = 1024
d_ff = 4096
# Done
# Parameters for MultifactorSchedule:
# ==============================================================================
multifactor.constant = 0.088
multifactor.decay_factor = 0.5
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.steps_per_cycle = 100000
multifactor.steps_per_decay = 20000
multifactor.warmup_steps = 8000
# Done
# Parameters for Adam:
# ==============================================================================
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-09
Adam.weight_decay_rate = 1e-05
# Done
# Parameters for SelfAttention:
# ==============================================================================
#trax.layers.SelfAttention.attention_dropout = 0.05
#trax.layers.SelfAttention.chunk_len = 64
#trax.layers.SelfAttention.n_chunks_before = 1
#trax.layers.SelfAttention.n_parallel_heads = 1
trax.layers.SelfAttention.causal = False
trax.layers.SelfAttention.chunk_len = None
trax.layers.SelfAttention.masked = False
trax.layers.SelfAttention.n_chunks_after = 0
trax.layers.SelfAttention.n_chunks_before = 0
trax.layers.SelfAttention.n_parallel_heads = None
trax.layers.SelfAttention.predict_drop_len = 64
trax.layers.SelfAttention.predict_mem_len = 192
trax.layers.SelfAttention.share_qk = False
trax.layers.SelfAttention.use_python_loop = False
trax.layers.SelfAttention.use_reference_code = False
# Done
# Parameters for EncDecAttention:
# ==============================================================================
trax.layers.EncDecAttention.masked = True
trax.layers.EncDecAttention.n_parallel_heads = None
trax.layers.EncDecAttention.use_python_loop = False
trax.layers.EncDecAttention.use_reference_code = False
# Done
# Parameters for LSHSelfAttention:
# ==============================================================================
#LSHSelfAttention.attention_dropout = 0.0
#LSHSelfAttention.chunk_len = 64
#LSHSelfAttention.n_buckets = [64, 128]
#LSHSelfAttention.n_chunks_after = 0
#LSHSelfAttention.n_chunks_before = 1
#LSHSelfAttention.n_hashes = 1
#LSHSelfAttention.n_parallel_heads = 1
#LSHSelfAttention.predict_drop_len = 128
#LSHSelfAttention.predict_mem_len = 1024
# Done
# Parameters for Reformer:
# ==============================================================================
Reformer.d_model = %d_model
Reformer.d_ff = %d_ff
Reformer.dropout = %dropout
Reformer.ff_activation = @trax.layers.Relu
Reformer.max_len = %n_tokens
Reformer.mode = 'train'
Reformer.n_heads = %n_heads
Reformer.n_encoder_layers = %n_layers
Reformer.n_decoder_layers = %n_layers
Reformer.input_vocab_size = %vocab_size
""")
# Set up a Trainer.
output_dir = os.path.expanduser('train_dir/')
trainer = trax.supervised.Trainer(
model=trax.models.Reformer,
loss_fn=trax.layers.CrossEntropyLoss(),
optimizer=trax.optimizers.Adam,
lr_schedule=trax.lr.multifactor(),
inputs=trax.data.inputs.Inputs(my_inputs),
output_dir=output_dir)
# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)
Error logs:
...
2020-08-26 17:46:37.421334: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.
2020-08-26 17:51:46.613101: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.
Any idea what could be the problem ?