sagemaker-python-sdk icon indicating copy to clipboard operation
sagemaker-python-sdk copied to clipboard

SageMaker Estimator doesn't support checkpoint_s3_uri with Heterogeneous Clusters

Open brunopistone opened this issue 6 months ago • 0 comments

Describe the bug PyTorch estimator doesn't allow to setup a checkpoint_s3_uri when I'm working with an heterogeneous cluster, by returning the following error:

│ /Users/bpistone/miniforge3/envs/ray-env/lib/python3.12/site-packages/sagemaker/estimator.py:3646 │
│ in _validate_and_set_debugger_configs                                                            │
│                                                                                                  │
│   3643 │   │   │   │   │   │   "the debugger_hook_config is disabled."                           │
│   3644 │   │   │   │   │   )                                                                     │
│   3645 │   │   │   │   │   self.debugger_hook_config = False                                     │
│ ❱ 3646 │   │   │   │   elif self.instance_count > 1 or (                                         │
│   3647 │   │   │   │   │   hasattr(self, "distribution")                                         │
│   3648 │   │   │   │   │   and self.distribution is not None  # pylint: disable=no-member        │
│   3649 │   │   │   │   ):                                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: '>' not supported between instances of 'NoneType' and 'int'

To reproduce

from sagemaker.instance_group import InstanceGroup
from sagemaker.pytorch import PyTorch

instance_groups = [
    InstanceGroup(
        instance_group_name="head-instance-group",
        instance_type="ml.t3.xlarge",
        instance_count=1,
    ),
    InstanceGroup(
        instance_group_name="worker-instance-group",
        instance_type="ml.g5.xlarge",
        instance_count=4,
    ),
]

# define Training Job Name
job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft"

output_path = f"s3://{bucket_name}/{job_name}"

estimator = PyTorch(
    source_dir="./scripts",
    entry_point="launcher.py",
    output_path=output_path,
    base_job_name=job_name,
    role=role,
    instance_groups=instance_groups,
    max_run=432000,
    image_uri=image_uri,
    environment={
        "head_instance_group": "head-instance-group",
        "head_num_cpus": "0",
        "head_num_gpus": "0",
    },
    hyperparameters={
        "entrypoint": "train_ray.py",
        "config": "/opt/ml/input/data/config/args.yaml",  # path to TRL config which was uploaded to s3
    },
    enable_remote_debug=True,
    checkpoint_local_path="/opt/ml/checkpoints", 
    checkpoint_s3_uri=output_path + "/checkpoint", 
)

This error cannot be reproduced with ModelTrainer, due to an existing bug for Heterogeneous clusters and ModelTrainer reported in this issue https://github.com/aws/sagemaker-python-sdk/issues/5225

Expected behavior The estimator should be created and the training job should start with estimator.fit(inputs=data, wait=False)

Screenshots or logs If applicable, add screenshots or logs to help explain your problem.

System information A description of your system. Please provide:

  • SageMaker Python SDK version: 2.271.
  • Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch
  • Framework version: 2.6.0
  • Python version: 3.12
  • CPU or GPU: CPU and GPU
  • Custom Docker image (Y/N): N

Additional context Add any other context about the problem here.

brunopistone avatar Jul 01 '25 07:07 brunopistone