djl
djl copied to clipboard
Ensure Batchifier returns compatible arrays for hybrid engines like OnnxRuntime
Addresses issue raised from https://deepjavalibrary.slack.com/archives/C01AURG857U/p1724804411722599.
Essentially batchifying / unbatchifying OrtNDArrays produce PtNDArrays, and this causes downstream issues, particularly in OrtSymbolBlock:114:
OrtSession.Result results = session.run(container);
NDList ret = evaluateOutput(results);
ret.attach(inputs.head().getManager());
Because of the batchify call during Translator processInput() call, the initial OrtNDArrays end up as PtNDArrays, and then the ret.attach call incorrectly sets the ret NDList's main (and alternate manager) both to PtNDManager. This inhibits a lot of post-processing code from working as the returned NDArrays have the wrong manager set.
Maybe there is a more elegant way of fixing this within OrtSymbolBlock, but I couldn't see any easy way of passing in the Predict NDManager in.