kotlindl icon indicating copy to clipboard operation
kotlindl copied to clipboard

Type cast error when inferencing onnx model

Open gasabr opened this issue 3 years ago • 5 comments

I'm trying to inference ONNX model created from lightgbm model via Kotlin DL and in every method (tried Raw ones too) i'm getting class [J cannot be cast to class [[F ([J and [[F are in module java.base of loader 'bootstrap') or in RawMethods SequenceInfo cannot be cast to class ai.onnxruntime.TensorInfo

Env:

    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.3.0")
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.3.0-alpha-3")

Code twoTierModel.txt

package co.`fun`

import kotlin.test.*
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import kotlin.random.Random

class ApplicationTest {
    @Test
    fun testRoot() 
{
        val onnxModel = OnnxInferenceModel.load("/tmp/twoTierModel.onnx")
        onnxModel.reshape(27)
        val features = (1..27).map { Random.nextFloat() }.toFloatArray()
        val prediction = onnxModel.predictSoftly(features, "features")
    }
}

Error is in the line 124 of the file OnnxInferenceModel.kt and it's caused by the attempt to cast List<INT64> to Array<FloatArray>, I'm not sure if the model should always return 3d Tensor or the lib should check the types.

Rename the attachment to twoTierModel.onnx to try the test at your machine

gasabr avatar Nov 19 '21 13:11 gasabr

Please, add the attachment @gasabr. Could you try to repeat the experiment with the latest version of the onnx-dependency, here

    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.3.0")
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.3.0")

zaleslaw avatar Nov 19 '21 14:11 zaleslaw

twoTierModel.txt

Thanks for the response! I tried 0.3.0 got the same exception, I also have tried to use java onnxruntime and was able to inference the model with following code:

        val env = OrtEnvironment.getEnvironment()
        val session = env.createSession("/tmp/twoTierModel.onnx", OrtSession.SessionOptions())
        val features = (1..27).map { Random.nextFloat() }.toFloatArray()
        val buf = FloatBuffer.wrap(features)
        val t1 = OnnxTensor.createTensor(env, buf, longArrayOf(1, 27))
        val inputs = mapOf<String, OnnxTensor>("features" to t1)
        val result = session.run(inputs, setOf("probabilities"))[0].value as ArrayList<HashMap<Long, Float>>
        println(result)

gasabr avatar Nov 22 '21 07:11 gasabr

Thanks, @gasabr thanks, for the example you gave me, and I hope to fix it in the 0.4 release to cover more cases, but at this moment java onnxruntime is the best choice for you, I agree

zaleslaw avatar Nov 22 '21 08:11 zaleslaw

I want to discuss a couple of things.

Onnx supports multiple output types such as tensors, sequence of numbers (or strings), sequence of maps, and a map. At first glance, it seems that it is possible to decode every type of input to appropriate Kotlin's data structure using OnnxModel metadata. But I have some doubts:

  • I am not 100% sure that it's possible in every scenario.
  • We need to implement some dynamic casting mechanism if we want to return the appropriate data structure to the user without casts on the user's side.
Dirty draft of decoding function
private fun decodeOnnxOutput(onnxOutput: OrtSession.Result) : Map<String, Any> {
  val keys = onnxOutput.map { it.key }

  return keys.associateWith { key ->
      {
          if (key !in this.session.outputInfo) throw RuntimeException()

          when (val info = this.session.outputInfo[key]!!.info) {
              is TensorInfo -> {
                  val tensor = onnxOutput.get(key).get().value
                  if (info.shape.size == 1) {
                      when (info.type) {
                          OnnxJavaType.FLOAT -> tensor as FloatArray
                          OnnxJavaType.DOUBLE -> tensor as DoubleArray
                          OnnxJavaType.INT8 -> tensor as ByteArray
                          OnnxJavaType.INT16 -> tensor as ShortArray
                          OnnxJavaType.INT32 -> tensor as IntArray
                          OnnxJavaType.INT64 -> tensor as LongArray
                          OnnxJavaType.UINT8 -> tensor as UByteArray
                          else -> throw RuntimeException()
                      }
                  } else {
                      when (info.type) {
                          OnnxJavaType.FLOAT -> tensor as Array<FloatArray>
                          OnnxJavaType.DOUBLE -> tensor as Array<DoubleArray>
                          OnnxJavaType.INT8 -> tensor as Array<ByteArray>
                          OnnxJavaType.INT16 -> tensor as Array<ShortArray>
                          OnnxJavaType.INT32 -> tensor as Array<IntArray>
                          OnnxJavaType.INT64 -> tensor as Array<LongArray>
                          OnnxJavaType.UINT8 -> tensor as Array<UByteArray>
                          else -> throw RuntimeException()
                      }
                  }
              }
              is SequenceInfo -> {
                  val elements = onnxOutput.get(key).get().value as List<Objects>
                  if (info.sequenceOfMaps) {
                      elements.map {
                          when (info.mapInfo.keyType to info.mapInfo.valueType) {
                              OnnxJavaType.INT64 to OnnxJavaType.FLOAT -> it as HashMap<Long, Float>
                              OnnxJavaType.STRING to OnnxJavaType.FLOAT -> it as HashMap<String, Float>
                              else -> throw RuntimeException()
                          }
                      }
                  } else {
                      when (info.sequenceType) {
                          OnnxJavaType.FLOAT -> elements as List<Float>
                          OnnxJavaType.DOUBLE -> elements as List<Double>
                          OnnxJavaType.INT64 -> elements as List<Long>
                          OnnxJavaType.STRING -> elements as List<String>
                          else -> throw RuntimeException()
                      }
                  }
              }
              is MapInfo -> {
                  val map_ = onnxOutput.get(key).get().value as OnnxMap
                  when (info.keyType) {
                      OnnxJavaType.INT64 -> when (info.valueType) {
                          OnnxJavaType.FLOAT -> map_ as HashMap<Long, Float>
                          OnnxJavaType.DOUBLE -> map_ as HashMap<Long, Double>
                          OnnxJavaType.INT64 -> map_ as HashMap<Long, Long>
                          OnnxJavaType.STRING -> map_ as HashMap<Long, String>
                          else -> throw RuntimeException()
                      }
                      OnnxJavaType.STRING -> when (info.valueType) {
                          OnnxJavaType.FLOAT -> map_ as HashMap<String, Float>
                          OnnxJavaType.DOUBLE -> map_ as HashMap<String, Double>
                          OnnxJavaType.INT64 -> map_ as HashMap<String, Long>
                          OnnxJavaType.STRING -> map_ as HashMap<String, String>
                          else -> throw RuntimeException()
                      }
                      else -> throw RuntimeException()
                  }
              }
              else -> throw RuntimeException()
          }
      }
  }
}

ermolenkodev avatar Apr 15 '22 12:04 ermolenkodev

Another thing I want to discuss. For me, it seems reasonable if OnnxInferenceModel's methods predict and predictSoftly will be refactored out into more specific implementation class (like ClassificationOnnxInferenceModel).

It may be handy if OnnxInferenceModel will work with arbitrary tensors. Meanwhile, classes targeted for specific DL tasks (such as detection or segmentation) can use OnnxInferenceModel internally and format output for a specific task.

ermolenkodev avatar Apr 15 '22 12:04 ermolenkodev