relbench
relbench copied to clipboard
Potential Bugs in the hybrid_node.py Example Script
Several possible issues have been observed in the example script examples/hybrid_node.py that may lead to unexpected behavior or unfair comparisons. The following points outline the bugs along with the problematic code lines and suggested improvements:
-
LightGBM Test Metrics Printing:
The script prints outdated test metrics (e.g., using metrics from the GNN model when evaluating LightGBM predictions). Bug:pred = model.predict(tf_test).numpy() print(f"Test: {test_metrics}")Improvement:
pred = pred.predict(tf_test).numpy() test_metrics = task.evaluate(test_pred) print(f"LightGBM Test metrics: {test_metrics}" -
State Dict Path Formatting:
Bug:
The state dict file path is defined with placeholders in a plain string:STATE_DICT_PTH = "results/{args.dataset}_{args.task}_state_dict.pth"Improvement:
Use an f-string so that the placeholders are replaced with actual values:STATE_DICT_PTH = f"results/{args.dataset}_{args.task}_state_dict.pth" -
Sample Size Usage for GNN Model:
Issue:
The GNN model does not utilize thesample_sizeparameter when training. This means that while the LightGBM model is trained on a subsampled dataset (firstsample_sizerows), the GNN model is trained on the full training set. This discrepancy can lead to an unfair comparison between the models. -
Entity Table Overwriting:
Bug:
The script reassigns theentity_tablevariable in a loop when creating loaders for each split. For example:for split in ["train", "val", "test"]: table = task.get_table(split) table_input = get_node_train_table_input(table=table, task=task) entity_table = table_input.nodes[0] # This gets overwritten each iteration ...Improvement:
Instead of overwriting, maintain a mapping for each split and reference the correct table. For instance:entity_table_mapping: Dict[str, str] = {} for split in ["train", "val", "test"]: table = task.get_table(split) table_input = get_node_train_table_input(table=table, task=task) entity_table_mapping[split] = table_input.nodes[0] ... # Later reference the appropriate entity table, e.g., using task.entity_table = entity_table_mapping["train"]
Addressing these issues may help improve the robustness of the example script and ensure a fair comparison between the GNN and LightGBM models.
Thanks for pointing this out @vladislavalerievich ! Since you have proposed the improvements, can you make a quick PR with these changes? We will be happy to merge to main!
@rishabh-ranjan could you please review the PR?
Closing since the PR is now merged.