rust-bert icon indicating copy to clipboard operation
rust-bert copied to clipboard

WIP: NLLB support.

Open npatsakula opened this issue 3 years ago • 1 comments

  • [x] Implement NLLB tokenizer: https://github.com/guillaume-be/rust-tokenizers/pull/76 (waiting for review).
  • [x] Extend language enum: ISO tables was used.
  • [ ] Add resources links (block: model is not converted).
  • [ ] Implement cross-tests.

https://github.com/guillaume-be/rust-bert/issues/277

npatsakula avatar Aug 30 '22 19:08 npatsakula

Hello @guillaume-be! I've made a few assumptions that need a sanity check:

  1. According to this line I assumed that encoder and decoder of M2M should be NLLB compatible.
  2. I've used M2M code with tiny changes such as tokenizer injections and I assumed that M2M code doesn't contain hardcoded constants.

Now I'm getting errors like this:

thread 'nllb_translation' panicked at 'called `Result::unwrap()` on an `Err` value: Torch("index out of range in self\nException raised from operator() at /tmp/libtorch-20220807-32886-1r7el0l/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1209 (most recent call first):\nframe #0: at::native::index_select_out_cpu_(at::Tensor const&, long long, at::Tensor const&, at::Tensor&)::$_7::operator()(long long, long long) const + 508 (0x10daaa2a4 in libtorch_cpu.dylib)\nframe #1: at::native::index_select_out_cpu_(at::Tensor const&, long long, at::Tensor const&, at::Tensor&) + 2528 (0x10daa6830 in libtorch_cpu.dylib)\nframe #2: at::native::index_select_cpu_(at::Tensor const&, long long, at::Tensor const&) + 104 (0x10daaa6c0 in libtorch_cpu.dylib)\nframe #3: at::_ops::index_select::call(at::Tensor const&, long long, at::Tensor const&) + 212 (0x10dd36204 in libtorch_cpu.dylib)\nframe #4: at::native::embedding(at::Tensor const&, at::Tensor const&, long long, bool, bool) + 508 (0x10d943634 in libtorch_cpu.dylib)\nframe #5: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, long long, bool, bool), &(torch::autograd::VariableType::(anonymous namespace)::embedding(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, long long, bool, bool))>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, long long, bool, bool> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, long long, bool, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, long long, bool, bool) + 588 (0x10ede0788 in libtorch_cpu.dylib)\nframe #6: at::_ops::embedding::call(at::Tensor const&, at::Tensor const&, long long, bool, bool) + 260 (0x10e015c6c in libtorch_cpu.dylib)\nframe #7: at::embedding(at::Tensor const&, at::Tensor const&, long long, bool, bool) + 84 (0x1057ebf70 in nllb-d197e1151de7984a)\nframe #8: atg_embedding + 120 (0x1057ebe1c in nllb-d197e1151de7984a)\nframe #9: tch::wrappers::tensor_fallible_generated::_$LT$impl$u20$tch..wrappers..tensor..Tensor$GT$::f_embedding::h87cea000c1cad002 + 304 (0x1057a46d4 in nllb-d197e1151de7984a)\nframe #10: tch::wrappers::tensor_generated::_$LT$impl$u20$tch..wrappers..tensor..Tensor$GT$::embedding::hb3a96d8f5025976d + 68 (0x1057ac468 in nllb-d197e1151de7984a)\nframe #11: _$LT$tch..nn..sparse..Embedding$u20$as$u20$tch..nn..module..Module$GT$::forward::h8a123947dea2df8c + 52 (0x1057c9880 in nllb-d197e1151de7984a)\nframe #12: tch::nn::module::_$LT$impl$u20$tch..wrappers..tensor..Tensor$GT$::apply::h809e888beb935c9f + 44 (0x104e357dc in nllb-d197e1151de7984a)\nframe #13: rust_bert::m2m_100::encoder::M2M100Encoder::forward_t::h6f7a1d9634bb47bb + 116 (0x104d440b4 in nllb-d197e1151de7984a)\nframe #14: rust_bert::m2m_100::m2m_100_model::M2M100ForConditionalGeneration::encode::h5cbabacf24f07bd5 + 56 (0x104d4ac3c in nllb-d197e1151de7984a)\nframe #15: _$LT$rust_bert..m2m_100..m2m_100_model..M2M100Generator$u20$as$u20$rust_bert..pipelines..generation_utils..private_generation_utils..PrivateLanguageGenerator$LT$rust_bert..m2m_100..m2m_100_model..M2M100ForConditionalGeneration$C$rust_tokenizers..vocab..m2m100_vocab..M2M100Vocab$C$rust_tokenizers..tokenizer..m2m100_tokenizer..M2M100Tokenizer$GT$$GT$::encode::h3825ae8f6307768f + 64 (0x104d4c154 in nllb-d197e1151de7984a)\nframe #16: rust_bert::pipelines::generation_utils::LanguageGenerator::generate_from_ids_and_past::h1f0f529920629282 + 3496 (0x104c532f0 in nllb-d197e1151de7984a)\nframe #17: rust_bert::pipelines::generation_utils::LanguageGenerator::generate_indices::h3c67370aceaca893 + 492 (0x104c521d4 in nllb-d197e1151de7984a)\nframe #18: rust_bert::pipelines::generation_utils::LanguageGenerator::generate::hfae34bf2bc36ec46 + 116 (0x104c548d4 in nllb-d197e1151de7984a)\nframe #19: rust_bert::pipelines::translation::translation_pipeline::TranslationOption::generate::hb41ac24f476c5b9b + 692 (0x104bf97d0 in nllb-d197e1151de7984a)\nframe #20: rust_bert::pipelines::translation::translation_pipeline::TranslationModel::translate::hd6ac4fa4980e5b7b + 756 (0x104bf8900 in nllb-d197e1151de7984a)\nframe #21: nllb::nllb_translation::h253d86a79d646c8c + 1056 (0x104c02790 in nllb-d197e1151de7984a)\nframe #22: nllb::nllb_translation::_$u7b$$u7b$closure$u7d$$u7d$::h923894370ee73f66 + 20 (0x104c50fb0 in nllb-d197e1151de7984a)\nframe #23: core::ops::function::FnOnce::call_once::hd4406a2fa5927fd1 + 20 (0x104c18228 in nllb-d197e1151de7984a)\nframe #24: test::__rust_begin_short_backtrace::h7926226eae79829f + 12 (0x104d03670 in nllb-d197e1151de7984a)\nframe #25: test::run_test::run_test_inner::_$u7b$$u7b$closure$u7d$$u7d$::hff3d1e1adba0813b + 492 (0x104d02898 in nllb-d197e1151de7984a)\nframe #26: std::sys_common::backtrace::__rust_begin_short_backtrace::h240dc1bf8b9afcc5 + 288 (0x104cd7abc in nllb-d197e1151de7984a)\nframe #27: core::ops::function::FnOnce::call_once$u7b$$u7b$vtable.shim$u7d$$u7d$::h4762fe9dc553e2c3 + 124 (0x104cdd924 in nllb-d197e1151de7984a)\nframe #28: std::sys::unix::thread::Thread::new::thread_start::h7b2f9b83fb320a20 + 48 (0x105864388 in nllb-d197e1151de7984a)\nframe #29: _pthread_start + 148 (0x18a19826c in libsystem_pthread.dylib)\nframe #30: thread_start + 8 (0x18a19308c in libsystem_pthread.dylib)\n")', /Users/mrpink/.cargo/registry/src/github.com-1ecc6299db9ec823/tch-0.8.0/src/wrappers/tensor_generated.rs:5288:87

I have assumption that it related to language identifiers bug in tokenizer, but I'll check it tomorrow.

npatsakula avatar Sep 05 '22 17:09 npatsakula

Hello @guillaume-be! Can i ask you for a review of this PR?

npatsakula avatar Mar 13 '23 12:03 npatsakula