LocalAI icon indicating copy to clipboard operation
LocalAI copied to clipboard

Transformers backend supports mps

Open aotsukiqx opened this issue 8 months ago • 1 comments

**LocalAI version:2.16.0

Environment, CPU architecture, OS, and Version:

mac studio M2 Ultra Describe the bug

using backend transformers for glm4, trust_remote_code: true not correctly used by backend/python/transformers/backend.py To Reproduce

Expected behavior

Logs

Additional context

modified some lines for running without error

def LoadModel(self, request, context):
      """
      A gRPC method that loads a model into memory.

      Args:
          request: A LoadModelRequest object that contains the request parameters.
          context: A grpc.ServicerContext object that provides information about the RPC.

      Returns:
          A Result object that contains the result of the LoadModel operation.
      """
      model_name = request.Model
      print(f"request.Model: {request.Model}, request.TrustRemoteCode: {request.TrustRemoteCode}")

      compute = "auto"
      if request.F16Memory == True:
          compute=torch.bfloat16

      self.CUDA = request.CUDA
      self.OV=False

      # 检查是否支持 MPS
      if torch.backends.mps.is_available():
          device_map = "mps"
      else:
          print("MPS not supported, using CPU instead.", file=sys.stderr)
          device_map = "cpu"
...

     self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True,
                                                    trust_remote_code=request.TrustRemoteCode)
def Embedding(self, request, context):
        """
        A gRPC method that calculates embeddings for a given sentence.

        Args:
            request: An EmbeddingRequest object that contains the request parameters.
            context: A grpc.ServicerContext object that provides information about the RPC.

        Returns:
            An EmbeddingResult object that contains the calculated embeddings.
        """

        set_seed(request.Seed)
        # Tokenize input
        max_length = 512
        if request.Tokens != 0:
            max_length = request.Tokens
        encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")    

        # Create word embeddings
        if self.CUDA:
            encoded_input = encoded_input.to("cuda")

        with torch.no_grad():    
            model_output = self.model(**encoded_input)

        # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#        print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
#        print("Embeddings:", sentence_embeddings, file=sys.stderr)
        return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])

    async def _predict(self, request, context, streaming=False): 
        set_seed(request.Seed)
        if request.TopP == 0:
            request.TopP = 0.9
        
        if request.TopK == 0:
            request.TopK = 40

        prompt = request.Prompt
        if not request.Prompt and request.UseTokenizerTemplate and request.Messages:    
            prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)

        eos_token_id = self.tokenizer.eos_token_id
        if request.StopPrompts:
            eos_token_id = []
            for word in request.StopPrompts:
                eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))

        inputs = self.tokenizer(prompt, return_tensors="pt")

        if request.Tokens > 0:
            max_tokens = request.Tokens
        else:
            max_tokens = self.max_tokens - inputs["input_ids"].size()[inputs["input_ids"].dim()-1]

        if self.CUDA:
            inputs = inputs.to("cuda")
        if XPU and self.OV == False:
            inputs = inputs.to("xpu")
            streaming = False
        if torch.backends.mps.is_available():
            inputs = inputs.to("mps")
...

aotsukiqx avatar Jun 08 '24 06:06 aotsukiqx