torchacc
torchacc copied to clipboard
使用self.model = torchacc.accelerate(model)加速模型训练的时,GPU利用率极低,接近于0
class Trainer:
def __init__(self, global_rank, gpu_id: int, trainer_config: TrainerConfig, model: RecNet, optimizer,
world_size: int, data_cfg: DataConfig):
self.global_rank = global_rank
self.config = trainer_config
self.world_size = world_size
self.dataloader = Data(data_cfg)
self.epochs_run = 0
self.gpu_id = gpu_id
self.model = model.to(gpu_id)
self.optimizer = optimizer
if self.config.use_amp:
self.scaler = torch.cuda.amp.GradScaler()
# load snapshot if available. only necessary on the first node.
if self.config.snapshot_path is None:
self.config.snapshot_path = "snapshot.pt"
self._load_snapshot()
# wrap with DDP. this step will synch model across all the processes.
# self.model = torch.compile(DDP(self.model, device_ids=[gpu_id]))
torch.set_float32_matmul_precision('high')
def _load_snapshot(self):
try:
snapshot = fsspec.open(self.config.snapshot_path)
with snapshot as f:
snapshot_data = torch.load(f, map_location="cpu")
except FileNotFoundError:
print("Snapshot not found. Training model from scratch")
return
snapshot = Snapshot(**snapshot_data)
self.model.load_state_dict(snapshot.model_state)
self.optimizer.load_state_dict(snapshot.optimizer_state)
self.epochs_run = snapshot.finished_epoch
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
def cal_loss(self, score, labels):
click_loss = F.binary_cross_entropy(score["click_score"], labels['click_label'])
add_loss = F.binary_cross_entropy(score["add_score"], labels['add_label'])
order_loss = F.binary_cross_entropy(score["add_order_score"], labels['order_label'])
loss = self.config.click_loss_weight * click_loss + \
self.config.add_loss_weight * add_loss + self.config.order_loss_weight * order_loss
return loss
def _run_batch(self, features, labels, train: bool = True):
with torch.set_grad_enabled(train), torch.amp.autocast(device_type="cuda", dtype=torch.float16,
enabled=self.config.use_amp):
score = self.model(features)
loss = self.cal_loss(score, labels)
if train:
self.optimizer.zero_grad(set_to_none=True)
if self.config.use_amp:
self.scaler.scale(loss).backward()
if self.config.use_clip_grad:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm_clip)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
if self.config.use_clip_grad:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm_clip)
self.optimizer.step()
return loss.item()
def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
with self.model.join():
for _iter, (features, labels) in enumerate(dataloader):
features = {feat_name: torch.as_tensor(data=feat_data, dtype=torch.int, device=self.gpu_id)
for feat_name, feat_data in features.items()}
labels = {label_name: torch.as_tensor(data=label_data, dtype=torch.float, device=self.gpu_id)
for label_name, label_data in labels.items()}
step_type = "Train" if train else "Eval"
batch_loss = self._run_batch(features, labels, train)
if _iter % 100 == 99:
logging.info(
f"{datetime.datetime.now()} [GPU{self.gpu_id}] Epoch {epoch} | Iter {_iter} | {step_type} Loss {batch_loss:.5f}")
def _eval(self, step, epoch, is_eval):
logging.info("eval start")
self.dataloader.config.data_file = self.dataloader.config.test_file
test_loader = self.dataloader.data_input_fn_torch_eval()
self.model.eval()
with torch.no_grad(): # 禁用梯度计算
loss = 0.0
_iter = 0
click_scores = []
add_scores = []
order_scores = []
click_labels = []
add_labels = []
order_lebels = []
for _iter, (features, labels) in enumerate(test_loader):
features = {feat_name: torch.as_tensor(data=feat_data, dtype=torch.int, device=self.gpu_id)
for feat_name, feat_data in features.items()}
labels = {label_name: torch.as_tensor(data=label_data, dtype=torch.float, device=self.gpu_id)
for label_name, label_data in labels.items()}
score = self.model(features, is_eval)
batch_loss = self.cal_loss(score, labels)
click_scores.append(score["click_score"].detach().cpu().numpy())
add_scores.append(score["add_score"].detach().cpu().numpy())
order_scores.append(score["add_order_score"].detach().cpu().numpy())
click_labels.append(labels['click_label'].detach().cpu().numpy())
add_labels.append(labels['add_label'].detach().cpu().numpy())
order_lebels.append(labels['order_label'].detach().cpu().numpy())
loss += batch_loss
click_auc = roc_auc_score(np.concatenate(click_labels), np.concatenate(click_scores))
add_auc = roc_auc_score(np.concatenate(add_labels), np.concatenate(add_scores))
order_auc = roc_auc_score(np.concatenate(order_lebels), np.concatenate(order_scores))
logging.info(
f'''{datetime.datetime.now()} [GPU{self.gpu_id}] Epoch {epoch} | Iter {step} | Eval Loss {loss / _iter:.5f},
click_auc:{click_auc}, add_auc:{add_auc}, order_auc:{order_auc}''')
send_msg("product-feeds-category-rank-model-v62:: click_auc: {}, add_auc: {}, order_auc: {}".format(
click_auc, add_auc, order_auc))
if add_auc < 0.7:
send_msg("@majun26 product-feeds-category-rank-model-v62:: add_auc: {} is too low".format(add_auc))
sys.exit(1)
def _save_checkpoint(self, epoch=1):
state_dict = self.model.cpu().state_dict()
unwanted_prefix = '_orig_mod.module.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
if k.startswith(unwanted_prefix + 'pos_logit'):
state_dict.pop(k)
else:
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
PATH = f'{self.config.checkpoint_path}/checkpoint.pt'
torch.save(state_dict, PATH)
logging.info(f"Epoch {epoch} | Training checkpoint saved at {PATH}")
def train(self):
self.dataloader.config.data_file = self.dataloader.config.train_file
train_loader = self.dataloader.data_input_fn_torch_train(self.world_size, self.global_rank)
for i in range(self.config.max_epochs):
self._run_epoch(i, train_loader, train=True)
def eval(self, is_eval):
self._eval(1, 1, is_eval)
class Optimizer:
def __init__(self, model: RecNet, opt_config: OptimizerConfig):
self.model = model
self.opt_config = opt_config
def get(self):
if self.opt_config.name.lower() == 'adam':
return self.adam()
if self.opt_config.name.lower() == 'adamw':
return self.adamw()
def adam(self):
return torch.optim.Adam(params=self.model.parameters(),
lr=self.opt_config.learning_rate)
def adamw(self):
return torch.optim.AdamW(params=self.model.parameters(),
lr=self.opt_config.learning_rate)
@anw90 @qiuxiafei 大佬们,帮忙看看。
已线下钉钉群沟通