Wissam Antoun

Results 22 comments of Wissam Antoun

@gante I think it's better to replace ```python flat_x = tf.reshape(x, (-1, x.shape[-1])) flat_indices = tf.reshape(indices, (-1, indices.shape[-1])) gathered = tf.gather(flat_x, flat_indices, batch_dims=1) gathered = tf.reshape(gathered, shape_list(indices)) ``` with ```python...

Weird! During my TPU and GPU tests, i was using a custom training loop instead of keras's `.fit()`, which I'm not sure if it actually matters. In my custom training...

@tmoroder Hey, can i ask about the training throughput/performance you got with the TPUs?

Oh great! I mean not great in the sense that the model is super slow on TPUs, but great that `model.fit` and my custom training loop have the same issue....

@Rocketknight1 I read all the discussions that you had with Kamal about the `torch.gather` and `take_along_axis` . On GPUs I already enabled XLA via `tf.config.optimizer.set_jit` and via T`F_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"` but...

runnig the training with `jit_compile=True` on GPU revealed a new bug. Then it is now an XLA/JIT issue not a TPU one View log dump ```md 2022-07-21 23:36:18.107830: W tensorflow/core/framework/op_kernel.cc:1745]...

I confirm it works on GPUs with XLA, and I got ~20% improved speedup. I'm still testing now on TPUs, will let you know ASAP

Weirdly enough TPUs didn't seem to care about the changes 😅 even after we removed all the if branches

I tried disabling `relative_attention` in deberta, which makes the model a regular BERT, and the performance improved 40x 😅

@Rocketknight1 I tried your suggestions without any success, sadly! Then I tried replacing the whole `take_along_axis` function with `tf.gather(..,...,batch_dims=2)` which is equivalent, according to this test I made. GPU still...