relbench icon indicating copy to clipboard operation
relbench copied to clipboard

Potential Bugs in the hybrid_node.py Example Script

Open vlad-moroshan opened this issue 9 months ago • 1 comments

Several possible issues have been observed in the example script examples/hybrid_node.py that may lead to unexpected behavior or unfair comparisons. The following points outline the bugs along with the problematic code lines and suggested improvements:

  1. LightGBM Test Metrics Printing:
    The script prints outdated test metrics (e.g., using metrics from the GNN model when evaluating LightGBM predictions). Bug:

     pred = model.predict(tf_test).numpy()
     print(f"Test: {test_metrics}")
    

    Improvement:

     pred = pred.predict(tf_test).numpy()
     test_metrics = task.evaluate(test_pred)
     print(f"LightGBM Test metrics: {test_metrics}"
    
  2. State Dict Path Formatting:
    Bug:
    The state dict file path is defined with placeholders in a plain string:

    STATE_DICT_PTH = "results/{args.dataset}_{args.task}_state_dict.pth"
    

    Improvement:
    Use an f-string so that the placeholders are replaced with actual values:

    STATE_DICT_PTH = f"results/{args.dataset}_{args.task}_state_dict.pth"
    
  3. Sample Size Usage for GNN Model:
    Issue:
    The GNN model does not utilize the sample_size parameter when training. This means that while the LightGBM model is trained on a subsampled dataset (first sample_size rows), the GNN model is trained on the full training set. This discrepancy can lead to an unfair comparison between the models.

  4. Entity Table Overwriting:
    Bug:
    The script reassigns the entity_table variable in a loop when creating loaders for each split. For example:

    for split in ["train", "val", "test"]:
        table = task.get_table(split)
        table_input = get_node_train_table_input(table=table, task=task)
        entity_table = table_input.nodes[0]  # This gets overwritten each iteration
        ...
    

    Improvement:
    Instead of overwriting, maintain a mapping for each split and reference the correct table. For instance:

    entity_table_mapping: Dict[str, str] = {}
    for split in ["train", "val", "test"]:
        table = task.get_table(split)
        table_input = get_node_train_table_input(table=table, task=task)
        entity_table_mapping[split] = table_input.nodes[0]
        ...
    # Later reference the appropriate entity table, e.g., using task.entity_table = entity_table_mapping["train"]
    

Addressing these issues may help improve the robustness of the example script and ensure a fair comparison between the GNN and LightGBM models.


vlad-moroshan avatar Feb 15 '25 21:02 vlad-moroshan

Thanks for pointing this out @vladislavalerievich ! Since you have proposed the improvements, can you make a quick PR with these changes? We will be happy to merge to main!

rishabh-ranjan avatar Feb 26 '25 00:02 rishabh-ranjan

@rishabh-ranjan could you please review the PR?

vlad-moroshan avatar May 28 '25 20:05 vlad-moroshan

Closing since the PR is now merged.

rishabh-ranjan avatar Aug 05 '25 21:08 rishabh-ranjan