pytorch-struct icon indicating copy to clipboard operation
pytorch-struct copied to clipboard

end_class support for Autoregressive

Open urchade opened this issue 3 years ago • 1 comments

end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49

urchade avatar Feb 04 '22 15:02 urchade

Hi, I'm also in need of an autoregressive model with end_class. Here's my approach https://github.com/CarlossShi/pytorch-struct/commit/b5a56e8aee2742586b4e432fafd2c6b7be63273c. I use the variable active to record whether the sequences have ever output end_class or not. If there is no sequence alive, break the for loop to save time. I am not quite familiar with NLP, so I am not sure if this is common practice. In addition, some problems remain to be solved:

  • It may be necessary to add an output indicating the effective length of each sentence.
  • The _beam_search method does not work as expected (see example below).

Following the official documentation Autoregressive / Beam Search, I made some examples.

import torch
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, 'pytorch-struct')
import torch_struct
batch, layer, H, C, N, K = 3, 1, 5, 4, 10, 2  # K: sample shape
init = (torch.rand(batch, layer, H),
        torch.rand(batch, layer, H))


def t(a):
    return [t.transpose(0, 1) for t in a]


def show_ar(chain):
     plt.imshow(chain.detach().transpose(0, 1))


class RNN_AR(torch.nn.Module):
    def __init__(self, sparse=True):
        super().__init__()
        self.sparse = sparse
        self.rnn = torch.nn.RNN(H, H, batch_first=True)
        self.proj = torch.nn.Linear(H, C)
        if sparse:
            self.embed = torch.nn.Embedding(C, H)
        else:
            self.embed = torch.nn.Linear(C, H)

    def forward(self, inputs, state):
        """

        @param inputs: {Tensor: (batch, 1)}
        @param state:  e.g. ({Tensor: (batch, layer, H)}, {Tensor: (batch, layer, H)})
        @return: {Tensor: (batch, layer, C)}, [{Tensor: (batch, layer, H)}]
        """
        if not self.sparse and inputs.dim() == 2:
            inputs = torch.nn.functional.one_hot(inputs, C).float()
        inputs = self.embed(inputs)  # {Tensor: (batch, 1, H)}
        out, state = self.rnn(inputs, t(state)[0])  # out: {Tensor: (batch, layer, H)}, t(state)[0] & state: {Tensor: (layer, batch, H)}
        out = self.proj(out)  # {Tensor: (batch, layer, C)}
        return out, t((state,))  # t((state,))[0]: {Tensor: (batch, layer, H)}


dist = torch_struct.Autoregressive(RNN_AR(), init, C, N, end_class=1)

path, scores, logits = dist.greedy_max()  # path, logits: {Tensor: (batch, N, C)}, scores: {Tensor: (batch,)}
for b in range(batch):
    plt.subplot(1, batch, b + 1)
    plt.axis('off')
    show_ar(path[b])
plt.suptitle('dist.greedy_max()')
plt.show()

out = dist.sample(torch.Size([K]))  # {Tensor: (K, batch, N, C)}
for k in range(K):
    for b in range(batch):
        plt.subplot(K, batch, batch * k + b + 1)
        plt.axis('off')
        show_ar(out[k, b])
plt.suptitle('dist.sample(torch.Size([K]))')
plt.show()

out = dist.beam_topk(K)  # {Tensor: (K, batch, N, C)}, first output of _beam_search
for k in range(K):
    for b in range(batch):
        plt.subplot(K, batch, batch * k + b + 1)
        plt.axis('off')
        show_ar(out[k, b])
plt.suptitle('dist.beam_topk(K)')
plt.show()

The output images are as follows. dist greedy_max() dist sample(torch Size( K )) dist beam_topk(K)

In the example above, end_class is set to 1. I expect that if all setences meet the end_class (i.e. there is a yellow square in the second row of each array), then the remaining columns are truncated. It seems that the sample method works expected, but the _beam_search not. I'm not quite familiar with the beam search function, so I just get stuck here.

Hope that help and any further support would be greatly appreciated.

CarlossShi avatar Sep 18 '22 13:09 CarlossShi