sagemaker-python-sdk
sagemaker-python-sdk copied to clipboard
Setting `load_in_8bit=True` in `DJLModel` or `HuggingFaceAccelerateModel` forces you to set `dtype=int8` which is incorrect behaviour
Describe the bug
Setting load_in_8bit=True
in DJLModel
or HuggingFaceAccelerateModel
forces you to set dtype=int8
as well which is not correct behaviour.
To reproduce
import sagemaker
from sagemaker.djl_inference.model import DJLModel, HuggingFaceAccelerateModel
def deploy_djl(
endpoint_name: str,
s3_model_path: str,
num_gpus: int = 1,
initial_instance_count: int = 1,
instance_type: str = "ml.g5.2xlarge",
health_check_timeout: int = 600,
):
session, role = sagemaker.session.Session(), sagemaker.get_execution_role()
region = session._region_name
djl_model = HuggingFaceAccelerateModel(
model_id=s3_model_path,
role=role,
dtype='fp16',
load_in_8bit=True,
task="text-generation",
number_of_partitions=num_gpus,
)
return djl_model.deploy(
instance_type=instance_type,
initial_instance_count=initial_instance_count,
endpoint_name=endpoint_name,
container_startup_health_check_timeout=health_check_timeout
)
model_id="decapoda-research/llama-7b-hf"
predictor = deploy_djl(
endpoint_name="DJLEndpoint",
s3_model_path=model_id
)
Expected behavior
Should work as the standalone DJL-serving container works with equivalent serving.properties
file
engine=MPI
option.model_id=TheBloke/Llama-2-7B-fp16
option.entryPoint=djl_python.huggingface
option.tensor_parallel_degree=1
option.dtype=fp16
option.load_in_8bit=True
option.model_loading_timeout=900
Screenshots or logs
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/sagemaker/djl_inference/model.py:1092, in HuggingFaceAccelerateModel.generate_serving_properties(self, serving_properties)
1090 if self.load_in_8bit:
1091 if self.dtype != "int8":
-> 1092 raise ValueError("Set dtype='int8' to use load_in_8bit")
1093 serving_properties["option.load_in_8bit"] = self.load_in_8bit
1094 if self.dtype == "int8":
ValueError: Set dtype='int8' to use load_in_8bit
System information A description of your system. Please provide:
- SageMaker Python SDK version: 2.188.0
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch
- Framework version: 2.0.1
- Python version: 3.10
- CPU or GPU: GPU
- Custom Docker image (Y/N): N
Additional context N/A