fourier-feature-networks icon indicating copy to clipboard operation
fourier-feature-networks copied to clipboard

Is it possible to add time parameter on the input?

Open darwinharianto opened this issue 3 years ago • 7 comments

From what I understand, this method can reconstruct crisp image just using coordinates. Is it possible to encode a whole video (x,y,t) using this model?

I have tried to introduce x,y,t on this, but the PSNR is getting smaller deeper in t. (starts at 25ish for the first few images, and getting lower to 16 by the 25th image) I tried to encode t using normal positional encoding.

If it is possible to encode t, Would it be beneficial for NERF with t-axis?

darwinharianto avatar Oct 14 '22 05:10 darwinharianto

Fascinating question!

Can you share more about your training data and how you are training? Assuming that you have a training frame for each timestamp, and the training loop treats x, y, and t the same way, I don't see why PSNR should decrease in the t dimension. Unless, I am misreading your question and the different PSNRs are from different training runs with different number of frames?

Note: the research I'm familiar with in encoding Nerf videos (https://arxiv.org/abs/1906.07751, https://nerfies.github.io/) put a prior on coherent motion between frames. They learn a motion / vector field / distortion map, rather than the "naive" f(u, v, x, y, z, t) approach. It's possible for a 2D video like you are proposing, the naive approach may be sufficient (but I suspect a coherent motion prior might improve results).

ndahlquist avatar Oct 14 '22 16:10 ndahlquist

this is the notebook that I used, I copied it from somewhere else and added PE https://colab.research.google.com/drive/1GbhWcmWd6BXPj8ishqgN2zzGCpGvB8sK?usp=sharing I ran out of compute units and memory, so I haven't run it again.

Since a video might contain different images, I made my input as 5 random images, and repeats it 5 times, to see if it could work. (A,B,C,D,E,A,B,C,D,E,A,B,C,D,E....,C,D,E) and each image corresponds to 1 timestamp, so A1, B2, C3, .... C23, D24, E25

To train it, I encode x,y using the fourier feature, and tried to add PE to that fourier feature (just like in the transformer) and train on all those 25 images

I was expecting maybe those with same images would have close PSNR results, but the PSNR just gradually decreases deeper into t axis

I guess it was because I added PE to the fourier feature, so I tried decreasing the value of PE before adding it, but it just makes it worse...

darwinharianto avatar Oct 15 '22 06:10 darwinharianto

Okay, I haven't tried running this code, so I may be misunderstanding it, but here are some initial thoughts:

  1. I would have expected that instead of an xy_grid, you would want to create an xyt_grid. Is the t dimension being encoded properly here?
# Note: this can be done outside of the training loop, since the result at this stage is unchanged during the course of training.
x = GaussianFourierFeatureTransform(2, 128, 10)(xy_grid)

# can I reconstruct videos?

# addded pos encoding
pe = PositionalEncoding(256)
inputs = x.repeat(target.shape[0], 1, 1, 1)

inputWithPE = pe.forward(inputs)
  1. It looks like in your training loop, you're treating each frame as a separate "batch", and always going in order from frame 0 to N. I would suggest randomizing the order of the frames (otherwise the loss may not be uniform with respect to t, as you are observing).
for epoch in tqdm(range(5000)):
    optimizer.zero_grad()

    for i in range(target.shape[0]):
      generated = model(inputWithPE[i:i+1])

      loss = torch.nn.functional.l1_loss(target[i:i+1], generated)
      loss.backward()
      optimizer.step()
  1. Your "video" is made up of completely independent frames/images. I'd suspect it will be easier for the network to learn an actual video (same subject with movement). This should be fine for your proof of concept, but I expect you'd get a better PSNR with a real video.

ndahlquist avatar Oct 15 '22 17:10 ndahlquist

The mismatched optimizer.zero_grad and optimizer.step() in your training loop looks atypical as well. I suspect you might have intended:

for epoch in tqdm(range(5000)):
    for i in range(target.shape[0]):
      optimizer.zero_grad()

      generated = model(inputWithPE[i:i+1])

      loss = torch.nn.functional.l1_loss(target[i:i+1], generated)
      loss.backward()
      optimizer.step()

ndahlquist avatar Oct 15 '22 17:10 ndahlquist

Thank you for the feedback. I updated the notebook with the changes. After applying it and re-run them, they now have more similar PSNR across all timestep. PSNR values seems to be peaking at 20ish (I was hoping for 30~60 value). I guess this approach is not enough?

  1. Ah yes, I wasn't sure on how to change the fourierfeature class, but now I have made it to accept xyt_grid. I hope I changed it correctly
class GaussianFourierFeatureTimestepTransform(torch.nn.Module):
    """
    An implementation of Gaussian Fourier feature with timestep mapping.
    
    Based from
    "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
       https://arxiv.org/abs/2006.10739
       https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html

    Given an input of size [batches, num_input_channels, width, height],
     returns a tensor of size [batches, mapping_size*2, width, height].
    """

    def __init__(self, num_input_channels, mapping_size=256, scale=10):
        super().__init__()

        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = torch.randn((num_input_channels, mapping_size)) * scale

    def forward(self, x):
        assert x.dim() == 5, 'Expected 4D input (got {}D input)'.format(x.dim())

        batches, channels, width, height, timestep = x.shape

        assert channels == self._num_input_channels,\
            "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)

        # Make shape compatible for matmul with _B.
        # From [B, C, W, H, t] to [(B*W*H*t), C].
        x = x.permute(0, 2, 3, 4, 1).reshape(batches * width * height * timestep, channels)

        x = x @ self._B.to(x.device)

        # From [(B*W*H), C] to [B, W, H, t, C]
        x = x.view(batches, width, height, timestep, self._mapping_size)
        # From [B, W, H, C] to [B, C, W, H, t]
        x = x.permute(0, 4, 1, 2, 3)

        x = 2 * np.pi * x
        return torch.cat([torch.sin(x), torch.cos(x)], dim=1)
  1. Ah you are right, I didn't realize that it might be the problem, now I shuffled it so it comes in random order and fixed the optimizer.zero_grad() and optimizer.step() location
  for epoch in tqdm(range(1000)):
    imageIndices = list(range(target.shape[0]))
    for idx in random.sample(imageIndices, len(imageIndices)):
      optimizer.zero_grad()
      generated = model(x[:,:,:,:,idx])
      loss = torch.nn.functional.l1_loss(target[idx:idx+1], generated)
      loss.backward()
      optimizer.step()
  1. Yes, I was thinking maybe I could take maybe a marvel movie and encode it inside this model. There are sequences that are similar and cuts that are really different, that's why I am using different images in this sample

darwinharianto avatar Oct 17 '22 02:10 darwinharianto

Glad to hear it's working better! I have not much to add on the fourierfeature implementation or the PSNR result.

My only vague thought would be to make sure that x, y, and t are scaled to similar ranges before being passed into the fouriertransform. For example, if x and y are normalized coordinates in the range [0, 1], but t is in the range [0, NUM_FRAMES], that could make it harder for the network to converge to a good result. But that's really just a shot in the dark.

ndahlquist avatar Oct 17 '22 16:10 ndahlquist

Thanks for the feedback!

They are both scaled from 0 to 1. The x and y uses the same value, but t is using num_frames

# target shape is 25,512,512
coords = np.linspace(0, 1, target.shape[2], endpoint=False) # for x and y 
timestep = np.linspace(0, 1, target.shape[0], endpoint=False) # for t

One thing that bothers me is there are some artifacts on the generated images, its either black, white, or reddish. See the image below. Not sure what is the cause for that.

Source: Screen Shot 2022-10-18 at 12 41 35

Result: Screen Shot 2022-10-18 at 12 24 02

darwinharianto avatar Oct 18 '22 03:10 darwinharianto