wav2letter icon indicating copy to clipboard operation
wav2letter copied to clipboard

How to deal with model checkpoint compatibility issue?

Open DongChanS opened this issue 3 years ago • 16 comments

Question

I trained Transformer (Transformer encoder + Transformer criterion) model from Wav2letter v0.2.

Unfortunately, I should use flashlight-consolidated's Wav2letter (Due to some updates in flashlight..)

But, i cannot use previous checkpoint directly.

Here is the error message.

E0324 07:40:20.319444 84544 Serializer.h:145 Error while loading "(CHECKPOINT_PATH)": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

Yes, previous checkpoint (Transformer criterion) have the type w2l::TransformerCriterion.

It is not fl::app::asr::TransformerCriterion in flashlight-consolidated.

How to solve this problem??

Additional Context

New version's transformer

// flashlight/flashlight/app/asr/criterion/TransformerCriterion.h
AMUpdateFunc buildTransformerAmUpdateFunction(
    std::shared_ptr<SequenceCriterion>& crit);
} // namespace asr
} // namespace app
} // namespace fl

CEREAL_REGISTER_TYPE(fl::app::asr::TransformerCriterion)

Old version's transformer

// wav2letter/src/criterion/TransformerCriterion.h
AMUpdateFunc buildTransformerAmUpdateFunction(
    std::shared_ptr<SequenceCriterion>& crit);

} // namespace w2l

CEREAL_REGISTER_TYPE(w2l::TransformerCriterion)

DongChanS avatar Mar 24 '21 08:03 DongChanS

cc @vineelpratap @avidov @jacobkahn @xuqiantong Do we have converting scripts or any guides / hints how to do this?

tlikhomanenko avatar Mar 27 '21 05:03 tlikhomanenko

Is it impossible..?? if not, please let me know how to change class type of checkpoint..

DongChanS avatar Apr 02 '21 01:04 DongChanS

It is possible =) @vineelpratap @avidov

tlikhomanenko avatar Apr 05 '21 18:04 tlikhomanenko

Hi, Sorry for the delay. We were busy with Interspeech deadline =) . We will aim to provide a script today/tomorrow to do the conversion.

vineelpratap avatar Apr 05 '21 18:04 vineelpratap

Is there any problem for providing script..?? I should re-train the same model unless i receive the guidelines...

DongChanS avatar Apr 13 '21 07:04 DongChanS

yep, it is in the PR for now, need to fix some CI stuff but you can try it https://github.com/facebookresearch/flashlight/pull/524. Please comment if you have any troubles to use it as it is in this PR.

tlikhomanenko avatar Apr 13 '21 07:04 tlikhomanenko

It is not working...

I was built serialization tools from above PR, but this error message occurred.

root@cd303acf12b0:~/flashlight/build/bin/asr# ./fl_asr_model_converter old {old_model_path}
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0426 06:20:01.159353 20012 ModelConverter.cpp:105] Saving params from `old binary` model to a binary dump
E0426 06:20:01.160109 95840 Serializer.h:82 Error while loading "{old_model_path}": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

E0426 06:20:02.160425 95840 Serializer.h:82 Error while loading "{old_model_path}": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

E0426 06:20:04.160701 95840 Serializer.h:82 Error while loading "{old_model_path}": Trying to load an unregistered polymorphic type (w2l::TransformerCriterion).
Make sure your type is registered with CEREAL_REGISTER_TYPE and that the archive you are using was included (and registered with CEREAL_REGISTER_ARCHIVE) prior to calling CEREAL_REGISTER_TYPE.
If your type is already registered and you still see this error, you may need to use CEREAL_REGISTER_DYNAMIC_INIT.

I think because it also use fl::ext::Serializer::load(modelPath, version, cfg, network, criterion) which is also used in Decode.cpp

DongChanS avatar Apr 26 '21 06:04 DongChanS

Seems your old bin still doesn't have proper classes thus you cannot load model. Are you sure to use old_binary that has w2l::TransformerCriterion?

Also cc @vineelpratap.

tlikhomanenko avatar Apr 28 '21 21:04 tlikhomanenko

@DongChanS - {old_model_path} should be replaced with the appropriate path...

Also, can you copy the current fl::app::asr::TransformerCriterion class and create a duplicate class in the same file under namespace w2l::TransformerCriterion .

vineelpratap avatar Apr 28 '21 22:04 vineelpratap

@tlikhomanenko - Yes, i'm sure. but i didn't know that the serialization tool require full AM binary (network + criterion). so i try to this with full AM binary file!

Thanks! i successfully convert the wav2letter v0.2 binary to flashlight v0.3 binary

But, i conducted procedures different than @vineelpratap. is it okay?

  1. Since i need only saveToBinaryDump function in wav2letter v0.2, I built serialization tools in wav2letter v0.2 with minimal setting
  // tools/serialization/ModelConverter.cpp
  std::string binaryType = argv[1];
  std::string modelPath = argv[2];
  std::string version;
  if (binaryType == "old") {
    LOG(INFO) << "Saving params from `old binary` model to a binary dump";
    W2lSerializer::load(modelPath, cfg, criterion);
    saveToBinaryDump(tempModelPath(modelPath).c_str(), network, criterion);
  } else if (binaryType == "new") {
    LOG(FATAL) << "Unsupported binary type in wav2letter";
  } else {
    LOG(FATAL) << "Incorrect binary type specified.";
  }
  1. I built flashlight v0.3 with full setting of serialization tools, and run it
root@cd303acf12b0:~/flashlight/build/bin/asr# ./fl_asr_model_converter new /root/flashlight/025_model_last.bin
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0429 07:28:43.208770 34623 ModelConverter.cpp:112] Loading model params from binary dump to `new binary` model
I0429 07:28:52.309464 34623 ModelConverter.cpp:220] Done !

But this model cannot work properly..... This is the error messase in the TransformerCriterion.

I0429 07:31:37.571990 34664 memoryefficient_offline_inference_cpu_consolidated.cpp:329] [Criterion] Number of params: 15070135
I0429 07:31:37.577639 34664 memoryefficient_offline_inference_cpu_consolidated.cpp:362] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
[ConvLM]: Loading vocabulary from /model/wav2letter/v0.3/lm_vocab
[ConvLM]: vocabulary size of convLM 20552
I0429 07:31:52.202826 34664 memoryefficient_offline_inference_cpu_consolidated.cpp:378] [Decoder] LM constructed.
I0429 07:31:52.203383 34891 memoryefficient_offline_inference_cpu_consolidated.cpp:422] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
I0429 07:31:52.203384 34889 memoryefficient_offline_inference_cpu_consolidated.cpp:422] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
I0429 07:31:52.203434 34888 memoryefficient_offline_inference_cpu_consolidated.cpp:487] [Decoder] LexiconFreeSeq2Seq decoder with token-LM loaded in thread: 0
I0429 07:31:52.203383 34890 memoryefficient_offline_inference_cpu_consolidated.cpp:422] [ConvLM]: Loading LM from /model/wav2letter/v0.3/lm_model
terminate called after throwing an instance of 'std::invalid_argument'
  what():  Invalid inputs for transformer block: there should be at least input and mask
*** Aborted at 1619681514 (unix time) try "date -d @1619681514" if you are using GNU date ***
PC: @     0x7efd1920418b gsignal
*** SIGABRT (@0x8768) received by PID 34664 (TID 0x7efd14680980) from PID 34664; stack trace: ***
    @     0x7efd24996631 (unknown)
    @     0x7efd1d87a3c0 (unknown)
    @     0x7efd1920418b gsignal
    @     0x7efd191e3859 abort
    @     0x7efd195fc951 (unknown)
    @     0x7efd1960847c (unknown)
    @     0x7efd196084e7 std::terminate()
    @     0x7efd1960846f std::rethrow_exception()
    @     0x563be81e7f05 main
    @     0x7efd191e50b3 __libc_start_main
    @     0x563be826fe2e _start

DongChanS avatar Apr 29 '21 07:04 DongChanS

This error message is related to Transformer module in flashlight

The Flashlight v0.3's Transformer require mask unlike wav2letter v0.2's.

// flashlight/flashlight/fl/contrib/modules/Transformer.cpp

std::vector<Variable> Transformer::forward(const std::vector<Variable>& input) {
  // previous step[optionally], input, padMask
  // padMask should be empty if previous step is provided
  // padMask is expected to have "1" on the used positions and "0" on padded
  // positions
  if (input.size() < 2) {
    throw std::invalid_argument(
        "Invalid inputs for transformer block: there should be at least input and mask");
  }
  auto x = input.at(input.size() - 2);
  if (!input.back().isempty() && x.dims(2) != input.back().dims(1)) {
    throw std::invalid_argument(
        "Invalid inputs for transformer block: input and Mask batch sizes are different");
  }

But, The Transformer encoder doesn't return mask. (since the last layer of encoder is Linear layer, mask is not included in output)

// flashlight/flashlight/ext/common/SequentialBuilder.cpp

fl::Variable forwardSequentialModuleWithPadMask(
    const fl::Variable& input,
    std::shared_ptr<fl::Module> ntwrk,
    const af::array& inputSizes) {
  // expected input dims T x C x 1 x B
  int T = input.dims(0), B = input.dims(3);
  auto inputMaxSize = af::tile(af::max(inputSizes), 1, B);
  af::array inputNotPaddedSize = af::ceil(inputSizes * T / inputMaxSize);
  auto padMask = af::iota(af::dim4(T, 1), af::dim4(1, B)) <
      af::tile(inputNotPaddedSize, T, 1);
  auto ntwrkSeq = std::dynamic_pointer_cast<fl::Sequential>(ntwrk);
  auto output = input;
  for (auto& module : ntwrkSeq->modules()) {
    auto tr = std::dynamic_pointer_cast<fl::Transformer>(module);
    auto cfr = std::dynamic_pointer_cast<fl::Conformer>(module);
    if (tr != nullptr || cfr != nullptr) {
      output = module->forward({output, fl::noGrad(padMask)}).front();
    } else {
      output = module->forward({output}).front();
    }
  }
  return output.as(input.type());
}

How to solve this problem..??

DongChanS avatar Apr 29 '21 07:04 DongChanS

Sorry, I don't get what happened. So you have converted model and running decoding in fl v0.3 and see the error on the forward pass for the transformer block, right? Let me check that for s2s it should work and we have a proper call on transformer blocks everywhere.

tlikhomanenko avatar Apr 29 '21 15:04 tlikhomanenko

@DongChanS I believe you can just do

std::vector<Variable> Transformer::forward(const std::vector<Variable>& input2) {
   auto input = input2;
   if (input2.size() == 1) {
       input.push_back(fl::Variable(af::array(), false));
    }
   if (input.size() < 2) {
      throw std::invalid_argument(
        "Invalid inputs for transformer block: there should be at least input and mask");
  }

I'll let @tlikhomanenko confirm though...

vineelpratap avatar May 06 '21 05:05 vineelpratap

@DongChanS Please change this https://github.com/flashlight/flashlight/blob/master/flashlight/app/asr/criterion/TransformerCriterion.cpp#L284 to

yBatched = layer(i)->forward(std::vector<Variable>({yBatched}), fl::Variable(af::array())).front();

and this https://github.com/flashlight/flashlight/blob/master/flashlight/app/asr/criterion/TransformerCriterion.cpp#L296 to

yBatched = layer(i)->forward(tmp, fl::Variable(af::array())).front();

I will send this fix later, but this should unblock you. Let me know if you still have problems.

tlikhomanenko avatar May 07 '21 06:05 tlikhomanenko

Good!

Since there is some syntax error, i changed these two lines to

yBatched = layer(i)->forward(std::vector<Variable>({yBatched, fl::Variable(af::array(), false)})).front();
tmp.push_back(fl::Variable(af::array(), false));
yBatched = layer(i)->forward(tmp).front();

then, the model works fine!

DongChanS avatar May 07 '21 09:05 DongChanS

Feel free to send PR on this =)

tlikhomanenko avatar May 08 '21 00:05 tlikhomanenko