tract icon indicating copy to clipboard operation
tract copied to clipboard

Broadcasting in einsum

Open Rikorose opened this issue 2 years ago • 1 comments

So I have this einsum operator that broadcasts the second parameter (self.weight) to be able to compute the expression.

x = torch.einsum("btgi,gih->btgh", x, self.weight)

This results in the following error:

Error: Translating node #29 "enc.df_fc_emb.0.weight.0" Const Pulsifier(1)

Caused by:
    No pulsifier nor pulsable axis invariant for #29 "enc.df_fc_emb.0.weight.0" Const

Stack backtrace:
   0: <tract_pulse::model::Pulsifier as tract_core::model::translator::Translate<tract_core::model::fact::TypedFact,alloc::boxed::Box<dyn tract_core::ops::TypedOp>,tract_pulse::fact::PulsedFact,alloc::boxed::Box<dyn tract_pulse::ops::PulsedOp>>>::translate_node
             at /home/hendrik/ext_src/tract/pulse/src/model.rs:135:9
   1: tract_core::model::translator::Translate::translate_model_with_mappings
             at /home/hendrik/ext_src/tract/core/src/model/translator.rs:35:27
   2: <tract_core::model::graph::Graph<tract_pulse::fact::PulsedFact,alloc::boxed::Box<dyn tract_pulse::ops::PulsedOp>> as tract_pulse::model::PulsedModelExt>::new_with_mapping
             at /home/hendrik/ext_src/tract/pulse/src/model.rs:28:9
   3: <tract_core::model::graph::Graph<tract_pulse::fact::PulsedFact,alloc::boxed::Box<dyn tract_pulse::ops::PulsedOp>> as tract_pulse::model::PulsedModelExt>::new
             at /home/hendrik/ext_src/tract/pulse/src/model.rs:20:12
   4: df_tract::init_model
             at ./libDF/src/bin/tract.rs:74:18
   5: df_tract::main
             at ./libDF/src/bin/tract.rs:102:9
   6: core::ops::function::FnOnce::call_once
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/core/src/ops/function.rs:248:5
   7: std::sys_common::backtrace::__rust_begin_short_backtrace
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/sys_common/backtrace.rs:122:18
   8: std::rt::lang_start::{{closure}}
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/rt.rs:145:18
   9: core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/core/src/ops/function.rs:280:13
      std::panicking::try::do_call
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/panicking.rs:492:40
      std::panicking::try
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/panicking.rs:456:19
      std::panic::catch_unwind
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/panic.rs:137:14
      std::rt::lang_start_internal::{{closure}}
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/rt.rs:128:48
      std::panicking::try::do_call
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/panicking.rs:492:40
      std::panicking::try
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/panicking.rs:456:19
      std::panic::catch_unwind
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/panic.rs:137:14
      std::rt::lang_start_internal
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/rt.rs:128:20
  10: std::rt::lang_start
             at /rustc/4b91a6ea7258a947e59c6522cd5898e7c0a6a88f/library/std/src/rt.rs:144:17
  11: main
  12: __libc_start_call_main
  13: __libc_start_main@@GLIBC_2.34
  14: _start

Ideally this can also be written as

x = torch.einsum("...i,...ih->...h", x, self.weight)

See #791

Rikorose avatar Aug 16 '22 13:08 Rikorose

We don't have pulsification for EinSum yet. It is not necessarily super-hard to do, so I can have a look, but it would help if you could make a test case (I think you were on the way to do this when you hit 791).

kali avatar Aug 17 '22 09:08 kali

I'm gonna close this one. We do have pulsification for Einsum now. Feel free to open new tickets if/when you find issues :)

kali avatar Oct 28 '22 08:10 kali