sagemaker-python-sdk
sagemaker-python-sdk copied to clipboard
Passing PipelineVariable as hyperparameters for Framework Estimator fails
Describe the bug Generating SageMaker Pipeline definition fails if hyperparameters defined as PipelineVariable are passed to a training step.
Expected behavior
Should PipelineVariable(s) be supported for hyperparametes
?
In Estimator's constructor hyperparameters
are defined as: Optional[Dict[str, Union[str, sagemaker.workflow.entities.PipelineVariable]]] = None.
Screenshots or logs
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-5-86301c565ea6> in <cell line: 1>()
----> 1 pipeline.build_and_deploy(experiment_config=experiment_config)
~/projects/src/ml-scale/labs-MLTools_CLI/src/trlabs_mltools_cli/pipeline.py in build_and_deploy(self, experiment_config, parameters_override)
98 service responce : see https://sagemaker.readthedocs.io/en/stable/workflows/pipelines/sagemaker.workflow.pipelines.html#sagemaker.workflow.pipeline.Pipeline.upsert
99 """ # noqa
--> 100 pipeline = self.build(experiment_config=experiment_config, parameters_override=parameters_override)
101 logger.info(f"Deploying SageMaker Pipeline: {pipeline} ...")
102 return pipeline.upsert(self.experiment.workspace.iam_role)
~/projects/src/ml-scale/labs-MLTools_CLI/src/trlabs_mltools_cli/pipeline.py in build(self, experiment_config, parameters_override)
85 self.get_template(parameters_override), DeploymentContext.from_workspace(self.experiment.workspace)
86 ).build_pipeline(self._base_ppl_name(), experiment_config=ppl_experiment_config)
---> 87 logger.debug(f"Built SageMaker Pipeline: {pipeline.definition()}")
88 return pipeline
89
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/pipeline.py in definition(self)
319 def definition(self) -> str:
320 """Converts a request structure to string representation for workflow service calls."""
--> 321 request_dict = self.to_request()
322 self._interpolate_step_collection_name_in_depends_on(request_dict["Steps"])
323 request_dict["PipelineExperimentConfig"] = interpolate(
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/pipeline.py in to_request(self)
103 if self.pipeline_experiment_config is not None
104 else None,
--> 105 "Steps": list_to_request(self.steps),
106 }
107
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/utilities.py in list_to_request(entities)
51 for entity in entities:
52 if isinstance(entity, Entity):
---> 53 request_dicts.append(entity.to_request())
54 elif isinstance(entity, StepCollection):
55 request_dicts.extend(entity.request_dicts())
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in to_request(self)
497 def to_request(self) -> RequestType:
498 """Updates the request dictionary with cache configuration."""
--> 499 request_dict = super().to_request()
500 if self.cache_config:
501 request_dict.update(self.cache_config.config)
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in to_request(self)
349 def to_request(self) -> RequestType:
350 """Gets the request structure for `ConfigurableRetryStep`."""
--> 351 step_dict = super().to_request()
352 if self.retry_policies:
353 step_dict["RetryPolicies"] = self._resolve_retry_policy(self.retry_policies)
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in to_request(self)
118 "Name": self.name,
119 "Type": self.step_type.value,
--> 120 "Arguments": self.arguments,
121 }
122 if self.depends_on:
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in arguments(self)
476 request_dict = self.step_args
477 else:
--> 478 self.estimator._prepare_for_training(self.job_name)
479 train_args = _TrainingJob._get_train_args(
480 self.estimator, self.inputs, experiment_config=dict()
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/trlabs_mltools/sagemaker/factory.py in _prepare_for_training(self, job_name)
91 )
92
---> 93 super()._prepare_for_training(job_name=job_name)
94
95
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/estimator.py in _prepare_for_training(self, job_name)
2880 constructor if applicable.
2881 """
-> 2882 super(Framework, self)._prepare_for_training(job_name=job_name)
2883
2884 self._validate_and_set_debugger_configs()
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/estimator.py in _prepare_for_training(self, job_name)
705 # Modify hyperparameters in-place to point to the right code directory and
706 # script URIs
--> 707 self._script_mode_hyperparam_update(code_dir, script)
708
709 self._prepare_rules()
~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/estimator.py in _script_mode_hyperparam_update(self, code_dir, script)
2898 hyperparams[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
2899
-> 2900 self._hyperparameters.update(hyperparams)
2901
2902 def _validate_and_set_debugger_configs(self):
AttributeError: 'ParameterString' object has no attribute 'update'
System information A description of your system. Please provide:
- SageMaker Python SDK version: 2.107.0:
- PyTorch Estimator:
Hi @AndreiVoinovTR , thanks for using SageMaker Pipeline! Can you also provide us the code snippet on how you defined hyperparameters as PipelineVariable so that we can reproduce the issue?
Expected behavior Should PipelineVariable(s) be supported for hyperparametes? In Estimator's constructor hyperparameters are defined as: Optional[Dict[str, Union[str, sagemaker.workflow.entities.PipelineVariable]]] = None.
Currently PipelineVariable only supports to replace Python primitive types, e.g. str, int, fload, bool. Given the expected behavior above, seems this issue is a feature request asking to make PipelineVariable support dict and list types. Relabeling this issue to feature request and we have a backlog item already to keep track on this.