mleap icon indicating copy to clipboard operation
mleap copied to clipboard

Tensor to Proto Bug with SparseTensor: " java.lang.IllegalArgumentException: size of dimensions must equals size of values"

Open austinzh opened this issue 2 years ago • 0 comments

Call Stacks

  java.lang.IllegalArgumentException: size of dimensions must equals size of values
  at ml.combust.mleap.tensor.Tensor$.normalizeDimensions(Tensor.scala:63)
  at ml.combust.mleap.tensor.Tensor$.create(Tensor.scala:33)
  at ml.combust.bundle.tensor.TensorSerializer$.fromProto(TensorSerializer.scala:74)
  at ml.combust.bundle.dsl.Value.getTensor(Value.scala:323)

Possible cause In ml.combust.bundle.tensor.TensorSerializer$.toProto, we save rawValue, But in ml.combust.bundle.tensor.TensorSerializer$.fromProto we load it as DenseTensor. the size of SparseVector rawValue is much smaller array, so it cause this error. I suggest we separate SparseTensor and DenseTensor

  def toProto[T](t: tensor.Tensor[T]): Tensor = {
    val (tpe, values) = t.base.runtimeClass match {
      case tensor.Tensor.BooleanClass =>
        (BasicType.BOOLEAN, BooleanArraySerializer.write(t.rawValues.asInstanceOf[Array[Boolean]]))
      case tensor.Tensor.ByteClass =>
        (BasicType.BYTE, ByteArraySerializer.write(t.rawValues.asInstanceOf[Array[Byte]]))
      case tensor.Tensor.ShortClass =>
        (BasicType.SHORT, ShortArraySerializer.write(t.rawValues.asInstanceOf[Array[Short]]))
      case tensor.Tensor.IntClass =>
        (BasicType.INT, IntArraySerializer.write(t.rawValues.asInstanceOf[Array[Int]]))
      case tensor.Tensor.LongClass =>
        (BasicType.LONG, LongArraySerializer.write(t.rawValues.asInstanceOf[Array[Long]]))
      case tensor.Tensor.FloatClass =>
        (BasicType.FLOAT, FloatArraySerializer.write(t.rawValues.asInstanceOf[Array[Float]]))
      case tensor.Tensor.DoubleClass =>
        (BasicType.DOUBLE, DoubleArraySerializer.write(t.rawValues.asInstanceOf[Array[Double]]))
      case tensor.Tensor.StringClass =>
        (BasicType.STRING, StringArraySerializer.write(t.rawValues.asInstanceOf[Array[String]]))
      case tensor.Tensor.ByteStringClass =>
        (BasicType.BYTE_STRING, ByteStringArraySerializer.write(t.rawValues.asInstanceOf[Array[ByteString]]))
      case _ => throw new IllegalArgumentException(s"unsupported tensor type ${t.base}")
    }
  def fromProto[T](t: Tensor): tensor.Tensor[T] = {
    val dimensions = t.shape.get.dimensions.map(_.size)
    val valueBytes = t.value.toByteArray

    val tn = t.base match {
      case BasicType.BOOLEAN =>
        tensor.Tensor.create(BooleanArraySerializer.read(valueBytes), dimensions)
      case BasicType.BYTE =>
        tensor.Tensor.create(ByteArraySerializer.read(valueBytes), dimensions)
      case BasicType.SHORT =>
        tensor.Tensor.create(ShortArraySerializer.read(valueBytes), dimensions)
      case BasicType.INT =>
        tensor.Tensor.create(IntArraySerializer.read(valueBytes), dimensions)
      case BasicType.LONG =>
        tensor.Tensor.create(LongArraySerializer.read(valueBytes), dimensions)
      case BasicType.FLOAT =>
        tensor.Tensor.create(FloatArraySerializer.read(valueBytes), dimensions)
      case BasicType.DOUBLE =>
        tensor.Tensor.create(DoubleArraySerializer.read(valueBytes), dimensions)
      case BasicType.STRING =>
        tensor.Tensor.create(StringArraySerializer.read(valueBytes), dimensions)
      case BasicType.BYTE_STRING =>
        tensor.Tensor.create(ByteStringArraySerializer.read(valueBytes), dimensions)
      case _ => throw new IllegalArgumentException(s"unsupported tensor type ${t.base}")
    }

austinzh avatar Jun 27 '23 21:06 austinzh