RecBole-GNN
RecBole-GNN copied to clipboard
运行srgnn时报错
敬爱的工作者您好!我在运行srgnn时报错,猜测应该是main函数中trainer和interaction使用的是recbole而非recbole_gnn框架下的问题,但我不知道如何进行修改补充,辛苦您为我答疑解惑,期待您的回复,万分感谢!
main函数: from recbole_gnn.config import Config from recbole_gnn.utils import create_dataset, data_preparation from recbole.utils import init_logger, init_seed from recbole_gnn.utils import set_color, get_trainer from logging import getLogger
from test import SRGNN
if name == 'main': # configurations initialization config = Config( model=SRGNN, dataset='diginetica', config_file_list=['config.yaml', 'config_model.yaml'], ) init_seed(config['seed'], config['reproducibility'])
# logger initialization
init_logger(config)
logger = getLogger()
logger.info(config)
# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)
model = SRGNN(config, train_data.dataset).to(config['device'])
logger.info(model)
# trainer loading and initialization
# trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=True, show_progress=config['show_progress']
)
# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=True, show_progress=config['show_progress'])
logger.info(set_color('best valid result:', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result:', 'yellow') + f': {test_result}')
config.yaml与config_model.yaml均使用框架中提供的参数。
运行结果: General Hyper Parameters: gpu_id = 0 use_gpu = True seed = 2020 state = INFO reproducibility = True data_path = dataset/diginetica checkpoint_dir = saved show_progress = True save_dataset = False dataset_save_path = None save_dataloaders = False dataloaders_save_path = None log_wandb = False
Training Hyper Parameters: epochs = 500 train_batch_size = 4096 learner = adam learning_rate = 0.001 neg_sampling = None eval_step = 1 stopping_step = 10 clip_grad_norm = None weight_decay = 0.0 loss_decimal_place = 4
Evaluation Hyper Parameters: eval_args = {'split': {'LS': 'valid_and_test'}, 'mode': 'full', 'order': 'TO', 'group_by': 'user'} repeatable = True metrics = ['MRR', 'Precision'] topk = [10, 20] valid_metric = MRR@10 valid_metric_bigger = True eval_batch_size = 2000 metric_decimal_place = 5
Dataset Hyper Parameters:
field_separator =
seq_separator =
USER_ID_FIELD = session_id
ITEM_ID_FIELD = item_id
RATING_FIELD = rating
TIME_FIELD = timestamp
seq_len = None
LABEL_FIELD = label
threshold = None
NEG_PREFIX = neg_
load_col = {'inter': ['session_id', 'item_id', 'timestamp']}
unload_col = None
unused_col = None
additional_feat_suffix = None
rm_dup_inter = None
val_interval = None
filter_inter_by_user_or_item = True
user_inter_num_interval = [5,inf)
item_inter_num_interval = [5,inf)
alias_of_user_id = None
alias_of_item_id = None
alias_of_entity_id = None
alias_of_relation_id = None
preload_weight = None
normalize_field = None
normalize_all = None
ITEM_LIST_LENGTH_FIELD = item_length
LIST_SUFFIX = _list
MAX_ITEM_LIST_LENGTH = 20
POSITION_FIELD = position_id
HEAD_ENTITY_ID_FIELD = head_id
TAIL_ENTITY_ID_FIELD = tail_id
RELATION_ID_FIELD = relation_id
ENTITY_ID_FIELD = entity_id
benchmark_filename = None
Other Hyper Parameters: wandb_project = recbole require_pow = False embedding_size = 64 step = 1 loss_type = CE MODEL_TYPE = ModelType.SEQUENTIAL gnn_transform = sess_graph train_neg_sample_args = {'strategy': 'none'} MODEL_INPUT_TYPE = InputType.POINTWISE eval_type = EvaluatorType.RANKING device = cpu eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}
06 Mar 13:17 INFO diginetica
The number of users: 72014
Average actions of users: 8.060905669809618
The number of items: 29454
Average actions of items: 19.70902794282416
The number of inters: 580490
The sparsity of the dataset: 99.97263260088765%
Remain Fields: ['session_id', 'item_id', 'timestamp']
06 Mar 13:17 INFO Constructing session graphs.
100%|██████████| 364451/364451 [00:33<00:00, 11034.37it/s]
06 Mar 13:18 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:07<00:00, 9464.61it/s]
06 Mar 13:18 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:07<00:00, 9047.17it/s]
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO [Training]: train_batch_size = [4096] negative sampling: [{'strategy': 'none'}]
06 Mar 13:18 INFO [Evaluation]: eval_batch_size = [2000] eval_args: [{'split': {'LS': 'valid_and_test'}, 'mode': 'full', 'order': 'TO', 'group_by': 'user'}]
06 Mar 13:18 INFO SRGNN(
(item_embedding): Embedding(29454, 64, padding_idx=0)
(gnncell): SRGNNCell(
(incomming_conv): SRGNNConv()
(outcomming_conv): SRGNNConv()
(lin_ih): Linear(in_features=128, out_features=192, bias=True)
(lin_hh): Linear(in_features=64, out_features=192, bias=True)
)
(linear_one): Linear(in_features=64, out_features=64, bias=True)
(linear_two): Linear(in_features=64, out_features=64, bias=True)
(linear_three): Linear(in_features=64, out_features=1, bias=False)
(linear_transform): Linear(in_features=128, out_features=64, bias=True)
(loss_fct): CrossEntropyLoss()
)
Trainable parameters: 1947264
Train 0: 0%| | 0/89 [00:00<?, ?it/s]
Traceback (most recent call last):
File "E:/ADACONDA/envs/pytorch/pythonproject_test/Next Work/RecBole-GNN-main/main.py", line 41, in
你好!recbole_gnn重写了自己的主函数,包括数据集构造,模型选择等都与recbole有一定区别。
运行SRGNN,可以git clone本仓库到本地,安装必要的包后,运行主函数
python run_recbole_gnn.py --model="SRGNN"
即可
你好!recbole_gnn重写了自己的主函数,包括数据集构造,模型选择等都与recbole有一定区别。 运行SRGNN,可以git clone本仓库到本地,安装必要的包后,运行主函数
python run_recbole_gnn.py --model="SRGNN"
即可
您好,感谢回复!因为之前一直是按照修改的主函数进行运行,所以我仍旧希望能在之前的主函数进行。经过和其他使用recbole的同学交流,他使用以下的代码是可以成功运行的,但是我会在训练的部分出错,修改后的代码与运行结果如下: from logging import getLogger from recbole.utils import init_logger, init_seed, set_color from recbole_gnn.config import Config from recbole_gnn.utils import get_trainer from recbole_gnn.data.dataset import GCEGNNDataset from recbole_gnn.utils import _get_customized_dataloader from recbole.data.utils import create_samplers
=== from model ===
from model import myGCEGNN
from test import GCEGNN
if name == 'main': # configurations initialization config = Config( model=GCEGNN, dataset='diginetica', config_file_list=['config.yaml', 'config_model.yaml'], ) init_seed(config['seed'], config['reproducibility']) # logger initialization init_logger(config) logger = getLogger()
logger.info(config)
# dataset filtering
dataset = GCEGNNDataset(config)
# dataset = create_dataset(config)
logger.info(dataset)
# dataset splitting
built_datasets = dataset.build()
train_dataset, valid_dataset, test_dataset = built_datasets
train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets)
train_data = _get_customized_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True)
valid_data = _get_customized_dataloader(config, 'evaluation')(config, valid_dataset, valid_sampler, shuffle=False)
test_data = _get_customized_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False)
# train_data, valid_data, test_data = data_preparation(config, dataset)
# model loading and initialization
model = GCEGNN(config, train_data.dataset).to(config['device'])
logger.info(model)
# trainer loading and initialization
# trainer = Trainer(config, model)
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=True, show_progress=config['show_progress']
)
# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=True, show_progress=config['show_progress'])
logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result', 'yellow') + f': {test_result}')
运行结果:
14 Mar 14:36 INFO
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 2020
state = INFO
reproducibility = True
data_path = dataset/diginetica
checkpoint_dir = saved
show_progress = True
save_dataset = False
dataset_save_path = None
save_dataloaders = False
dataloaders_save_path = None
log_wandb = False
Training Hyper Parameters: epochs = 500 train_batch_size = 4096 learner = adam learning_rate = 0.001 neg_sampling = None eval_step = 1 stopping_step = 10 clip_grad_norm = None weight_decay = 0.0 loss_decimal_place = 4
Evaluation Hyper Parameters: eval_args = {'split': {'LS': 'valid_and_test'}, 'mode': 'full', 'order': 'TO', 'group_by': 'user'} repeatable = True metrics = ['MRR', 'Precision'] topk = [10, 20] valid_metric = MRR@10 valid_metric_bigger = True eval_batch_size = 2000 metric_decimal_place = 5
Dataset Hyper Parameters:
field_separator =
seq_separator =
USER_ID_FIELD = session_id
ITEM_ID_FIELD = item_id
RATING_FIELD = rating
TIME_FIELD = timestamp
seq_len = None
LABEL_FIELD = label
threshold = None
NEG_PREFIX = neg_
load_col = {'inter': ['session_id', 'item_id', 'timestamp']}
unload_col = None
unused_col = None
additional_feat_suffix = None
rm_dup_inter = None
val_interval = None
filter_inter_by_user_or_item = True
user_inter_num_interval = [5,inf)
item_inter_num_interval = [5,inf)
alias_of_user_id = None
alias_of_item_id = None
alias_of_entity_id = None
alias_of_relation_id = None
preload_weight = None
normalize_field = None
normalize_all = None
ITEM_LIST_LENGTH_FIELD = item_length
LIST_SUFFIX = _list
MAX_ITEM_LIST_LENGTH = 20
POSITION_FIELD = position_id
HEAD_ENTITY_ID_FIELD = head_id
TAIL_ENTITY_ID_FIELD = tail_id
RELATION_ID_FIELD = relation_id
ENTITY_ID_FIELD = entity_id
benchmark_filename = None
Other Hyper Parameters: wandb_project = recbole require_pow = False MODEL_TYPE = ModelType.SEQUENTIAL embedding_size = 64 leakyrelu_alpha = 0.2 dropout_local = 0.0 dropout_global = 0.5 dropout_gcn = 0.0 loss_type = CE gnn_transform = sess_graph build_global_graph = True sample_num = 12 hop = 1 train_neg_sample_args = {'strategy': 'none'} n_layers = 1 n_heads = 1 hidden_size = 64 inner_size = 256 hidden_dropout_prob = 0.2 attn_dropout_prob = 0.2 hidden_act = gelu layer_norm_eps = 1e-12 initializer_range = 0.02 step = 1 weight = 0.6 reg_weight = 5e-05 MODEL_INPUT_TYPE = InputType.POINTWISE eval_type = EvaluatorType.RANKING device = cpu eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}
14 Mar 14:36 INFO diginetica
The number of users: 72014
Average actions of users: 8.060905669809618
The number of items: 29454
Average actions of items: 19.70902794282416
The number of inters: 580490
The sparsity of the dataset: 99.97263260088765%
Remain Fields: ['session_id', 'item_id', 'timestamp']
14 Mar 14:37 INFO Reversing sessions.
100%|██████████| 364451/364451 [00:09<00:00, 39432.99it/s]
14 Mar 14:37 INFO Constructing session graphs.
100%|██████████| 364451/364451 [02:15<00:00, 2684.90it/s]
14 Mar 14:39 INFO Reversing sessions.
100%|██████████| 72013/72013 [00:01<00:00, 38924.85it/s]
14 Mar 14:39 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:28<00:00, 2494.76it/s]
14 Mar 14:40 INFO Reversing sessions.
100%|██████████| 72013/72013 [00:01<00:00, 39521.40it/s]
14 Mar 14:40 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:30<00:00, 2388.86it/s]
14 Mar 14:40 INFO SessionGraph Transform in DataLoader.
14 Mar 14:40 INFO SessionGraph Transform in DataLoader.
14 Mar 14:40 INFO SessionGraph Transform in DataLoader.
14 Mar 14:40 INFO Constructing global graphs.
Converting: 100%|██████████| 364451/364451 [00:01<00:00, 288417.84it/s]
Sorting: 100%|██████████| 29454/29454 [00:00<00:00, 91150.18it/s]
14 Mar 14:40 INFO GCEGNN(
(item_embedding): Embedding(29454, 64, padding_idx=0)
(pos_embedding): Embedding(20, 64)
(local_agg): LocalAggregator()
(global_agg): ModuleList(
(0): GlobalAggregator()
)
(w_1): Linear(in_features=128, out_features=64, bias=False)
(w_2): Linear(in_features=64, out_features=1, bias=False)
(glu1): Linear(in_features=64, out_features=64, bias=True)
(glu2): Linear(in_features=64, out_features=64, bias=False)
(loss_fct): CrossEntropyLoss()
)
Trainable parameters: 1915584
Train 0: 0%| | 0/89 [00:00<?, ?it/s]
Traceback (most recent call last):
File "E:/ADACONDA/envs/pytorch/pythonproject_test/Next Work/RecBole-GNN-main/main.py", line 52, in
同时附上我的路径截图
main.py为主函数,test.py为模型: