ray icon indicating copy to clipboard operation
ray copied to clipboard

[Checkpoint: AIR] Saved checkpoints folders does not include correct training iteration number.

Open n30111 opened this issue 2 years ago • 8 comments

What happened + What you expected to happen

While enabling the frequency parameter In the Keras Callback (from ray.air.callbacks.keras import Callback), the checkpoints folder does not include the correct training iteration number.

If we set frequency=1, then the checkpoints follow the naming convention checkpoint_{(iteration-1):06d}, but if we set frequency>1, the saved checkpoint folder does not have any info about the iteration number, and the checkpoints are saved with consecutive folder naming convention. This is because of the way checkpoints folder are created here : https://github.com/ray-project/ray/blob/master/python/ray/train/_internal/checkpoint.py#L228 . As it simply increment the self._latest_checkpoint_id without considering the frequency parameter.

While using frequency=1

# outputs checkpoints: ['checkpoint_000002', 'checkpoint_000004', 'checkpoint_000000', 'checkpoint_000003', 'checkpoint_000001']

While using frequency=2

# output checkpoints: ['checkpoint_000000', 'checkpoint_000001']

But ideally these numbering should be ['checkpoint_000002', 'checkpoint_000004']

Versions / Dependencies

2.0.0

Reproduction script

Following script which is a minor modification of the test: https://github.com/ray-project/ray/blob/releases/2.0.0/python/ray/air/tests/test_keras_callback.py can be used to reproduce the bug.

from pathlib import Path

import numpy as np
import tensorflow as tf

import ray
from ray.air import session
from ray.air.callbacks.keras import Callback
from ray.air.constants import MODEL_KEY
from ray.train.constants import TRAIN_DATASET_KEY
from ray.air.config import RunConfig, ScalingConfig
from ray.train.tensorflow import (
    TensorflowTrainer,
    prepare_dataset_shard,
    TensorflowPredictor,
)


def get_dataset(a=5, b=10, size=1000):
    items = [i / size for i in range(size)]
    dataset = ray.data.from_items([{"x": x, "y": a * x + b} for x in items])
    return dataset


def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=()),
            # Add feature dimension, expanding (batch_size,) to (batch_size, 1).
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Dense(1),
        ]
    )
    return model


def train_func(config: dict, ckpt_freq=1):
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_model()
        multi_worker_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
            loss=tf.keras.losses.mean_squared_error,
            metrics=[tf.keras.metrics.mean_squared_error],
        )

    dataset = session.get_dataset_shard("train")

    def to_tf_dataset(dataset, batch_size):
        def to_tensor_iterator():
            for batch in dataset.iter_tf_batches(
                batch_size=batch_size, dtypes=tf.float32
            ):
                yield batch["x"], batch["y"]

        output_signature = (
            tf.TensorSpec(shape=(None), dtype=tf.float32),
            tf.TensorSpec(shape=(None), dtype=tf.float32),
        )
        tf_dataset = tf.data.Dataset.from_generator(
            to_tensor_iterator, output_signature=output_signature
        )
        return prepare_dataset_shard(tf_dataset)

    tf_dataset = to_tf_dataset(dataset=dataset, batch_size=32)
    multi_worker_model.fit(tf_dataset, callbacks=[Callback(frequency=ckpt_freq)], epochs=config["epochs"])


def test_keras_callback_e2e(ckpt_freq=1):
    epochs = 5
    config = {
        "epochs": epochs,
    }
    import tempfile
    tempdir = tempfile.TemporaryDirectory().name
    print(tempdir)
    trainer = TensorflowTrainer(
        train_loop_per_worker=lambda config: train_func(config, ckpt_freq=ckpt_freq),
        train_loop_config=config,
        scaling_config=ScalingConfig(num_workers=2),
        datasets={TRAIN_DATASET_KEY: get_dataset()},
        run_config=RunConfig(local_dir=tempdir)
    )
    checkpoint = trainer.fit().checkpoint
    base_path = Path(checkpoint._local_path).parent
    ckpts = [ckpt_dir.name for ckpt_dir in base_path.iterdir() if "checkpoint_00000" in str(ckpt_dir)]
    return ckpts

Issue Severity

High: It blocks me from completing my task.

n30111 avatar Oct 19 '22 13:10 n30111

Thanks for reporting the issue @n30111! Indeed, definitely something we should fix.

I think we should switch to using the Tune Session API internally instead of tune.checkpoint_dir, and then on the Tune side, it can fill in the checkpoint step the training_iteration in the corresponding metrics. cc @xwjiang2010 @Yard1

amogkam avatar Oct 19 '22 20:10 amogkam

There is the same issue for HuggingfaceTrainer, when using steps for saving frequency, like 1000 steps, the first checkpoint is checkpoint 00000, not checkpoint1000.

dumpmemory avatar Oct 20 '22 03:10 dumpmemory

How is this impacting workloads, aside from the Keras callback not saving the epoch? As far as I understand, the most important thing is that we have an incremental counter for checkpoints. The actual epoch/iteration number should be saved inside the checkpoint itself (which is indeed the case with Huggingface, but not with the Keras callback).

Yard1 avatar Oct 26 '22 19:10 Yard1

How is this impacting workloads, aside from the Keras callback not saving the epoch? As far as I understand, the most important thing is that we have an incremental counter for checkpoints. The actual epoch/iteration number should be saved inside the checkpoint itself (which is indeed the case with Huggingface, but not with the Keras callback).

But keep the checkpoint number consistent with Huggingface checkpoint number will be more connivence for managing checkpoints

dumpmemory avatar Oct 27 '22 06:10 dumpmemory

@amogkam not exactly sure that I followed. How does Tune Session know about the specific application details (freq etc)?

xwjiang2010 avatar Nov 03 '22 22:11 xwjiang2010

I haven't set checkpoint_frequency in CheckpointConfig

dumpmemory avatar Nov 04 '22 04:11 dumpmemory

@amogkam any update on this issue?

n30111 avatar Jan 09 '23 10:01 n30111

@justinvyu does #36220 resolve this?

anyscalesam avatar May 15 '24 18:05 anyscalesam