coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Autocasting datatypes

Open vyeevani opened this issue 3 years ago • 2 comments
trafficstars

When trying to get publicly available models to convert with coremltools, it would be very beneficial if there was auto type casting with promotable types and the common linear operators. The infrastructure appears to be in place to handle this with cast operators and type promotion checking.

It would seem that modifying the Operator base classes type verification in the init would be sufficient (cast all the inputs to a particular operator to the promotable type if the base types aren't compatible with the operation.

vyeevani avatar Sep 20 '22 16:09 vyeevani

Spoke with @TobyRoseman offline.

Adding the conversion script where I saw the problem:

import coremltools as ct
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-multi")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-multi", torchscript=True)

input_sample = "sample"

input_ids = tokenizer.encode(input_sample, return_tensors="pt")

input_ids = input_ids.type(torch.IntTensor)

traced_model = torch.jit.trace(model, input_ids)

mlmodel = ct.convert(traced_model, source="pytorch", inputs=[ct.TensorType(shape=input_ids.shape)], convert_to='mlprogram')

vyeevani avatar Sep 20 '22 17:09 vyeevani

Please note this requires that the build_einsum_val has the following changed. This gets around: https://github.com/apple/coremltools/issues/1359. The fix from: https://github.com/apple/coremltools/issues/1604 is also needed to be able to repro this.

elif parsed_vectors == ([0], [1], [0, 1]):
    x_1 = mb.reshape(x=a_var, shape=[a_var.shape[0], 1])
    x_1 = mb.cast(x=x_1, dtype="fp32")
    x_2 = mb.reshape(x=b_var, shape=[b_var.shape[0], 1])
    print(x_1, x_2)
    #print(types.promote_types(x_1, x_2))
    x = mb.matmul(x=x_1, y=x_2, transpose_x=False, transpose_y=True, name=name)
else:
    raise NotImplementedError(
        "Einsum unsupported equation format: ", equation, a_var.shape, b_var.shape, parsed_vectors, parsed_vectors_rev
    )
return x

vyeevani avatar Sep 20 '22 18:09 vyeevani