maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Supported features

Open peregilk opened this issue 10 months ago • 21 comments

Mainly wanted to start with thanking you for making MaxText available. I have been using it for a few days, and the first impression is fantastic. Getting started was really easy, it seemed very stable, and the performance was fantastic. It seems to scale very nicely.

A few things that I have not been able to figure out yet, it might be because of lack of documentation, or simply because it is not implemented.

  • Are there any support for Flash attention, or any plans for implementing this? This has been a major area where GPUs have been ahead of TPUs. I have noticed that there now is at least an experimental implementation from the Jax-team: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py.

  • Training directly from tfds seemed straight forward. However, I was a bit confused about how to implement more advanced data loader features, for instance probability sampling like explained here. This can be somewhat tricky to do efficiently on multiple tpus. What is the sensible approach here? Manually sampling into a tfds dataset does not seem very efficient. Are there external libraries here that are compatible with maxtext?

  • Are there plans for implementing DPO/RLHF?

I also shamelessly wanted to point you to my own repo: https://github.com/peregilk/ttconnect. It is a very simple bash script that ideally should be run on a VM in the same zone. It automatically opens up synchronised tmux windows to all the VMs in the pod, and allows you to type the same command into all the VMs. This makes it even easier to go from one tpu to pods.

peregilk avatar Mar 30 '24 13:03 peregilk

Thank you for the comments!

(1) Fused attention is on by default for training! We use "splash attention" which is a custom and faster version! (And we're working on accelerated inference attentions.) (2) We don't implement more advanced data loaders though I think they can be implemented in TFDS. It is also easy to plug in your own data loader. Is there a specific data loading solution you'd like us to use? (3) Yes, DPO is underway!

ttconnect is super cool, thanks for sending!

rwitten avatar Mar 31 '24 00:03 rwitten

Thanks for the answer. Looking forward to the DPO support.

It would of course be fantastic if the HuggingFace datasets could natively be supported. I have never really been able to run large non-streaming datasets from HF on the TPUs (disk-size issues on the VMs), but we have been able to wrap the HF datasets in torch.split_dataset_by_node, to stream on multiple TPUs. Im not sure if I am able to implement something like this into MaxText though. Not really sure on what level it should be implemented.

Any chance you support HF datasets in the future?

But any way of preprocessing the data before it is split to the TPUs would be extremely useful for running experiments on dataset building. Thats both for sampling or filtering based on a field in the dataset.

peregilk avatar Apr 01 '24 13:04 peregilk

Yes support for HF datasets in MaxText is on the way @aireenmei

A9isha avatar May 06 '24 19:05 A9isha

Thank you for tagging me on this. Yes, supporting HuggingFace dataset is in our plan. We have some implementations and are undergoing some perf evaluations to understand it better. I will update here when we have it out.

aireenmei avatar May 06 '24 20:05 aireenmei

Hi @peregilk , HuugingFace dataset is supported now. Please check out https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md.

aireenmei avatar May 21 '24 20:05 aireenmei

Really fantastic! Makes it a lot more convenient. Especially reading jsonlines from the buckets looks great. Do you support all native HF? Like jsonl.gz?

peregilk avatar May 21 '24 20:05 peregilk

Yes, jsonl.gz is supported, as well as other formats supported by datasets.load_dataset (https://huggingface.co/docs/datasets/en/loading)

aireenmei avatar May 21 '24 21:05 aireenmei

@aireenmei Is there a more detailed documentation here. I was for instance unable to figure out how to specify the validation set.

peregilk avatar Jun 19 '24 06:06 peregilk

Hi @peregilk, a specific validation set is not supported yet. But this is in our list of items to be worked on.

aireenmei avatar Jun 24 '24 17:06 aireenmei

Hi @peregilk , eval is supported now https://github.com/google/maxtext/pull/738

aireenmei avatar Jul 02 '24 12:07 aireenmei

@aireenmei Thanks a lot. Really looking forward to testing this.

Since this seems to be very related, I am reporting here. Can open an issue if you like:

I am training with:

hf_data_files='gs://mybucket/mydir/train*.jsonl'

There are 256 files in the directory. Close to the end of the first epoch one of the workers throws this error in maxtext/MaxText/input_pipeline/_input_pipeline_utils.py", line 95, in __getitem__:

The above exception was the direct cause of the following exception:

ValueError: Run out of shards, shard 259 is not available

peregilk avatar Jul 03 '24 14:07 peregilk

Hi @peregilk , this should be the expected behavior. With the current implementation, you may not be able to use all the data in your train files. Say that you have 256 files and you are using v4-64 that has 8 hosts. Each host will read 256/8=32 shards. The i-th host will read the (8*x + i)-th shard (0<=x<32). For exp, host 0 reads shard 0, 8, 16, ..., 248; host 7 reads shard 7, 15, ..., 255 etc. When a host finish their current shard, they move to the next shard assigned to them. But since each shard has slightly different number of examples, the training will stop when the one of the hosts run out of data. For the above exp, if host 0 is the first one to finish it's last shard, 248, it will look for shard 248+8=256, which is not available, and it will results in the error you see.

aireenmei avatar Jul 03 '24 18:07 aireenmei

Thank @aireenmei. Not sure I understand though. Why would not the logical behaviour here be simply to restart on the first shard that was given to the host when there are no more shards available? Alternatively you would have to duplicate your dataset for training more than one epoch, right?

peregilk avatar Jul 03 '24 19:07 peregilk

I did not implement the auto restart because some users may not want their model to see repetitive data. I can add the multi-epoch support to our backlog. Meanwhile it should be straightfoward to change the shard update logic here: https://github.com/google/maxtext/blob/main/MaxText/input_pipeline/_input_pipeline_utils.py#L105

aireenmei avatar Jul 03 '24 23:07 aireenmei

OK. Makes sense. Thanks.

peregilk avatar Jul 04 '24 08:07 peregilk

@aireenmei I have tried using your validation support for hf-datasets. I am seeing the same issue here, setting hf_eval_files. Even if the number of shard are dividable by the number of the number of workers, it still crashes asking for the next shard. I cant see any way to limit the number of eval steps, so that it does not run out of shards. What am I missing?

peregilk avatar Jul 12 '24 16:07 peregilk

Hi @peregilk indeed this is a bug. I will fix it. Meanwhile this flag (https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml#L336) controls eval step that you can use for now, I'll rename it to eval_steps in my next pr for clarity.

aireenmei avatar Jul 15 '24 19:07 aireenmei

any update on dpo?

Mddct avatar Jul 19 '24 11:07 Mddct