learning-invariances icon indicating copy to clipboard operation
learning-invariances copied to clipboard

The missing file `aug_eq_model`

Open Fangwq opened this issue 3 years ago • 4 comments

As shown in the title, the file aug_eq_model.py is missing. Would you like to upload it? Thanks

Fangwq avatar Feb 18 '22 13:02 Fangwq

Yeah what did happen to that file, any ideas @g-benton?

mfinzi avatar Feb 22 '22 21:02 mfinzi

I have an old version

import torch
import torch.nn as nn
import torch.nn.functional as F
from augerino.utils import expm


class DiffEqAug(nn.Module):
    """ 
    Differetiable rotations and translations
    """
    def __init__(self, padding=0):
        super().__init__()
        self.Sigma = nn.Parameter(torch.eye(3))
        self.Mu = nn.Parameter(torch.zeros(3))
        self.aug = True
        self.pad = nn.ConstantPad2d(padding, 0.)
        self.crop = lambda x: x[:, :, padding:-padding, padding:-padding]
        
        
    def sample_z(self, x):
        bs, _, w, h = x.size()
        z = torch.randn(bs,3,device=x.device,dtype=x.dtype) @ self.Sigma.T + self.Mu
        self.z = z
    
    @staticmethod
    def transform(img, z):
        bs, c, h, w = img.size()
        # Build affine matrices for random translation of each image
        affineMatrices = torch.zeros(bs, 2, 3, device=img.device)
        affineMatrices[:,0,0] = z[:,2].cos()
        affineMatrices[:,0,1] = -z[:,2].sin()
        affineMatrices[:,1,0] = z[:,2].sin()
        affineMatrices[:,1,1] = z[:,2].cos()
        affineMatrices[:,:2,2] = z[:,:2]/(.5*w+.5*h)
        flowgrid = F.affine_grid(affineMatrices, size=img.size(), align_corners=True)
        return F.grid_sample(img, flowgrid, align_corners=True)

    def forward(self, x):
        x = self.pad(x)
        self.sample_z(x)
        return self.transform(x, self.z)
    
    def inverse(self, x):
        x = self.transform(x, -self.z)
        return self.crop(x)


class UniformEqAug(nn.Module):
    """
    Differetiable rotations and translations.
    """
    def __init__(self, gen_scale=10., trans_scale=0.1, padding=0):
        super(UniformEqAug, self).__init__()

        self.trans_scale = trans_scale

        self.lower = nn.Parameter(torch.zeros(6))
        self.upper = nn.Parameter(torch.ones(6))
        self.g0 = None
        self.pad = nn.ConstantPad2d(padding, 0.)
        self.crop = lambda x: x[:, :, padding:-padding, padding:-padding]
        self.weights = None

    def sample_weights(self, x):
        bs, _, w, h = x.size()
        weights = torch.rand(bs, 6)
        weights = weights.to(x.device, x.dtype)
        self.weights = weights * (self.upper - self.lower) + self.lower

    def transform(self, x, weights):
        generators = self.generate(weights)

        ## exponential map
        affine_matrices = expm(generators)
        flowgrid = F.affine_grid(affine_matrices[:, :2, :], size = x.size(),
                                 align_corners=True)
        x_out = F.grid_sample(x, flowgrid,align_corners=True)

        return x_out

    def forward(self, x):
        x = self.pad(x)
        self.sample_weights(x)
        return self.transform(x, self.weights)

    def inverse(self, x):
        x = self.transform(x, -self.weights)
        return self.crop(x)

    def generate(self, weights):
        """
        return the sum of the scaled generator matrices
        """
        bs = weights.shape[0]

        if self.g0 is None:
            ## tx
            self.g0 = torch.zeros(3, 3, device=weights.device)
            self.g0[0, 2] = 1. * self.trans_scale

            ## ty
            self.g1 = torch.zeros(3, 3, device=weights.device)
            self.g1[1, 2] = 1. * self.trans_scale

            self.g2 = torch.zeros(3, 3, device=weights.device)
            self.g2[0, 1] = -1.
            self.g2[1, 0] = 1.

            self.g3 = torch.zeros(3, 3, device=weights.device)
            self.g3[0, 0] = 1.
            self.g3[1, 1] = 1.

            self.g4 = torch.zeros(3, 3, device=weights.device)
            self.g4[0, 0] = 1.
            self.g4[1, 1] = -1.

            self.g5 = torch.zeros(3, 3, device=weights.device)
            self.g5[0, 1] = 1.
            self.g5[1, 0] = 1.

        out_mat = weights[:, 0] * self.g0.unsqueeze(-1).expand(3,3, bs)
        out_mat += weights[:, 1] * self.g1.unsqueeze(-1).expand(3,3, bs)
        out_mat += weights[:, 2] * self.g2.unsqueeze(-1).expand(3,3, bs)
        out_mat += weights[:, 3] * self.g3.unsqueeze(-1).expand(3,3, bs)
        out_mat += weights[:, 4] * self.g4.unsqueeze(-1).expand(3,3, bs)
        out_mat += weights[:, 5] * self.g5.unsqueeze(-1).expand(3,3, bs)

        # transposes just to get everything right
        return out_mat.transpose(0, 2).transpose(2, 1)



class AugEqModel(nn.Module):
    def __init__(self,model, aug, traincopies=1, testcopies=8):
        super().__init__()
        self.aug = aug
        self.model = model
        self.traincopies = traincopies
        self.testcopies = testcopies
        
    def forward(self,x):
        # ToDo: go to log softmax
        copies = self.traincopies if self.training else self.testcopies
        outputs = []
        for i in range(copies):
            pred = self.aug(x)
            pred = self.model(pred)
            pred = self.aug.inverse(pred)
            outputs.append(pred)
        return sum(outputs) / copies

mfinzi avatar Feb 22 '22 21:02 mfinzi

I think its in models/aug_modules.py now, but we should fix the broken import.

mfinzi avatar Feb 22 '22 21:02 mfinzi

Thanks for your clarification.

Fangwq avatar Feb 23 '22 01:02 Fangwq