sagemaker-python-sdk icon indicating copy to clipboard operation
sagemaker-python-sdk copied to clipboard

Setting `load_in_8bit=True` in `DJLModel` or `HuggingFaceAccelerateModel` forces you to set `dtype=int8` which is incorrect behaviour

Open maaquib opened this issue 1 year ago • 0 comments

Describe the bug Setting load_in_8bit=True in DJLModel or HuggingFaceAccelerateModel forces you to set dtype=int8 as well which is not correct behaviour.

To reproduce

import sagemaker
from sagemaker.djl_inference.model import DJLModel, HuggingFaceAccelerateModel

def deploy_djl(
    endpoint_name: str,
    s3_model_path: str,
    num_gpus: int = 1,
    initial_instance_count: int = 1,
    instance_type: str = "ml.g5.2xlarge",
    health_check_timeout: int = 600,
):
    session, role = sagemaker.session.Session(), sagemaker.get_execution_role()
    region = session._region_name

    djl_model = HuggingFaceAccelerateModel(
        model_id=s3_model_path,
        role=role,
        dtype='fp16',
        load_in_8bit=True,
        task="text-generation",
        number_of_partitions=num_gpus,
    )

    return djl_model.deploy(
        instance_type=instance_type,
        initial_instance_count=initial_instance_count,
        endpoint_name=endpoint_name,
        container_startup_health_check_timeout=health_check_timeout
    )


model_id="decapoda-research/llama-7b-hf"

predictor = deploy_djl(
    endpoint_name="DJLEndpoint",
    s3_model_path=model_id
)

Expected behavior Should work as the standalone DJL-serving container works with equivalent serving.properties file

engine=MPI
option.model_id=TheBloke/Llama-2-7B-fp16
option.entryPoint=djl_python.huggingface
option.tensor_parallel_degree=1
option.dtype=fp16
option.load_in_8bit=True
option.model_loading_timeout=900

Screenshots or logs

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/sagemaker/djl_inference/model.py:1092, in HuggingFaceAccelerateModel.generate_serving_properties(self, serving_properties)
   1090 if self.load_in_8bit:
   1091     if self.dtype != "int8":
-> 1092         raise ValueError("Set dtype='int8' to use load_in_8bit")
   1093     serving_properties["option.load_in_8bit"] = self.load_in_8bit
   1094 if self.dtype == "int8":

ValueError: Set dtype='int8' to use load_in_8bit

System information A description of your system. Please provide:

  • SageMaker Python SDK version: 2.188.0
  • Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch
  • Framework version: 2.0.1
  • Python version: 3.10
  • CPU or GPU: GPU
  • Custom Docker image (Y/N): N

Additional context N/A

maaquib avatar Oct 27 '23 22:10 maaquib