point-e icon indicating copy to clipboard operation
point-e copied to clipboard

Is there any suggestion for sampling more points than 4096?

Open easonnie opened this issue 1 year ago • 2 comments

Thank you for releasing point-e! I was wondering whether there would a way to sample more points than 4096?

I tried to do two-steps upsampling but it does not work.

sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model, upsampler_diffusion],
    diffusions=[base_diffusion, upsampler_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024, 4096 * 4 - 4096],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 3.0, 3.0],
    use_karras = (True, True, True),
    karras_steps = (64, 64, 64),
    sigma_min = (1e-3, 1e-3, 1e-3),
    sigma_max = (120, 160, 160),
    s_churn = (3, 0, 0),
)

# Load an image to condition on.
# img = Image.open('example_data/cube_stack.jpg')
img = Image.open('example_data/render_img.png')

# Produce a sample from the model.
samples = None
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
    samples = x

It gives the following errors.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [48], line 7
      5 # Produce a sample from the model.
      6 samples = None
----> 7 for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
      8     samples = x

File ~/miniconda/envs/3d_cu116_py38/lib/python3.8/site-packages/tqdm/notebook.py:259, in tqdm_notebook.__iter__(self)
    257 try:
    258     it = super(tqdm_notebook, self).__iter__()
--> 259     for obj in it:
    260         # return super(tqdm...) will not catch exception
    261         yield obj
    262 # NB: except ... [ as ...] breaks IPython async KeyboardInterrupt

File ~/miniconda/envs/3d_cu116_py38/lib/python3.8/site-packages/tqdm/std.py:1195, in tqdm.__iter__(self)
   1192 time = self._time
   1194 try:
-> 1195     for obj in iterable:
   1196         yield obj
   1197         # Update and possibly print the progressbar.
   1198         # Note: does not call self.update(1) for speed optimisation.

File ~/project/text2shape/repos/3DGen/de_package/point-e/point_e/diffusion/sampler.py:135, in PointCloudSampler.sample_batch_progressive(self, batch_size, model_kwargs)
    133 if stage_guidance_scale != 1 and stage_guidance_scale != 0:
    134     for k, v in stage_model_kwargs.copy().items():
--> 135         stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
    137 if stage_use_karras:
    138     samples_it = karras_sample_progressive(
    139         diffusion=diffusion,
    140         model=model,
   (...)
    149         guidance_scale=stage_guidance_scale,
    150     )

TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

easonnie avatar Dec 21 '22 02:12 easonnie