CTranslate2 icon indicating copy to clipboard operation
CTranslate2 copied to clipboard

Support for UMT5

Open QLutz opened this issue 1 year ago • 2 comments

New UMT5 models from Google are currently the most interesting variation of the original T5s.

However, trying to convert a UMT5 model using the transformers converter by running:

ct2-transformers-converter --model google/umt5-xl --output_dir ct2-umt5-3b --quantization int8

yields:

Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 812/812 [00:00<00:00, 4.96MB/s]
Traceback (most recent call last):
  File "/home/user/miniconda3/bin/ct2-transformers-converter", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/transformers.py", line 1719, in main
    converter.convert_from_args(args)
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/converter.py", line 50, in convert_from_args
    return self.convert(
           ^^^^^^^^^^^^^
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/converter.py", line 89, in convert
    model_spec = self._load()
                 ^^^^^^^^^^^^
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/transformers.py", line 106, in _load
    raise ValueError(
ValueError: No conversion is registered for the model configuration UMT5Config (supported configurations are: BartConfig, BertConfig, BloomConfig, CodeGenConfig, DistilBertConfig, FalconConfig, GPT2Config, GPTBigCodeConfig, GPTJConfig, GPTNeoXConfig, LlamaConfig, M2M100Config, MBartConfig, MPTConfig, MT5Config, MarianConfig, OPTConfig, PegasusConfig, RWConfig, T5Config, WhisperConfig, XLMRobertaConfig)

Is there an easy workaround ? Is this something that should be added in the package ?

QLutz avatar Sep 13 '23 08:09 QLutz

The UMT5 model features a unique relative attention bias for each self-attention layer. Therefore, the corresponding converter code in the file can be written as follows:

@register_loader("UMT5Config")
class UMT5Loader(T5Loader):
    @property
    def architecture_name(self):
        return "UMT5ForConditionalGeneration"

    def set_stack(self, spec, module, is_decoder=False):
        self.set_layer_norm(spec.layer_norm, module.final_layer_norm)
        self.set_embeddings(
            spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings,
            module.embed_tokens,
        )

        spec.scale_embeddings = False

        for layer_spec, block in zip(spec.layer, module.block):
            self.set_self_attention(layer_spec.self_attention, block.layer[0])

            if is_decoder:
                self.set_cross_attention(layer_spec.attention, block.layer[1])

            self.set_ffn(layer_spec.ffn, block.layer[-1])

Also, when computing attention, position_bias is reused across layers rather than being recalculated using the relative attention bias, with the exception of the first layer.

https://github.com/OpenNMT/CTranslate2/blob/2203ad5c8baf878a2d08e73095421e2ba033c89c/src/layers/attention.cc#L236-L253

To use the correct position_bias for every layers, simply disable the if condition in the code below:

// if (position_bias->empty()) {
const dim_t query_length = queries.dim(2);
const dim_t key_length = keys.dim(2);
*position_bias = compute_relative_bias(*relative_attention_bias,
                                        query_length,
                                        key_length,
                                        maximum_relative_position,
                                        is_decoder,
                                        with_cache ? key_length - 1 : 0);
// }

However, this approach may lead to performance degradation in T5 and MT5 models.

soocheolnoh avatar Oct 11 '23 07:10 soocheolnoh