nntrainer icon indicating copy to clipboard operation
nntrainer copied to clipboard

Exporter for TFLite

Open jijoongmoon opened this issue 2 years ago • 12 comments

  • NNTrainer needs to export the trained model to TfLite.
  • Currently, we have PoC Code for Fullyconnected layer, but we need to developer tflite interpreter to support more layers.
  • First target of this issue is serialize the Resnet 18 Model from NNTrainer to TfLite.
  • To-Do lists are below,
    • [x] Move the hard-coded function to get the builtin options for Fully Connected #1877
    • [x] Batch Normalization layer Remove Realizer #1883
    • [ ] Conv2D Exporter
    • [ ] Check Channel-First & Channel-Last
    • [ ] Addition Exporter
    • [ ] Activation Layer ( Relu )
    • [ ] Export Resnet Block

jijoongmoon avatar Apr 14 '22 11:04 jijoongmoon

:octocat: cibot: Thank you for posting issue #1879. The person in charge will reply soon.

taos-ci avatar Apr 14 '22 11:04 taos-ci

/cc @mhs4670go , @seanshpark

lemmaa avatar Apr 14 '22 23:04 lemmaa

While I'm drafting codes, I've found out things that need to be changed.

https://github.com/nnstreamer/nntrainer/blob/71415f893217c4ff3c4d0d31a82610cf06cd67a7/nntrainer/compiler/tflite_opnode.cpp#L81-L89

To create flatbuffers Operator, BuiltinOptions needs to be created first and above codes does it. But to create each BuiltinOptions, some attributes are needed. For example, these are the codes that make FullyConnectedOptions.

inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
    flatbuffers::FlatBufferBuilder &_fbb,
    tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
    tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
    bool keep_num_dims = false,
    bool asymmetric_quantize_inputs = false) {
  FullyConnectedOptionsBuilder builder_(_fbb);
  builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
  builder_.add_keep_num_dims(keep_num_dims);
  builder_.add_weights_format(weights_format);
  builder_.add_fused_activation_function(fused_activation_function);
  return builder_.Finish();
}

It means current codes should be changed to like..

switch (op_type) { 
   case tflite::BuiltinOperator_FULLY_CONNECTED:
    return tflite::CreateFullyConnectedOptions(f, fused_activation_function, weights_format, keep_num_dims, asymmetric_quantize_inputs).Union();
  case tflite::BuiltinOperator_DUMMY_OPERATR:
    return tflite::CreateDummyOperatorOptions(f, attribute_0, attribute_1, ..);

So, getBuiltinOps function interface adds some parameter that are needed to create BuiltinOperator.

mhs4670go avatar Apr 20 '22 07:04 mhs4670go

And, after reading some codes, I realized there are some codes that set built-in options.

template <>
void Exporter::saveTflResult(const std::tuple<props::Unit> &props,
                             const FullyConnectedLayer *self) {
  createIfNull(tf_node);

  auto weight_transform = [](std::vector<const Tensor *> &weights) {
    std::vector<Tensor> new_weights;
    new_weights.reserve(weights.size());

    // std::cerr << "weights! " << weights.size() << ' ' <<
    // new_weights.capacity() << std::endl; std::transform(weights.begin(),
    // weights.end(),
    //                std::back_inserter(new_weights),
    //                [](const Tensor *t) { return t->clone(); });
    // std::cerr << "22\n";
    new_weights.push_back(weights[0]->transpose("0:2:1"));
    new_weights.push_back(*weights[1]);
    // std::cerr << "33\n";
    return new_weights;
  };
  tf_node->setWeightTransformFn(weight_transform);

  tf_node->setOpType(tflite::BuiltinOperator_FULLY_CONNECTED);
  /// we probably going to need flatbuffer inside exporter regarding this
  tf_node->setBuiltinOptions(tflite::BuiltinOptions_FullyConnectedOptions,
                             flatbuffers::Offset<void>()); // NOTE HERE
}

So, above getBuiltinOps function seems not needed if setBuiltinOptions method runs properly.

mhs4670go avatar Apr 20 '22 08:04 mhs4670go

I've made above two comments hidden because saveTflResult intentionally calls setBuiltinOptions with flatbuffers::Offset<void>() and create it from getBuiltinOps. Also, since getBuiltinOps is an api of TfOpNode, the parameters needed for creating built-in options can be gotten from its object itself.

mhs4670go avatar Apr 20 '22 08:04 mhs4670go

I've implemented ReLU operator and this is a graph that has a single ReLU op. And, it resulted in weird-looking graph, which can't be captured because of my company policy T_T.

Instead, this is the dump of flatbuffers file with tfldump.

$ ./tfldump ~/nntrainer/build/single_relu.tflite

===================================================================
Model version: 3
 # sub graphs: 1

Operator Codes: [order] OpCodeName (OpCode Enum)
[0] RELU (code: 19, dep_code: 19, version: 1)

Buffers: B(index) (length) values, if any
B(0) (0) 
B(1) (0) 

-------------------------------------------------------------------
Sub-Graph: #0 main

Operands: T(subgraph index : tensor index) TYPE (shape) (shape_signature) B(buffer index) (variable) OperandName
T(0:0) FLOAT32 (1, 1, 1, 1) (-1, 1, 1, 1) B(0) nntrainer_convertedrelu0:input0

Operators: O(subgraph index : operator index) OpCodeName 
    Option(values) ... <-- depending on OpCode
    I T(tensor index) OperandName <-- as input
    O T(tensor index) OperandName <-- as output
O(0:0) RELU 
    I T(0:0) nntrainer_convertedrelu0:input0
    O T(0:0) nntrainer_convertedrelu0:input0

Inputs/Outputs: I(input)/O(output) T(tensor index) OperandName
I T(0:0) nntrainer_convertedrelu0:input0
O T(0:0) nntrainer_convertedrelu0:input0

===================================================================

And, this is expected output of single ReLU graph.

$ ./tfldump ../common-artifacts/ReLU_000.tflite 
Dump: ../common-artifacts/ReLU_000.tflite

===================================================================
Model version: 3
 # sub graphs: 1

Operator Codes: [order] OpCodeName (OpCode Enum)
[0] RELU (code: 19, dep_code: 19, version: 1)

Buffers: B(index) (length) values, if any
B(0) (0) 
B(1) (0) 
B(2) (0) 

SignatureDef

-------------------------------------------------------------------
Sub-Graph: #0 main

Operands: T(subgraph index : tensor index) TYPE (shape) (shape_signature) B(buffer index) (variable) OperandName
T(0:0) FLOAT32 (1, 3, 3, 2) B(1) ifm

T(0:1) FLOAT32 (1, 3, 3, 2) B(2) ofm

Operators: O(subgraph index : operator index) OpCodeName 
    Option(values) ... <-- depending on OpCode
    I T(tensor index) OperandName <-- as input
    O T(tensor index) OperandName <-- as output
O(0:0) RELU 
    I T(0:0) ifm
    O T(0:1) ofm

Inputs/Outputs: I(input)/O(output) T(tensor index) OperandName
I T(0:0) ifm
O T(0:1) ofm

===================================================================

The difference is

  • (nntrainer) input and output having same tensor; same buffer
  • the 0th entry of this array must be an empty buffer

mhs4670go avatar Apr 22 '22 06:04 mhs4670go

And, it resulted in weird-looking graph, which can't be captured because of my company policy T_T.

You may add a link to an internal page, then :)

myungjoo avatar Apr 22 '22 06:04 myungjoo

Q) Is it okay for input and output of an operator to have same buffer?

Yes.

import numpy as np
import tensorflow as tf

interp = tf.lite.Interpreter(model_path='relu_same_buffer.tflite')
interp.allocate_tensors()

input_details = interp.get_input_details()
output_details = interp.get_output_details()

input_shape = input_details[0]['shape']

input_data = np.array(2 * np.random.random_sample(input_shape) - 1, dtype=np.float32)

print(input_data)

interp.set_tensor(input_details[0]['index'], input_data)
interp.invoke()

output_data = interp.get_tensor(output_details[0]['index'])
print(output_data)
$ python test.py
[[[[ 0.353481    0.7638699 ]
   [ 0.15733716 -0.17923921]
   [ 0.70841825 -0.3853303 ]]

  [[ 0.947603    0.11945099]
   [-0.01959909 -0.76299024]
   [-0.33803365  0.2676839 ]]

  [[ 0.6767502   0.281072  ]
   [-0.71875     0.9702313 ]
   [-0.26452628  0.03297982]]]]
[[[[0.353481   0.7638699 ]
   [0.15733716 0.        ]
   [0.70841825 0.        ]]

  [[0.947603   0.11945099]
   [0.         0.        ]
   [0.         0.2676839 ]]

  [[0.6767502  0.281072  ]
   [0.         0.9702313 ]
   [0.         0.03297982]]]]

But, the tensor shouldn't be same, which makes invalid graph and can't be run at tflite interpreter properly.

mhs4670go avatar Apr 22 '22 08:04 mhs4670go

There are some issue with c++ tflite interpreter.

  1. https://www.tensorflow.org/lite/api_docs/cc/class/tflite/interpreter

WARNING: This class is not thread-safe. The client is responsible for ensuring serialized interaction to avoid data races and undefined behavior.

Since gtest uses threads, tests are failed if thests are run at once. But, it passed when single test is run.

  1. same padding conv2D results in wrong output.

When I run the graph from below codes,

auto input0 = LayerRepresentation("input", {"name=in0", "input_shape=1:2:2"});

auto conv0 = LayerRepresentation(
  "conv2d",
  {"name=conv0", "filters=1", "kernel_size=2,2", "stride=1,1", "padding=same",
   "bias_initializer=zeros", "weight_initializer=ones", "input_layers=in0"});
..

nntrainer::Tensor in(nntrainer::TensorDim({1, 1, 2, 2}));
in.setValue(1.0f);

/// run interpreter
/*
      Expected: out
      Which is: <N9nntrainer6TensorE at 0x7ffeb6298530>
data addr: 0x556c8ff39d20
Shape: 1:2:2:1
         4 
         0 

         0 
         0 
*/

it resulted in wrong outputs.

But, when I ran this graph in python, it has expected result.

>>> output_data
array([[[[4.],
         [2.]],

        [[2.],
         [1.]]]], dtype=float32)

Well,, need to know the reason why. I think this is related to above 1(thread unsafe).

mhs4670go avatar Apr 29 '22 15:04 mhs4670go

This is current status.

I've implemented overall codes for exporting ResNet.

Below images are created with #1892. And they are part of ResNet structure.

  1. https://media.github.sec.samsung.net/user/67684/files/4590c34f-1a1c-4bc2-8516-fdcb9f1f6bcc
  2. https://media.github.sec.samsung.net/user/67684/files/92c8d314-1697-4f41-ba72-1eac498de81e

next steps to do

  • Refine the code - doxygen, comments, etc.
  • Do value test
  • Export ResNet graph

issues to resolve

  • Resolve https://github.com/nnstreamer/nntrainer/issues/1879#issuecomment-1113438377
  • Remove mse(mean square error) layer added from graph compiled with MultioutRealizer
  • Transpose output
  • Fuse Relu nodes
  • Support non 4-rank tensor(nntrainer only allows 4-rank tensor but some tensor like bias(1-rank) needs to have other rank)
  • Support data types other than float

mhs4670go avatar Apr 29 '22 15:04 mhs4670go

Thanks for your great work! If you let me comment about some of the issues addressed,

  • Remove mes : it is done with #1895
  • Support data types other than float : Currently NNtrainer only support the float data type. Of course we have to support other data type like int. : This will be the future work considering with QAT as well.
  • Fuse activation : Definitely we need to support using implement another realizer
  • Transpose output : Of course!
  • Support non 4-rank : NNTrainer only support 4 rank tensor and treat the bias same way. We do have effective dimensions in TensorDim and we could use it.

jijoongmoon avatar May 03 '22 02:05 jijoongmoon

@jijoongmoon Thank you for your comments.

Currently NNtrainer only support the float data type. Of course we have to support other data type like int.

Actually, this is for exporting tflite only. For example, Transpose operator needs int32 tensor for perm input.

mhs4670go avatar May 03 '22 02:05 mhs4670go