PaddleNLP
PaddleNLP copied to clipboard
函数paddlenlp.data.data_collator.DataCollatorWithPadding 与描述不符
描述中是pad为batch中最长,实际上最长长度为max_model_length。
@dataclass
class DataCollatorWithPadding:
"""
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
Args:
tokenizer (`paddlenlp.transformers.PretrainedTokenizer`):
The tokenizer used for encoding the data.
"""
tokenizer: PretrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pd"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
if "label" in batch:
batch["labels"] = batch["label"]
del batch["label"]
if "label_ids" in batch:
batch["labels"] = batch["label_ids"]
del batch["label_ids"]
return batch
我做的更改是:
@dataclass
class DataCollatorWithPadding:
"""
Data collator that will dynamically pad the inputs to the longest sequence in the *batch* (original implement in paddlenlp pads sequence to model_max_length).
Args:
tokenizer (`paddlenlp.transformers.PretrainedTokenizer`):
The tokenizer used for encoding the data.
"""
tokenizer: PretrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pd"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
input_ids = [x["input_ids"] for x in features]
max_length = max(list(map(len, input_ids)))
print(max_length)
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
if "label" in batch:
batch["labels"] = batch["label"]
del batch["label"]
if "label_ids" in batch:
batch["labels"] = batch["label_ids"]
del batch["label_ids"]
return batch
抱歉,这里文档描述的确实不太清楚,这里应该是默认行为是pad到batch中最长
This issue is stale because it has been open for 60 days with no activity. 当前issue 60天内无活动,被标记为stale。
This issue was closed because it has been inactive for 14 days since being marked as stale. 当前issue 被标记为stale已有14天,即将关闭。