attend-infer-repeat-pytorch icon indicating copy to clipboard operation
attend-infer-repeat-pytorch copied to clipboard

Assertion for number of objects

Open nileshkumar0726 opened this issue 4 years ago • 1 comments

In spatial_transform.py there is an assertion assert n_obj <= z_wheres.shape[1] in function add_bounding_boxes. Shoudln't the index here be 0 because the z_where is (max_objects, features)?

nileshkumar0726 avatar Dec 03 '21 16:12 nileshkumar0726

Hi, I'm terribly sorry for the very late reply.

In principle z_wheres could have shape (1, max_n_objects, 3) or (max_n_objects, 3). I think the right thing to do is to change the index to 0 as you suggest, and move this assertion after this block:

if len(z_wheres.shape) == 3:
    assert z_wheres.shape[0] == 1
    z_wheres = z_wheres[0]

such that the shape of z_where is always (max_n_objects, 3).

Feel free to update that and open a PR :) Thank you!

addtt avatar Sep 10 '22 11:09 addtt