sparktorch icon indicating copy to clipboard operation
sparktorch copied to clipboard

Problem when using a pretrained Hugginface-Transformer model: "Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead"

Open cperezmig opened this issue 4 years ago • 1 comments

Hi, maybe it's not a bug and it's just that I am missing something when using this library. I am trying to load a pretrained DistilBert model from the hugginface-transformer project. This model is supposed to be a nn.module as required by the method "create_spark_torch_method". When I try to use it in a pipeline, I obtain the following error: "Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead"

This is my code in case you want to reproduce de problem:

from transformers import DistilBertTokenizer, DistilBertModel
from pyspark.ml.linalg import VectorUDT, Vectors
from sparktorch import create_spark_torch_model
from pyspark.sql.functions import udf
from pyspark.ml.pipeline import Pipeline
textos = [('El perro de San Roque no tiene rabo, por que ramón ramírez se lo ha robado',),
        ('Tres tristes tigres comen trigo en un trigal',),
        ('Esto es otra frase aleatoria que nada tiene que ver con las dos anteriores',)]
df = spark.createDataFrame(textos, ['t'])
modelName='distilbert-base-multilingual-cased'
tokenizer = DistilBertTokenizer.from_pretrained(modelName)
model = DistilBertModel.from_pretrained(modelName)
tokeniza = udf(lambda x: Vectors.dense(tokenizer.encode(x, add_special_tokens=False, max_length=512, return_tensors='pt')[0]), VectorUDT())
df = df.withColumn('features', tokeniza(df.t))
df.show()
spark_torch_model = create_spark_torch_model(model, inputCol='features', predictionCol='bertEmb', useVectorOut=True)
p = Pipeline(stages=[spark_torch_model]).fit(df)
df = p.transform(df)
df.show()

Am I missing something? Is it a bug?

Thank you

cperezmig avatar Jun 01 '20 09:06 cperezmig

SparkTorch casts all input to float tensor, see: https://github.com/dmmiller612/sparktorch/blob/master/sparktorch/torch_distributed.py#L114

It seems that the only way is to change the source code. You could hard-code a .long() instead the .float() to use with the bert model. Better yet, you could add a parameter when creating the model to indicate the input type.

piekill avatar Aug 20 '20 04:08 piekill