fnet-pytorch icon indicating copy to clipboard operation
fnet-pytorch copied to clipboard

FourierMatmul function giving some errors

Open jaytimbadia opened this issue 3 years ago • 0 comments

Hey,

Wonderful translation!

I just implemented it myself, but this FourierMatmul is giving error of dimension mismatch. Can you please let me know what are the dimensions it expects? Please help?

My sample Inputs

import json 
from fnet import FNetPretraining
from transformers import FNetTokenizer

with open('config.json', 'r') as f:
    config = json.load(f)


tokenizer = FNetTokenizer.from_pretrained("google/fnet-base")

inputs = tokenizer(['Hello, my dog is so cute', 'Hello world'], 
                return_tensors='pt',
                padding=True,
                truncation=True, max_length=512)

# print(inputs) 
{'input_ids': tensor([[    4,  9665, 16680,   275,  3314,    65,   215,  6387,     5],
        [    4,  9665,   725,     5,     3,     3,     3,     3,     3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]])}
input_ids=inputs['input_ids']
token_type_ids = inputs['token_type_ids']

obj1 = FNetPretraining(config=config)
obj1.forward(input_ids, token_type_ids)
class FourierMMLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dft_mat_seq = torch.tensor(linalg.dft(config['max_position_embeddings']))
        self.dft_mat_hidden = torch.tensor(linalg.dft(config['hidden_size']))

    def forward(self, hidden_states):
        hidden_states_complex = hidden_states.type(torch.complex128)
        return torch.einsum(
            "...ij,...jk,...ni->...nk",
            hidden_states_complex,
            self.dft_mat_hidden,
            self.dft_mat_seq
        ).real.type(torch.float32)

Error

Traceback (most recent call last):
  File "inference.py", line 22, in <module>
    obj1.forward(input_ids, token_type_ids)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 124, in forward
    self.encoder(input_ids, type_ids)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 113, in forward
    sequence_output = self.encoder(embedding_output)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 94, in forward
    hidden_states = layer_module(hidden_states)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 80, in forward
    fft_output = self.fft(hidden_states)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 62, in forward
    return torch.einsum(
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 9, 768]->[2, 1, 1, 9, 768] [768, 768]->[1, 1, 768, 1, 768] [512, 512]->[1, 512, 1, 512, 1]

jaytimbadia avatar Feb 23 '22 16:02 jaytimbadia