ignite
ignite copied to clipboard
COCO mAP metric
Description: Mean Average Precision metric for detection in COCO dataset. Check list:
- [x] New tests are added (if a new feature is added)
- [x] New doc strings: description and/or example code are in RST format
- [x] Documentation is updated (if required)
I reckon it'd be fantastic to get rid of these complex nested structures and break up large functions by spreading their functionality across smaller functions.

Hi @AlexanderChaptykov , thanks for your comment. Yes that would be nice, but I need a short time to add a few commits beforehand.
@vfdev-5 , finally It's ready to get reviewed.
@vfdev-5 , I couldn't find where you wanted me to do a benchmark on allgather_object vs all_gather_tensors_with_shapes, so I put the result here:
import itertools
import faulthandler
from argparse import ArgumentParser
from datetime import timedelta
from typing import Optional, Dict, List
from typing_extensions import Literal
import torch
import torch.distributed as dist
from ignite import distributed as idist
from ignite.utils import manual_seed
from torch.utils.benchmark import Timer
def get_X(device, nc: int) -> Dict[int, List[torch.Tensor]]:
n_iou_thresh = 3
c_existence_prob = .7
n_det = 25
return {c: [torch.randn((n_iou_thresh, torch.randint(1, n_det, ()).item()), device=device, dtype=torch.double) for _ in range(torch.randint(1, 100, ()).item())] for c in range(nc) if torch.rand(()).item() < c_existence_prob}
def allgather_object(X: Dict[int, List[torch.Tensor]], nc: int):
all_Xs = [None for _ in range(idist.get_world_size())]
dist.all_gather_object(all_Xs, X)
Xs = {}
for c in range(nc):
Xc = list(itertools.chain(*[all_Xs[r][c] for r in range(idist.get_world_size()) if c in all_Xs[r]]))
if Xc:
Xs[c] = torch.cat(Xc, dim=-1)
def allgather(X: Dict[int, List[torch.Tensor]], device, nc: int):
d1_size_per_c = torch.tensor([sum([t.size(1) for t in X[c]]) if c in X else 0 for c in range(nc)], device=device)
d1_size_per_c_across_ranks = torch.stack(idist.all_gather(d1_size_per_c).split(split_size=nc))
a_nonempty_rank, a_nonempty_c = list(zip(*torch.where(d1_size_per_c_across_ranks != 0))).pop(0)
a_nonempty_rank = a_nonempty_rank.item()
a_nonempty_c = a_nonempty_c.item()
d0_size = idist.broadcast(
torch.tensor(X[a_nonempty_c][-1].shape[:-1], device=device) if idist.get_rank() == a_nonempty_rank else None,
a_nonempty_rank,
safe_mode=True,
).item()
Xs = {}
for c in range(nc):
d1_size_across_ranks = d1_size_per_c_across_ranks[:, [c]]
if d1_size_across_ranks.any():
shape_across_ranks = [(d0_size, d1_size_in_rank.item()) for d1_size_in_rank in d1_size_across_ranks]
Xs[c] = torch.cat(
idist.utils.all_gather_tensors_with_shapes(
torch.cat(X[c], dim=-1) if c in X else torch.empty((d0_size, 0), dtype=torch.double, device=device),
shape_across_ranks
),
dim=-1
)
setup_statement = """
from __main__ import get_X, allgather_object, allgather
X = get_X(device, nc)
"""
def benchmark(rank:int, backend: Optional[Literal['gloo', 'nccl']] = 'gloo'):
if rank == 0:
faulthandler.enable()
manual_seed(41 + rank)
device = torch.device('cpu') if backend == 'gloo' else torch.device(f'cuda:{rank}')
nc = 30
_globals = {'backend': backend, 'device': device, 'nc': nc}
allgather_timer = Timer(stmt="allgather(X, device, nc)", setup=setup_statement, globals=_globals)
allgather_object_timer = Timer(stmt="allgather_object(X, nc)", setup=setup_statement, globals=_globals)
allgather_mean_time_in_ms = allgather_timer.timeit(3).mean * 1e3
allgather_object_mean_time_in_ms = allgather_object_timer.timeit(3).mean * 1e3
if rank == 0:
print(f"allgather time: {allgather_mean_time_in_ms:>10.1f} ms")
print(f"allgather_object time: {allgather_object_mean_time_in_ms:>10.1f} ms")
if __name__ == "__main__":
argparser = ArgumentParser()
argparser.add_argument("-b", "--backend", choices=['gloo', 'nccl'], default='gloo')
argparser.add_argument("-n", "--nprocs", type=int, default=4)
args = argparser.parse_args()
idist.spawn(args.backend, benchmark, (args.backend,), nproc_per_node=args.nprocs, timeout=timedelta(seconds=5))
Result on Gloo:
allgather time: 76.1 ms
allgather_object time: 846.5 ms