attend-infer-repeat-pytorch
attend-infer-repeat-pytorch copied to clipboard
Assertion for number of objects
In spatial_transform.py there is an assertion assert n_obj <= z_wheres.shape[1] in function add_bounding_boxes. Shoudln't the index here be 0 because the z_where is (max_objects, features)?
Hi, I'm terribly sorry for the very late reply.
In principle z_wheres could have shape (1, max_n_objects, 3) or (max_n_objects, 3). I think the right thing to do is to change the index to 0 as you suggest, and move this assertion after this block:
if len(z_wheres.shape) == 3:
assert z_wheres.shape[0] == 1
z_wheres = z_wheres[0]
such that the shape of z_where is always (max_n_objects, 3).
Feel free to update that and open a PR :) Thank you!