hypernerf
hypernerf copied to clipboard
InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3958, 3) and (8, -1, 3)
Hi, during demo in jupyter notebook there is a shape error occuring when the code goes into rendering.
Any clue to fix this error?
9 frames
/usr/local/lib/python3.7/dist-packages/hypernerf/evaluation.py in render_image(state, rays_dict, model_fn, device_count, rng, chunk, default_ret_key)
114 lambda x: x[(proc_id * per_proc_rays):((proc_id + 1) * per_proc_rays)],
115 chunk_rays_dict)
--> 116 chunk_rays_dict = utils.shard(chunk_rays_dict, device_count)
117 model_out = model_fn(key_0, key_1, state.optimizer.target['model'],
118 chunk_rays_dict, state.extra_params)
/usr/local/lib/python3.7/dist-packages/hypernerf/utils.py in shard(xs, device_count)
287 if device_count is None:
288 jax.local_device_count()
--> 289 return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
290
291
/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py in tree_map(f, tree, is_leaf, *rest)
176 leaves, treedef = tree_flatten(tree, is_leaf)
177 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 178 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
179
180 tree_multimap = tree_map
/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py in <genexpr>(.0)
176 leaves, treedef = tree_flatten(tree, is_leaf)
177 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 178 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
179
180 tree_multimap = tree_map
/usr/local/lib/python3.7/dist-packages/hypernerf/utils.py in <lambda>(x)
287 if device_count is None:
288 jax.local_device_count()
--> 289 return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
290
291
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _reshape(a, order, *args)
1727
1728 def _reshape(a, *args, order="C"):
-> 1729 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
1730 if order == "C":
1731 return lax.reshape(a, newshape, None)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _compute_newshape(a, newshape)
1723 return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
1724 if core.symbolic_equal_dim(d, -1) else d
-> 1725 for d in newshape)
1726
1727
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in <genexpr>(.0)
1723 return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
1724 if core.symbolic_equal_dim(d, -1) else d
-> 1725 for d in newshape)
1726
1727
/usr/local/lib/python3.7/dist-packages/jax/core.py in divide_shape_sizes(s1, s2)
1407 s2 = s2 or (1,)
1408 handler, ds = _dim_handler_and_canonical(*s1, *s2)
-> 1409 return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
1410
1411 def same_shape_sizes(s1: Shape, s2: Shape) -> bool:
/usr/local/lib/python3.7/dist-packages/jax/core.py in divide_shape_sizes(self, s1, s2)
1322 return 1
1323 if sz1 % sz2:
-> 1324 raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
1325 return sz1 // sz2
1326
InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3958, 3) and (8, -1, 3)
If I guess right, you used 8
gpus to run the code.
An annoying alternative is change 8
to 2
, making 3958
could divide 2
evenly .
Hoping anyone could give a better solution :(