vqtorch icon indicating copy to clipboard operation
vqtorch copied to clipboard

retain_graph=True

Open DiffDynamo opened this issue 2 years ago • 3 comments

Hello! I encountered an error when using "inplace_optimizer" in my code, but the same code works fine when "inplace_optimizer" is not used.

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward

DiffDynamo avatar Jun 19 '23 08:06 DiffDynamo

could you provide a short description of what you are trying to do and a code snippet to reproduce this error?

note that each forward call in training mode will update the codebook when inplace_optimizer is provided

minyoungg avatar Jun 21 '23 18:06 minyoungg

My goal is to conduct discrete representation learning for one-dimensional time series data, and I have created my own autoencoder for this purpose. I added your quantization layer in the bottleneck layer of my autoencoder to discretize the continuous latent representations. I am using Python version 3.9 and PyTorch version 1.12.1. a short code snippet:

from torch.nn import MSELoss
import torch.optim as optim
import torch
import torch.nn as nn
from torch.nn import Conv1d,ConvTranspose1d
from vqtorch.nn import VectorQuant
from torch.utils.data import TensorDataset,DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ResidualUnit(nn.Module):
    def __init__(self, in_channels, out_channels, dilation,kernel_size=7):
        super().__init__()

        self.layers = nn.Sequential(
            Conv1d(in_channels=in_channels, out_channels=in_channels,
                   kernel_size=kernel_size, stride=1,dilation=dilation,padding=int(dilation*(kernel_size-1)/2)),
            nn.ELU())

        self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
                              kernel_size=1)

    def forward(self, x):
        out = self.layers(x)
        out = self.conv(out)
        return x + out

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, stride=2):
        super().__init__()

        self.layers = nn.Sequential(
            Conv1d(in_channels=in_channels,out_channels=in_channels,kernel_size=1,padding=0,bias=False),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels,
                         out_channels=in_channels, dilation=1),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels,
                         out_channels=in_channels, dilation=3),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels,
                         out_channels=in_channels, dilation=5),
            nn.ELU(),
            Conv1d(in_channels=in_channels, out_channels=2*in_channels,
                   kernel_size=stride, stride=stride)
        )

    def forward(self, x):
        return self.layers(x)
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, stride=2):
        super().__init__()

        self.layers = nn.Sequential(
            ConvTranspose1d(in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=stride, stride=stride),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels, out_channels=in_channels,
                         dilation=1),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels, out_channels=in_channels,
                         dilation=3),
            nn.ELU(),
            ResidualUnit(in_channels=in_channels, out_channels=in_channels,
                         dilation=5),
            nn.ELU(),
            Conv1d(in_channels=in_channels,out_channels=in_channels//2,kernel_size=1,padding=0,bias=False)

        )

    def forward(self, x):
        return self.layers(x)

class Encoder(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.conv_1 = nn.Sequential(Conv1d(in_channels=1, out_channels=C, kernel_size=3,padding=1),nn.ELU())
        self.conv_2 = nn.Sequential(EncoderBlock(in_channels=C, stride=1),nn.ELU())
        self.conv_3 = nn.Sequential(EncoderBlock(in_channels=2*C, stride=2),nn.ELU())
        self.conv_4 = nn.Sequential(EncoderBlock(in_channels=4*C, stride=2),nn.ELU())
        self.conv_5 = nn.Sequential(EncoderBlock(in_channels=8*C, stride=2),nn.ELU())
        self.conv_6 = nn.Sequential(Conv1d(in_channels=16*C, out_channels=D, kernel_size=3,padding=1))

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        x = self.conv_5(x)
        x = self.conv_6(x)
        return x
class Decoder(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.conv1 = Conv1d(in_channels=C, out_channels=1, kernel_size=1,padding=0,bias=False)
        self.conv_1 = nn.Sequential(Conv1d(in_channels=D, out_channels=16*C, kernel_size=3,padding=1),nn.ELU())
        self.conv_2 = nn.Sequential(DecoderBlock(in_channels=16*C, stride=2),nn.ELU())
        self.conv_3 = nn.Sequential(DecoderBlock(in_channels=8*C, stride=2),nn.ELU())
        self.conv_4 = nn.Sequential(DecoderBlock(in_channels=4*C, stride=2),nn.ELU())
        self.conv_5 = nn.Sequential(DecoderBlock(in_channels=2*C, stride=1),nn.ELU())
        self.conv_6 = nn.Sequential(Conv1d(in_channels=C, out_channels=C, kernel_size=3,padding=1),nn.ELU())

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        x = self.conv_5(x)
        x = self.conv_6(x)
        x = self.conv1(x)
        return x

class VQencoder(nn.Module):
    def __init__(self, C, D, num_codes=512,embedding_dim=64,optemizer=None):
        super().__init__()

        self.inplace_optimizer = optemizer
        self.encoder = Encoder(C=C, D=D)
        self.vq_layer = VectorQuant(
            feature_size=embedding_dim,     # feature dimension corresponding to the vectors
            num_codes=num_codes,      # number of codebook vectors
            beta=1,           # (default: 0.9) commitment trade-off
            kmeans_init=True,    # (default: False) whether to use kmeans++ init
            norm='l2',           # (default: None) normalization for the input vectors
            cb_norm='l2',        # (default: None) normalization for codebook vectors
            affine_lr=20,      # (default: 0.0) lr scale for affine parameters
            sync_nu=0.2,         # (default: 0.0) codebook synchronization contribution
            replace_freq=20,     # (default: None) frequency to replace dead codes
            inplace_optimizer=self.inplace_optimizer,
            dim=1,              # (default: -1) dimension to be quantized
        )
        self.decoder = Decoder(C=C, D=D)
    def forward(self, x):
        e = self.encoder(x)
        z_q, vq_dict = self.vq_layer(e)
        vq_loss = vq_dict['loss']
        perplexity = vq_dict['perplexity']
        encodings = vq_dict['q']
        out = self.decoder(z_q)
        return out,vq_loss,perplexity,encodings

def train():
    weight_decay = 1e-4
    batch_size = 64
    num_workers = 4
    learning_rate = 1e-3
    loss_function = MSELoss(reduction='mean')
    inplace_optimizer = lambda *args, **kwargs: torch.optim.SGD(*args, **kwargs, lr=10.0, momentum=0.9)
    net = VQencoder(C=4, D=256, num_codes=1024,embedding_dim=256,optemizer=inplace_optimizer)
    net = net.to(device)
    optimizer = optim.AdamW(net.parameters(),lr=learning_rate, weight_decay=weight_decay,betas=(0.9,0.95))
    input_data = torch.randn(10000,1, 2048)
    train_dataset = TensorDataset(input_data)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers)
    for epoch in range(100):
        net.train()
        for batch_idx,data in enumerate(train_loader):
            data = data[0].to(device)
            res,vq_loss,perplexity,encodings = net(data)
            res_loss = loss_function(res,data)
            optimizer.zero_grad()
            res_loss.backward()
            optimizer.step()

if __name__ == '__main__':
    train()

DiffDynamo avatar Jun 22 '23 02:06 DiffDynamo

@minyoungg Hi, thanks for the interesting work and the great library. I've encountered the same problem as @DiffDynamo and setting retain_graph=True here seems working fine. Is this probably because the computation graph for z is deleted after the backward pass by inplace optimizer and thus the main optimizer is unable to do straight-through estimation? Thanks.

SeanNobel avatar Dec 28 '23 05:12 SeanNobel