ark-nlp
ark-nlp copied to clipboard
RuntimeError: sparse tensors do not have strides
单机多卡跑GlobalPoint模型,出现以上错误,其他模型多卡代码跑GlobalPoint没有报错
查看了一下问题,发现是由于label变成sparse导致的。稀疏化可能会防止内存爆炸,但是导致在多卡运行时出现问题,希望把生成labels的过程放在collate_fn中,每个batch去生成就应该不会有问题了。
确实是稀疏化的问题。主要的问题点在于在将张量分散到多个GPU上会涉及到稀疏矩阵相乘的问题,而torch不支持稀疏和密集(稀疏)矩阵相乘。
解决方案如你所说,需要修改collate_fn。但不需要把生成labels的过程放过去,只需要在collate_fn将稀疏矩阵稠密化。
修正方案如下: 在global_pointer_bert_named_entity_recognition.py 的 GlobalPointerNERTask() 中添加
def _train_collate_fn(self, batch):
input_ids = default_collate([f['input_ids'] for f in batch])
attention_mask = default_collate([f['attention_mask'] for f in batch])
token_type_ids = default_collate([f['token_type_ids'] for f in batch])
label_ids = default_collate([f['label_ids'].to_dense() for f in batch])
tensors = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids,
'label_ids': label_ids,
}
return tensors
def _evaluate_collate_fn(self, batch):
return self._train_collate_fn(batch)
之后会有两处 .to_dence() 报错, 将他们删去便可以了
下个版本会修复这个bug
GlobalPoint的label是三维矩阵,直接生成label并放入内存中,如果实体数目很多,训练语料又很大。对于配置一般的同学训练起来应该挺难受的,为了能够较好的循环利用内存资源,私认为把生成label的过程放在collate_fn中这样是比较合适的。提供一个小建议,作者大大考虑一下~
import torch
import warnings
from ark_nlp.factory.loss_function import get_loss
from ark_nlp.factory.utils import conlleval
from ark_nlp.factory.task.base._token_classification import TokenClassificationTask
from ark_nlp.factory.utils.ema import EMA
from torch.utils.data._utils.collate import default_collate
class GlobalPointerNERTask(TokenClassificationTask):
"""
GlobalPointer的命名实体识别Task
Args:
module: 深度学习模型
optimizer: 训练模型使用的优化器名或者优化器对象
loss_function: 训练模型使用的损失函数名或损失函数对象
class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目
scheduler (:obj:`class`, optional, defaults to None): scheduler对象
n_gpu (:obj:`int`, optional, defaults to 1): GPU数目
device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU
cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device
ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数
**kwargs (optional): 其他可选参数
""" # noqa: ignore flake8"
def __init__(self,
module,
tokenizer,
optimizer,
loss_function,
class_num=None,
scheduler=None,
n_gpu=1,
device_ids=None,
device=None,
cuda_device=0,
ema_decay=None,
):
super(TokenClassificationTask).__init__()
self.module = module
self.optimizer = optimizer
self.tokenizer = tokenizer
self.loss_function = get_loss(loss_function)
self.class_num = class_num
self.scheduler = scheduler
self.device_ids = device_ids
self.n_gpu = n_gpu
self.cuda_device = cuda_device
self.device = device
self.ema_decay = ema_decay
if self.device is None:
if torch.cuda.is_available():
if self.cuda_device == -1:
self.device = torch.device("cuda")
else:
self.device = torch.device(f"cuda:{self.cuda_device}")
else:
self.device = "cpu"
if self.n_gpu > 1:
self.module.cuda()
self.module = torch.nn.DataParallel(self.module, device_ids=self.device_ids)
else:
self.module.to(self.device)
self.ema_decay = ema_decay
if self.ema_decay:
self.ema = EMA(self.module.parameters(), decay=self.ema_decay)
def _get_module_inputs_on_train(
self,
inputs,
**kwargs
):
# print(inputs)
self.train_to_device_cols = list(inputs.keys())
for col in self.train_to_device_cols:
if type(inputs[col]) is torch.Tensor:
inputs[col] = inputs[col].to(self.device)
else:
warnings.warn(f"The {col} is not Tensor.\n")
return inputs
def _get_module_inputs_on_eval(
self,
inputs,
**kwargs
):
self.evaluate_to_device_cols = list(inputs.keys())
for col in self.evaluate_to_device_cols:
if type(inputs[col]) is torch.Tensor:
inputs[col] = inputs[col].to(self.device)
else:
warnings.warn(f"The {col} is not Tensor.\n")
return inputs
def _compute_loss(
self,
inputs,
logits,
verbose=True,
**kwargs
):
loss = self.loss_function(logits, inputs['label_ids'])
return loss
def _on_evaluate_begin_record(self, **kwargs):
self.evaluate_logs['eval_loss'] = 0
self.evaluate_logs['eval_step'] = 0
self.evaluate_logs['eval_example'] = 0
self.evaluate_logs['labels'] = []
self.evaluate_logs['logits'] = []
self.evaluate_logs['input_lengths'] = []
self.evaluate_logs['numerate'] = 0
self.evaluate_logs['denominator'] = 0
def _on_evaluate_step_end(self, inputs, outputs, **kwargs):
with torch.no_grad():
# compute loss
logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)
numerate, denominator = conlleval.global_pointer_f1_score(
inputs['label_ids'].cpu(),
logits.cpu()
)
self.evaluate_logs['numerate'] += numerate
self.evaluate_logs['denominator'] += denominator
self.evaluate_logs['eval_example'] += len(inputs['label_ids'])
self.evaluate_logs['eval_step'] += 1
self.evaluate_logs['eval_loss'] += loss.item()
def _on_evaluate_epoch_end(
self,
validation_data,
epoch=1,
is_evaluate_print=True,
id2cat=None,
**kwargs
):
if id2cat is None:
id2cat = self.id2cat
if is_evaluate_print:
print('eval loss is {:.6f}, precision is:{}, recall is:{}, f1_score is:{}'.format(
self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step'],
self.evaluate_logs['numerate'],
self.evaluate_logs['denominator'],
2*self.evaluate_logs['numerate']/self.evaluate_logs['denominator'])
)
def _train_collate_fn(self, batch):
# features =
input_ids_list = []
input_mask_list = []
segment_ids_list = []
global_label_list = []
for (index_, row_) in enumerate(batch):
tokens = self.tokenizer.tokenize(row_['text'])[:self.tokenizer.max_seq_len-2]
token_mapping = self.tokenizer.get_token_mapping(row_['text'], tokens)
start_mapping = {j[0]: i for i, j in enumerate(token_mapping) if j}
end_mapping = {j[-1]: i for i, j in enumerate(token_mapping) if j}
input_ids = self.tokenizer.sequence_to_ids(tokens)
input_ids, input_mask, segment_ids = input_ids
global_label = torch.zeros((
self.class_num,
self.tokenizer.max_seq_len,
self.tokenizer.max_seq_len)
)
for info_ in row_['label']:
if info_['start_idx'] in start_mapping and info_['end_idx'] in end_mapping:
start_idx = start_mapping[info_['start_idx']]
end_idx = end_mapping[info_['end_idx']]
if start_idx > end_idx or info_['entity'] == '':
continue
global_label[self.cat2id[info_['type']], start_idx+1, end_idx+1] = 1
# global_label = torch.tensor(global_label).long()
input_ids_list.append(torch.tensor(input_ids).long())
input_mask_list.append(torch.tensor(input_mask).long())
segment_ids_list.append(torch.tensor(segment_ids).long())
global_label_list.append(torch.tensor(global_label).long())
batch_input_ids = torch.stack(input_ids_list, dim=0)
batch_attention_mask = torch.stack(input_mask_list, dim=0)
batch_token_type_ids = torch.stack(segment_ids_list, dim=0)
batch_labels = torch.stack(global_label_list, dim=0)
features = {
'input_ids': batch_input_ids,
'attention_mask': batch_attention_mask,
'token_type_ids': batch_token_type_ids,
'label_ids': batch_labels
}
return features
def _evaluate_collate_fn(self, batch):
return self._train_collate_fn(batch)
这是我昨晚修改的版本,跑起来应该没啥问题,内存上不会崩溃,给其他同学一些参考
感谢你提供的建议以及能适配低内存的代码,这个改动方式能够有效减少内存的消耗,但会影响模型训练的速度。
我们考虑在v0.1.0引入lazy机制来处理这一类的问题。
哈哈哈 也是希望自己能做点贡献 不能白嫖😂😂😂