sparktorch
sparktorch copied to clipboard
Problem when using a pretrained Hugginface-Transformer model: "Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead"
Hi, maybe it's not a bug and it's just that I am missing something when using this library. I am trying to load a pretrained DistilBert model from the hugginface-transformer project. This model is supposed to be a nn.module as required by the method "create_spark_torch_method". When I try to use it in a pipeline, I obtain the following error: "Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead"
This is my code in case you want to reproduce de problem:
from transformers import DistilBertTokenizer, DistilBertModel
from pyspark.ml.linalg import VectorUDT, Vectors
from sparktorch import create_spark_torch_model
from pyspark.sql.functions import udf
from pyspark.ml.pipeline import Pipeline
textos = [('El perro de San Roque no tiene rabo, por que ramón ramírez se lo ha robado',),
('Tres tristes tigres comen trigo en un trigal',),
('Esto es otra frase aleatoria que nada tiene que ver con las dos anteriores',)]
df = spark.createDataFrame(textos, ['t'])
modelName='distilbert-base-multilingual-cased'
tokenizer = DistilBertTokenizer.from_pretrained(modelName)
model = DistilBertModel.from_pretrained(modelName)
tokeniza = udf(lambda x: Vectors.dense(tokenizer.encode(x, add_special_tokens=False, max_length=512, return_tensors='pt')[0]), VectorUDT())
df = df.withColumn('features', tokeniza(df.t))
df.show()
spark_torch_model = create_spark_torch_model(model, inputCol='features', predictionCol='bertEmb', useVectorOut=True)
p = Pipeline(stages=[spark_torch_model]).fit(df)
df = p.transform(df)
df.show()
Am I missing something? Is it a bug?
Thank you
SparkTorch casts all input to float tensor, see: https://github.com/dmmiller612/sparktorch/blob/master/sparktorch/torch_distributed.py#L114
It seems that the only way is to change the source code. You could hard-code a .long()
instead the .float()
to use with the bert model. Better yet, you could add a parameter when creating the model to indicate the input type.