trax icon indicating copy to clipboard operation
trax copied to clipboard

TPU deadlock

Open agemagician opened this issue 4 years ago • 0 comments

Description

Hello,

I am trying to train reformer model using Trax and JAX. The training seems to be fine on Google Colab, but when I run it on google cloud server + TPU, it hangs on the "trax.supervised.Trainer".

The warning is as follows:

2020-08-26 17:46:37.421334: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.

Environment information

Ubuntu

$ pip freeze | grep trax
trax==1.3.4

$ pip freeze | grep tensor
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.0
tensorflow-addons==0.11.1
tensorflow-datasets==3.2.1
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.9.0
tensorflow-metadata==0.23.0
tensorflow-probability==0.7.0
tensorflow-text==2.3.0

$ pip freeze | grep jax
jax==0.1.75
jaxlib==0.1.52

$ python -V
Python 3.6.10 :: Anaconda, Inc.

For bugs: reproduction and error logs

Steps to reproduce:

...

import requests
import os

from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + "10.206.164.18"
print(config.FLAGS.jax_backend_target)

from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.data import inputs

import numpy as np
import jax.numpy as jnp

from scipy.special import softmax

import sentencepiece as spm
from sentencepiece import SentencePieceProcessor

import random, glob, os

def fake_data():
  with open("vocab.txt",'w') as f:
    f.write("[MASK]\nL\nA\nG\nV\nE\nS\nI\nK\nR\nD\nT\nP\nN\nQ\nF\nY\nM\nH\nC\nW\nX\nU\nB\nZ\nO")

  if not os.path.exists('dataset'):
    os.makedirs('dataset')

  with open("dataset/train_0.txt",'w') as f:
    for i in range(50):
      f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
      f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
      f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
      f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
      f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
      f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
      f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
      f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
      f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n") 

    with open("dataset/train_1.txt",'w') as f:
      for i in range(50):
        f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
        f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
        f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
        f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
        f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
        f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
        f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
        f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
        f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n") 


fake_data()

spm.SentencePieceTrainer.train(input='vocab.txt', 
                               model_prefix='protein', 
                               vocab_size=30, 
                               model_type="word", 
                               #user_defined_symbols="<MASK>",
                               pad_id=0,
                               unk_id=1,
                               bos_id=2,
                               eos_id=3,
                               pad_piece="[PAD]",
                               unk_piece="[UNK]",
                               bos_piece="[BOS]",
                               eos_piece="[EOS]")


tokenizer = spm.SentencePieceProcessor(model_file='protein.model')

train_files = glob.glob("dataset/train*",recursive=True)
random.shuffle(train_files)

def mask_seq(seq,mask_prob=0.15):
  seq = np.array(seq)
  minValue = 1
  maxValue = len(seq) - 2
  max_mask_tokens = int(maxValue * 0.15 + 0.5)
  randomlist = random.sample(range(minValue, maxValue), max_mask_tokens)
  seq_masked = seq
  seq_masked[randomlist] = tokenizer.encode("[MASK]")[0]
  return seq_masked

def get_seq(train_files):
  while True:
    for file in train_files:
      with open(file) as fp: 
        for line in fp:
          yield line

def get_batch(seq_gen, batch_length):

  batch = []
  while True:
    seq = next(seq_gen)
    seq_ids = tokenizer.encode(seq,add_bos=True,add_eos=True)
    
    new_batch_len = len(batch) + len(seq_ids)

    if new_batch_len <= batch_length :
      batch = batch + seq_ids
      continue

    next_batch = batch
    batch = seq_ids

    yield next_batch

# Set up the data pipeline.
def my_inputs(n_devices):

  MAX_BATCH_LENGTH = 1024*4

  seq_gen = get_seq(train_files)
  batch_gen = get_batch(seq_gen,MAX_BATCH_LENGTH)

  while True:

    inputs = []
    targets = []
    mask = []
    for i in range(n_devices):
      batch_ids = next(batch_gen)
      masked_seq_ids = mask_seq(batch_ids)
      pad_amount = MAX_BATCH_LENGTH - len(batch_ids)

      inputs.append(np.pad(masked_seq_ids, (0,pad_amount)))
      targets.append(np.pad(batch_ids, (0,pad_amount)))
      mask.append(np.pad(np.ones_like(batch_ids, dtype=np.float32),
                          (0,pad_amount),
                          mode='constant'))
    inputs = np.stack(inputs)
    targets = np.stack(targets)
    mask = np.stack(mask)
    yield (inputs, targets, mask)

inp_gen_test = my_inputs(trax.fastmath.device_count())

res = next(inp_gen_test)

print(tokenizer.decode(res[0][0].tolist()))

print(tokenizer.decode(res[1][0].tolist()))


# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.Reformer
n_layers = 15
n_heads = 16
dropout = 0.1
n_tokens = 40000 # They have used very small n_tokens = 2048
vocab_size= 30
d_model = 1024
d_ff = 4096

# Done
# Parameters for MultifactorSchedule:
# ==============================================================================
multifactor.constant = 0.088
multifactor.decay_factor = 0.5
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.steps_per_cycle = 100000
multifactor.steps_per_decay = 20000
multifactor.warmup_steps = 8000


# Done
# Parameters for Adam:
# ==============================================================================
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-09
Adam.weight_decay_rate = 1e-05

# Done
# Parameters for SelfAttention:
# ==============================================================================
#trax.layers.SelfAttention.attention_dropout = 0.05
#trax.layers.SelfAttention.chunk_len = 64
#trax.layers.SelfAttention.n_chunks_before = 1
#trax.layers.SelfAttention.n_parallel_heads = 1
trax.layers.SelfAttention.causal = False
trax.layers.SelfAttention.chunk_len = None
trax.layers.SelfAttention.masked = False
trax.layers.SelfAttention.n_chunks_after = 0
trax.layers.SelfAttention.n_chunks_before = 0
trax.layers.SelfAttention.n_parallel_heads = None
trax.layers.SelfAttention.predict_drop_len = 64
trax.layers.SelfAttention.predict_mem_len = 192
trax.layers.SelfAttention.share_qk = False
trax.layers.SelfAttention.use_python_loop = False
trax.layers.SelfAttention.use_reference_code = False

# Done
# Parameters for EncDecAttention:
# ==============================================================================
trax.layers.EncDecAttention.masked = True
trax.layers.EncDecAttention.n_parallel_heads = None
trax.layers.EncDecAttention.use_python_loop = False
trax.layers.EncDecAttention.use_reference_code = False

# Done
# Parameters for LSHSelfAttention:
# ==============================================================================
#LSHSelfAttention.attention_dropout = 0.0
#LSHSelfAttention.chunk_len = 64
#LSHSelfAttention.n_buckets = [64, 128]
#LSHSelfAttention.n_chunks_after = 0
#LSHSelfAttention.n_chunks_before = 1
#LSHSelfAttention.n_hashes = 1
#LSHSelfAttention.n_parallel_heads = 1
#LSHSelfAttention.predict_drop_len = 128
#LSHSelfAttention.predict_mem_len = 1024

# Done
# Parameters for Reformer:
# ==============================================================================
Reformer.d_model = %d_model
Reformer.d_ff = %d_ff
Reformer.dropout = %dropout
Reformer.ff_activation = @trax.layers.Relu
Reformer.max_len = %n_tokens
Reformer.mode = 'train'
Reformer.n_heads = %n_heads
Reformer.n_encoder_layers = %n_layers
Reformer.n_decoder_layers = %n_layers
Reformer.input_vocab_size = %vocab_size


""")

# Set up a Trainer.
output_dir = os.path.expanduser('train_dir/')

trainer = trax.supervised.Trainer(
    model=trax.models.Reformer,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)


# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)

Error logs:

...

2020-08-26 17:46:37.421334: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.
2020-08-26 17:51:46.613101: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.

Any idea what could be the problem ?

agemagician avatar Aug 26 '20 15:08 agemagician