curated-transformers icon indicating copy to clipboard operation
curated-transformers copied to clipboard

Truncation of sequences that are beyond the model's maximum length

Open MootezSaaD opened this issue 1 year ago • 2 comments

Hi, First, I would like to thank you for this library :-) I'm really enjoying it.

I tried to tokenize a sequence with around 4K tokens and then fed it to a RoBERTa-based model (CodeBERT). This led to the following issue,

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-12-a373b5333f39>](https://localhost:8080/#) in <cell line: 1>()
      2    ids = input_sentence.padded_tensor(padding_id=0, pad_left=True)
      3    mask = input.attention_mask(pad_left=True)
----> 4    model_output = encoder(piece_ids=ids, attention_mask=mask)

10 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, attention_mask, positions, type_ids)
    122         type_ids: Optional[Tensor] = None,
    123     ) -> ModelOutput:
--> 124         embeddings = self.embeddings(piece_ids, positions=positions, type_ids=type_ids)
    125         layer_output = embeddings
    126 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/roberta/embeddings.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
     96         if positions is None:
     97             positions = self._get_positions(piece_ids)
---> 98         return super().forward(
     99             piece_ids,
    100             positions=positions,

[/usr/local/lib/python3.10/dist-packages/curated_transformers/layers/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
    180             if positions is None:
    181                 positions = self._get_positions(piece_ids)
--> 182             position_embeddings = self.position_embeddings(positions)
    183             embeddings += position_embeddings
    184 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py](https://localhost:8080/#) in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2231         # remove once script supports set_grad_enabled
   2232         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2233     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2234 
   2235 

IndexError: index out of range in self

For reference, here was the code that I was using,

MODEL_TAG = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_hf_hub(name=MODEL_TAG, revision="main")
model = RoBERTaEncoder.from_hf_hub(
    name=MODEL_TAG,
    revision="main",
)
code = [
   'void avcodec_string(char *buf, int buf_size, AVCodecContext *enc, int encode)\n\n{\n\n    const char *codec_type;\n\n    const char *codec_name;\n\n    const char *profile = NULL;\n\n    const AVCodec *p;\n\n    int64_t bitrate;\n\n    int new_line = 0;\n\n    AVRational display_aspect_ratio;\n\n    const char *separator = enc->dump_separator ? (const char *)enc->dump_separator : ", ";\n\n\n\n    if (!buf || buf_size <= 0)\n\n        return;\n\n    codec_type = av_get_media_type_string(enc->codec_type);\n\n    codec_name = avcodec_get_name(enc->codec_id);\n\n    if (enc->profile != FF_PROFILE_UNKNOWN) {\n\n        if (enc->codec)\n\n            p = enc->codec;\n\n        else\n\n            p = encode ? avcodec_find_encoder(enc->codec_id) :\n\n                        avcodec_find_decoder(enc->codec_id);\n\n        if (p)\n\n            profile = av_get_profile_name(p, enc->profile);\n\n    }\n\n\n\n    snprintf(buf, buf_size, "%s: %s", codec_type ? codec_type : "unknown",\n\n             codec_name);\n\n    buf[0] ^= \'a\' ^ \'A\'; /* first letter in uppercase */\n\n\n\n    if (enc->codec && strcmp(enc->codec->name, codec_name))\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", enc->codec->name);\n\n\n\n    if (profile)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", profile);\n\n    if (   enc->codec_type == AVMEDIA_TYPE_VIDEO\n\n        && av_log_get_level() >= AV_LOG_VERBOSE\n\n        && enc->refs)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %d reference frame%s",\n\n                 enc->refs, enc->refs > 1 ? "s" : "");\n\n\n\n    if (enc->codec_tag) {\n\n        char tag_buf[32];\n\n        av_get_codec_tag_string(tag_buf, sizeof(tag_buf), enc->codec_tag);\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 " (%s / 0x%04X)", tag_buf, enc->codec_tag);\n\n    }\n\n\n\n    switch (enc->codec_type) {\n\n    case AVMEDIA_TYPE_VIDEO:\n\n        {\n\n            char detail[256] = "(";\n\n\n\n            av_strlcat(buf, separator, buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 "%s", enc->pix_fmt == AV_PIX_FMT_NONE ? "none" :\n\n                     av_get_pix_fmt_name(enc->pix_fmt));\n\n            if (enc->bits_per_raw_sample && enc->pix_fmt != AV_PIX_FMT_NONE &&\n\n                enc->bits_per_raw_sample < av_pix_fmt_desc_get(enc->pix_fmt)->comp[0].depth)\n\n                av_strlcatf(detail, sizeof(detail), "%d bpc, ", enc->bits_per_raw_sample);\n\n            if (enc->color_range != AVCOL_RANGE_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_color_range_name(enc->color_range));\n\n\n\n            if (enc->colorspace != AVCOL_SPC_UNSPECIFIED ||\n\n                enc->color_primaries != AVCOL_PRI_UNSPECIFIED ||\n\n                enc->color_trc != AVCOL_TRC_UNSPECIFIED) {\n\n                if (enc->colorspace != (int)enc->color_primaries ||\n\n                    enc->colorspace != (int)enc->color_trc) {\n\n                    new_line = 1;\n\n                    av_strlcatf(detail, sizeof(detail), "%s/%s/%s, ",\n\n                                av_color_space_name(enc->colorspace),\n\n                                av_color_primaries_name(enc->color_primaries),\n\n                                av_color_transfer_name(enc->color_trc));\n\n                } else\n\n                    av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                                av_get_colorspace_name(enc->colorspace));\n\n            }\n\n\n\n            if (av_log_get_level() >= AV_LOG_DEBUG &&\n\n                enc->chroma_sample_location != AVCHROMA_LOC_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_chroma_location_name(enc->chroma_sample_location));\n\n\n\n            if (strlen(detail) > 1) {\n\n                detail[strlen(detail) - 2] = 0;\n\n                av_strlcatf(buf, buf_size, "%s)", detail);\n\n            }\n\n        }\n\n\n\n        if (enc->width) {\n\n            av_strlcat(buf, new_line ? separator : ", ", buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%dx%d",\n\n                     enc->width, enc->height);\n\n\n\n            if (av_log_get_level() >= AV_LOG_VERBOSE &&\n\n                (enc->width != enc->coded_width ||\n\n                 enc->height != enc->coded_height))\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " (%dx%d)", enc->coded_width, enc->coded_height);\n\n\n\n            if (enc->sample_aspect_ratio.num) {\n\n                av_reduce(&display_aspect_ratio.num, &display_aspect_ratio.den,\n\n                          enc->width * enc->sample_aspect_ratio.num,\n\n                          enc->height * enc->sample_aspect_ratio.den,\n\n                          1024 * 1024);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " [SAR %d:%d DAR %d:%d]",\n\n                         enc->sample_aspect_ratio.num, enc->sample_aspect_ratio.den,\n\n                         display_aspect_ratio.num, display_aspect_ratio.den);\n\n            }\n\n            if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n                int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n            }\n\n        }\n\n        if (encode) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", q=%d-%d", enc->qmin, enc->qmax);\n\n        } else {\n\n            if (enc->properties & FF_CODEC_PROPERTY_CLOSED_CAPTIONS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", Closed Captions");\n\n            if (enc->properties & FF_CODEC_PROPERTY_LOSSLESS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", lossless");\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_AUDIO:\n\n        av_strlcat(buf, separator, buf_size);\n\n\n\n        if (enc->sample_rate) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%d Hz, ", enc->sample_rate);\n\n        }\n\n        av_get_channel_layout_string(buf + strlen(buf), buf_size - strlen(buf), enc->channels, enc->channel_layout);\n\n        if (enc->sample_fmt != AV_SAMPLE_FMT_NONE) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %s", av_get_sample_fmt_name(enc->sample_fmt));\n\n        }\n\n        if (   enc->bits_per_raw_sample > 0\n\n            && enc->bits_per_raw_sample != av_get_bytes_per_sample(enc->sample_fmt) * 8)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     " (%d bit)", enc->bits_per_raw_sample);\n\n        break;\n\n    case AVMEDIA_TYPE_DATA:\n\n        if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n            int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n            if (g)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_SUBTITLE:\n\n        if (enc->width)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %dx%d", enc->width, enc->height);\n\n        break;\n\n    default:\n\n        return;\n\n    }\n\n    if (encode) {\n\n        if (enc->flags & AV_CODEC_FLAG_PASS1)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 1");\n\n        if (enc->flags & AV_CODEC_FLAG_PASS2)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 2");\n\n    }\n\n    bitrate = get_bit_rate(enc);\n\n    if (bitrate != 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %"PRId64" kb/s", bitrate / 1000);\n\n    } else if (enc->rc_max_rate > 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", max. %"PRId64" kb/s", (int64_t)enc->rc_max_rate / 1000);\n\n    }\n\n}\n',
]
with torch.no_grad():
    input_sentence = tokenizer(code)
    ids = input_sentence.padded_tensor(padding_id=0, pad_left=False)
    mask = input_sentence.attention_mask(pad_left=False)
    model_output = model(piece_ids=ids, attention_mask=mask)

I went through the API docs and skimmed through source code and it appears that truncation is not supported. Note that when I manually truncated the sequence, I was able to feed it to the RoBERTa encoder.

MootezSaaD avatar Jan 14 '24 18:01 MootezSaaD

Thanks for the report! As you surmised, we don't currently support the truncation of inputs, but that error message can definitely be improved. We'll look into it, but please feel free to contribute a PR if you'd like to sort it out yourself 😃

shadeMe avatar Jan 14 '24 19:01 shadeMe

Just wanted to add that we do support longer sequences with Curated Transformers in spaCy. We should probably provide something similar in Curated Transformers that could be used as an extension.

danieldk avatar Jan 31 '24 13:01 danieldk