fnet-pytorch
fnet-pytorch copied to clipboard
FourierMatmul function giving some errors
Hey,
Wonderful translation!
I just implemented it myself, but this FourierMatmul is giving error of dimension mismatch. Can you please let me know what are the dimensions it expects? Please help?
My sample Inputs
import json
from fnet import FNetPretraining
from transformers import FNetTokenizer
with open('config.json', 'r') as f:
config = json.load(f)
tokenizer = FNetTokenizer.from_pretrained("google/fnet-base")
inputs = tokenizer(['Hello, my dog is so cute', 'Hello world'],
return_tensors='pt',
padding=True,
truncation=True, max_length=512)
# print(inputs)
{'input_ids': tensor([[ 4, 9665, 16680, 275, 3314, 65, 215, 6387, 5],
[ 4, 9665, 725, 5, 3, 3, 3, 3, 3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0]])}
input_ids=inputs['input_ids']
token_type_ids = inputs['token_type_ids']
obj1 = FNetPretraining(config=config)
obj1.forward(input_ids, token_type_ids)
class FourierMMLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.dft_mat_seq = torch.tensor(linalg.dft(config['max_position_embeddings']))
self.dft_mat_hidden = torch.tensor(linalg.dft(config['hidden_size']))
def forward(self, hidden_states):
hidden_states_complex = hidden_states.type(torch.complex128)
return torch.einsum(
"...ij,...jk,...ni->...nk",
hidden_states_complex,
self.dft_mat_hidden,
self.dft_mat_seq
).real.type(torch.float32)
Error
Traceback (most recent call last):
File "inference.py", line 22, in <module>
obj1.forward(input_ids, token_type_ids)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 124, in forward
self.encoder(input_ids, type_ids)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 113, in forward
sequence_output = self.encoder(embedding_output)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 94, in forward
hidden_states = layer_module(hidden_states)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 80, in forward
fft_output = self.fft(hidden_states)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 62, in forward
return torch.einsum(
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 9, 768]->[2, 1, 1, 9, 768] [768, 768]->[1, 1, 768, 1, 768] [512, 512]->[1, 512, 1, 512, 1]