KoBART-summarization icon indicating copy to clipboard operation
KoBART-summarization copied to clipboard

input, output 문장길이 조정

Open BrainNim opened this issue 2 years ago • 2 comments

안녕하세요, 만들어주신 패키지 재밌게 만져보고 있는데요,
블로그들 크롤링해서 요약해보며 놀다가 몇가지 궁금한 점이 있어서 문의드립니다.

  1. input 문장의 길이제한
# define summarize function
def summarize(text):
    raw_input_ids = tokenizer.encode(text)
    input_ids = [tokenizer.bos_token_id] + raw_input_ids + [tokenizer.eos_token_id]

    summary_ids = model.generate(torch.tensor([input_ids]),  num_beams=4,  max_length=512,  eos_token_id=1)
    result = tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)
    return result

# read csv
df = pd.read_csv('textlist.csv', encoding='ANSI')
text = df.text[0]


# 정상작동
>>> summarize(text[:1998])  # text[:1998] = "(생략...) 했으나 맛은 괜찮었어요 평범했던 그린"
'백운호수점 분당에서는 20 25분정도 밖에 안 걸리는 곳이라 나들이삼아 기분전환하러 가기 좋은 곳인 것 같아요.'

# Traceback
>>> summarize(text[:1999])  # text[:1999] = "(생략...) 으나 맛은 괜찮었어요 평범했던 그린샐"
>>> summarize(text[:2001])  # text[:2001] = "(생략...)  맛은 괜찮었어요 평범했던 그린샐러드"
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 4, in summarize
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\autograd\grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\transformers\generation_utils.py", line 927, in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\transformers\generation_utils.py", line 412, in _prepare_encoder_decoder_kwargs_for_generation 
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\transformers\models\bart\modeling_bart.py", line 752, in forward
    embed_pos = self.embed_positions(input_shape)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\transformers\models\bart\modeling_bart.py", line 124, in forward
    return super().forward(positions + self.offset)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\sparse.py", line 160, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\functional.py", line 2044, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

글자 길이 1998까지는 잘 들어가는데 길이가 1999 이상인 경우 out of range가 나타났습니다. 더 긴 문장을 넣고 싶을 경우에는 어떻게 수정하면 될지 조언을 구하고 싶습니다.

  1. output 문장길이의 제한 위의 예시는 꽤 깔끔한 한 문장으로 요약이 되었지만, 그렇지 않은 경우들도 있었습니다.
text = df.text[2]

>>> summarize(text[:500])
'엄마 생신 평일이기도 했고 엄마 생일인데 엄마가 상차리는 것도 웃기기도 하고 내가 먼저 새언니에게 여자들끼리 맛있는거 먹으러 가자 제안 새언니가 자기도 그 생각했다면서 ᄒᄒ 그렇게 한달 전부터 계획된 온나카이 장소 검색하다가선희가 올라 백운호수
점에 갔다가 찍은 사진을 올렸서 급 생각해낸 올라 한 4 5년 전 어느날 선희가 밥먹자 하고 부른 곳이였는데 나는 얼떨결에 갔다가 제일 비싼 코스요리 얻어먹고 왔다 아무생각없이 얻어먹고와서 좋았던 기억만 있었는데 분위기가 꽤 좋았어서 _ 엄마 생신 장 
소도 여기로 정했다 일단 예약하고 나서 검색검색검색 후기가 반반 나뉘는데 급 걱정 가격대비 별로면 어쩌남 좋은날인데 맛없으면 어쩌남 결론부터 말하면 나는 너무 맛있었다                                   직원들이 불친절하다는 글을 많이 봤는데 사 
실 나는 딱 이정도가 좋다 2층홀 내부가 작은 편인데너무 쳐다보거나 말을 많이 붙히면 더 불편하고 싫었을 것 같다 계산해주시는분도 그냥 적정선을 지켜 일하고 계시는 느낌이지 딱딱선을 지켜 일하고 계시는 느낌이지 딱딱선을 지켜 일하고 계시는 느낌이지
 딱정도가 좋다 2층홀 내부가 작은 편인데 2층홀 내부가 작은 편인데너무 쳐다보거나 말을 많이 붙히면 더 불편하고 싫었을 것 같다 계산해주시는분도 그냥 적정선을 지켜 일하고 계시는 느낌이지 딱딱선을 지켜 일하고 계시는 느낌이지 딱딱선을 지켜 일하고 
계시는 느낌이지 딱딱선을 지켜 일하고 계시는 느낌이지 딱딱선을 지켜 일하고 계시는 느낌이지 딱딱선을 지켜 일하고 계시는 느낌이지 딱딱선을 지켜 일하고 계시는 느낌이지 딱정도가 좋다 2층홀 내부가 작은 편인데 직원들이 불친절하다는 글을 많이 봤는데
 사실 나는 딱 이정도가 좋다 2층홀 내부가 작은 편인데너무 쳐다보거나 말을 많이 붙히면 더 불편하고 싫었을 것 같다 계산해주시는분도 그냥 적정선을 지켜 일하고 계시는 느낌이다 계산해주시는분도 그냥 적정선을 지켜 일하고 계시는 느낌이다 계산해주시 
는분도 그냥 적정선을 지켜 일하고 계시는 느낌이지 딱정도가 좋다 2층홀 내부가 작은 편인데 1층홀 내부가 작은 편인데너무 쳐다보거나 말을 많이 붙히면 더 불편하고 싫었을 것 같다 계산해주시는분도'

# len(summarize(text[:500])) = 1086

위의 경우는 오히려 글자가 늘어났는데요,

summary_ids = model.generate(torch.tensor([input_ids]),  num_beams=4,  max_length=512,  eos_token_id=1)

에서 max_length를 200, 100으로 줄일 경우 요약문장 길이도 어느정도 줄어들기는 하였으나 여전히 300~500자 정도 되었습니다.

어떻게 하면 좀 더 합리적으로 짧은 문장을 만들게 할 수 있을까요?

감사합니다.

BrainNim avatar May 13 '22 08:05 BrainNim

제가 알고있는 내용 바탕으로 궁금하신 부분에 대해 정말 간단히 답변드려보겠습니다.

일단 2개 이슈 모두 Model의 config와 관련있다고 판단 됩니다.

1. input 문장의 길이제한

현재 len 함수를 사용하여 문자열 길이를 구하시고 계십니다. 하지만 SKT KoBART model의 input max_length는 token 단위 기준으로 "max_length=1026" 입니다. special token 포함 1026개 보다 많은 token을 model에 넣어 발생한 에러라고 생각됩니다.

encode 하실 때 truncate 해주시면 해결될 것 같습니다.

2. output 문장길이의 제한

generate method안 parameter인 max_length 또한 token 단위이고, 생성된 요약문을 tokenize 시켜 길이를 구해보면 설정하신 "max_length=512" 보다 적은 token으로 생성된 걸 알 수 있습니다.

사실 정해진 합리적인 길이는 없습니다. 저는 공신력 있는 PORORO(Dacon), fairseq(CNN_DM, XSUM) repo에서 설정한 max_length hyperparameter를 참고하여 결정하는 것도 괜찮은 방법이라고 생각합니다.

Jinhyeong-Lim avatar May 14 '22 06:05 Jinhyeong-Lim

encode 할 때 truncate 하려면 어떻게 해야하나요?? kobart-transformerskobart_tokenizer(["한국어", "BART 모델을", "소개합니다."], truncation=True, padding=True)에서처럼 truncation=True를 붙이는건가 하고 실험삼아

raw_input_ids = tokenizer.encode(text, truncation=True)

를 해봤는데 역시나 raw_input_ids의 길이가 truncation을 했을 때와 안했을 때 길이가 같았습니다ㅜㅜ

조금 더 자세히 설명해주시면 감사드리겠습니다ㅜㅜㅜㅜ

BrainNim avatar May 16 '22 01:05 BrainNim