kronfluence icon indicating copy to clipboard operation
kronfluence copied to clipboard

[BUG?] Missing accumulate_iterations call in pairwise score computation when compute_per_module_scores=True

Open MyDum-bsu opened this issue 7 months ago • 0 comments

Is there a critical bug in score/dot_product.py where accumulate_iterations is only called when compute_per_module_scores=False, causing incorrect module state management when computing per-module pairwise scores?

  • File: score/dot_product.py
  • Function: compute_dot_products_with_loader

Problem: accumulate_iterations is only called when compute_per_module_scores=False, but it should be called in both cases to properly clear module state after each iteration:

with torch.no_grad():
    if score_args.compute_per_module_scores:
        for module in cached_module_lst:
            score_chunks[module.name].append(
                module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to(device="cpu", copy=True)
            )
    else:
        # ... code for aggregated scores ...
        score_chunks[ALL_MODULE_NAME].append(pairwise_scores)
        accumulate_iterations(model=model, tracked_module_names=tracked_module_names)  # BUG

Move accumulate_iterations outside the conditional block:

with torch.no_grad():
    if score_args.compute_per_module_scores:
        for module in cached_module_lst:
            score_chunks[module.name].append(
                module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to(device="cpu", copy=True)
            )
    else:
        # ... code for aggregated scores ...
        score_chunks[ALL_MODULE_NAME].append(pairwise_scores)
    
    accumulate_iterations(model=model, tracked_module_names=tracked_module_names)

Correct me if I'm wrong

MyDum-bsu avatar May 30 '25 11:05 MyDum-bsu