torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Error when trying to emit MLIR for Training MNIST kernel

Open chadlonso opened this issue 2 years ago • 2 comments

Hi, I m trying to generate mlir for the training MNIST kernel using the below script ` import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR import torch_mlir from functorch import make_fx from torch.nn.utils import stateless from torch._functorch.compile_utils import strip_overloads from torch._decomp import get_decompositions

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.flatten = nn.Flatten(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.logsoftmax(x)
        return output

mod = Net()

def forward(params, buffers, args):
    params_and_buffers = {**params, **buffers}
    optimizer = optim.Adadelta(mod.parameters(), lr=1.0)
    optimizer.zero_grad()
    res = stateless.functional_call(mod, params_and_buffers, args,
                               {})
    loss = F.nll_loss(res, target)
    loss.backward()
    optimizer.step()
    return params, buffers

def get_sorted_params(named_params):
    return [i[1] for i in sorted(named_params.items())]

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1)#,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2)#, **test_kwargs)

data = None
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to("cpu"), target.to("cpu")
    break

arg = []
for name, param in mod.named_parameters():
    arg.append(param.data)
arg.append(data)
fx_graph = make_fx(forward)(dict(mod.named_parameters()),
                        dict(mod.named_buffers()), data)
fx_graph.graph.set_codegen(torch.fx.graph.CodeGen())
fx_graph.recompile()
sinput = strip_overloads(fx_graph)

ts_graph = torch.jit.script(fx_graph)

linalg_on_tensors_mlir = torch_mlir.compile(
    ts_graph,
    arg,
    output_type=torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=True)
print(linalg_on_tensors_mlir)

`

But I m getting the below error: python exception: Failure while executing pass pipeline: error: unknown: unsupported by backend contract: tensor with unknown rank note: unknown: see current operation: %25 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[32,1,3,3],f32>) -> !torch.vtensor note: unknown: this is likely due to a missing transfer function in abstract_interp_lib_gen.py

In the print after all, I noticed that "torch.tenosr_static_info_cast" op is getting created for each function argument during the AdjustCallingConventions pass.

chadlonso avatar Aug 10 '23 05:08 chadlonso

Answered on discord: https://discord.com/channels/636084430946959380/742573221882364009/1139614233978474646

ramiro050 avatar Aug 11 '23 23:08 ramiro050

Hi, after following the discord, I still can't get it running. It generates the following error:

python: /home/hao/torch-mlir/externals/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::torch::Torch::ValueTensorType, From = mlir::Type]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

The issue seems to to be the optimizer.step() or more specifically assigning the result of the autograd because commenting it out would make it work.

This can be reproduced using the following code The version of torch-mlir I'm using is commit 56a663690ccd378182ea7dbf95b7b2a54463e3e9 #3440.

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import torch_mlir
from torch_mlir import fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from functorch import make_fx
from torch.nn.utils import stateless
from torch._functorch.compile_utils import strip_overloads
from torch._decomp import get_decompositions
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.flatten = nn.Flatten(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.logsoftmax(x)
        return output

mod = Net()
optimizer = optim.Adadelta(mod.parameters(), lr=1.0)

def forward(params, buffers, data, target):
    optimizer.zero_grad()
    params_and_buffers = {**params, **buffers}
    res = torch.func.functional_call(mod, params_and_buffers, data,
                               {})
    loss = F.nll_loss(res, target)
    loss.backward()
    optimizer.step() # this line causes the error
    return (params, buffers, loss)

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1)#,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2)#, **test_kwargs)

data = None
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to("cpu"), target.to("cpu")
    break

arg = []
for name, param in mod.named_parameters():
    arg.append(param.data)
arg.append(data)
arg.append(target)
fx_graph = make_fx(forward)(dict(mod.named_parameters()),
                        dict(mod.named_buffers()), data, target)
fx_graph.graph.set_codegen(torch.fx.graph.CodeGen())
fx_graph.recompile()
sinput = strip_overloads(fx_graph)

train_fn = fx.export_and_import(
    fx_graph,
    *arg,
    output_type="linalg-on-tensors",
    func_name="forward",
)

print(train_fn)

I also attempted to replace the training functiont to

def forward(params, buffers, data, target):
    params_and_buffers = {**params, **buffers}
    res = torch.func.functional_call(mod, params_and_buffers, data,
                               {})
    loss = F.nll_loss(res, target)
    parameters = list(mod.parameters())
    grads = torch.autograd.grad(loss, parameters, create_graph=False) # this line causes the error
    return (loss, grads)

and still failed with the same error

python: /home/hao/torch-mlir/externals/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::torch::Torch::ValueTensorType, From = mlir::Type]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.


Also, when I emit the torch dialect it can work, dumping the following output:

module {
  func.func @forward(%arg0: !torch.vtensor<[32,1,3,3],f32>, %arg1: !torch.vtensor<[32],f32>, %arg2: !torch.vtensor<[64,32,3,3],f32>, %arg3: !torch.vtensor<[64],f32>, %arg4: !torch.vtensor<[128,9216],f32>, %arg5: !torch.vtensor<[128],f32>, %arg6: !torch.vtensor<[10,128],f32>, %arg7: !torch.vtensor<[10],f32>, %arg8: !torch.vtensor<[1,1,28,28],f32>, %arg9: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128],f32>, !torch.vtensor<[10,128],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128],f32>, !torch.vtensor<[10,128],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[],f32>) {
    %int32 = torch.constant.int 32
    %int3 = torch.constant.int 3
    %0 = torch.vtensor.literal(dense<-1.000000e+00> : tensor<f64>) : !torch.vtensor<[],f64>
    %1 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
    %2 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
    %3 = torch.vtensor.literal(dense<0.000000e+00> : tensor<1x128xf64>) : !torch.vtensor<[1,128],f64>
    %4 = torch.vtensor.literal(dense<5.000000e-01> : tensor<f64>) : !torch.vtensor<[],f64>
    %5 = torch.vtensor.literal(dense<0.000000e+00> : tensor<1x64x12x12xf64>) : !torch.vtensor<[1,64,12,12],f64>
    %float1.000000e00 = torch.constant.float 1.000000e+00
    %float0.000000e00 = torch.constant.float 0.000000e+00
    %int7 = torch.constant.int 7
    %6 = torch.vtensor.literal(dense<7.500000e-01> : tensor<f64>) : !torch.vtensor<[],f64>
    %float9.999990e-07 = torch.constant.float 9.9999999999999995E-7
    %float9.999990e-02 = torch.constant.float 0.099999999999999978
    %float9.000000e-01 = torch.constant.float 9.000000e-01
    %int12 = torch.constant.int 12
    %int64 = torch.constant.int 64
    %int128 = torch.constant.int 128
    %int10 = torch.constant.int 10
    %float-1.000000e00 = torch.constant.float -1.000000e+00
    %int6 = torch.constant.int 6
    %int-100 = torch.constant.int -100
    %true = torch.constant.bool true
    %float5.000000e-01 = torch.constant.float 5.000000e-01
    %int9216 = torch.constant.int 9216
    %float7.500000e-01 = torch.constant.float 7.500000e-01
    %none = torch.constant.none
    %int2 = torch.constant.int 2
    %false = torch.constant.bool false
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %7 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %8 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %9 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %10 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %11 = torch.aten.convolution %arg8, %arg0, %arg1, %7, %8, %9, %false, %10, %int1 : !torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,26,26],f32>
    %12 = torch.aten.relu %11 : !torch.vtensor<[1,32,26,26],f32> -> !torch.vtensor<[1,32,26,26],f32>
    %13 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %14 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %16 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %17 = torch.aten.convolution %12, %arg2, %arg3, %13, %14, %15, %false, %16, %int1 : !torch.vtensor<[1,32,26,26],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,24,24],f32>
    %18 = torch.aten.relu %17 : !torch.vtensor<[1,64,24,24],f32> -> !torch.vtensor<[1,64,24,24],f32>
    %19 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %20 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %21 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %22 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %result0, %result1 = torch.aten.max_pool2d_with_indices %18, %19, %20, %21, %22, %false : !torch.vtensor<[1,64,24,24],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,12,12],f32>, !torch.vtensor<[1,64,12,12],si64>
    %23 = torch.prim.ListConstruct %int1, %int64, %int12, %int12 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %24 = torch.aten.empty.memory_format %23, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,64,12,12],f32>
    %25 = torch.aten.to.dtype %24, %int7, %false, %false, %none : !torch.vtensor<[1,64,12,12],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,64,12,12],f64>
    %26 = torch.aten.uniform %5, %float0.000000e00, %float1.000000e00, %none : !torch.vtensor<[1,64,12,12],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[1,64,12,12],f64>
    %27 = torch.aten.lt.Tensor %26, %6 : !torch.vtensor<[1,64,12,12],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[1,64,12,12],i1>
    %28 = torch.aten.to.dtype %27, %int6, %false, %false, %none : !torch.vtensor<[1,64,12,12],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,64,12,12],f32>
    %29 = torch.aten.div.Scalar %28, %float7.500000e-01 : !torch.vtensor<[1,64,12,12],f32>, !torch.float -> !torch.vtensor<[1,64,12,12],f32>
    %30 = torch.aten.mul.Tensor %result0, %29 : !torch.vtensor<[1,64,12,12],f32>, !torch.vtensor<[1,64,12,12],f32> -> !torch.vtensor<[1,64,12,12],f32>
    %31 = torch.prim.ListConstruct %int1, %int9216 : (!torch.int, !torch.int) -> !torch.list<int>
    %32 = torch.aten.view %30, %31 : !torch.vtensor<[1,64,12,12],f32>, !torch.list<int> -> !torch.vtensor<[1,9216],f32>
    %33 = torch.aten.transpose.int %arg4, %int0, %int1 : !torch.vtensor<[128,9216],f32>, !torch.int, !torch.int -> !torch.vtensor<[9216,128],f32>
    %34 = torch.aten.mm %32, %33 : !torch.vtensor<[1,9216],f32>, !torch.vtensor<[9216,128],f32> -> !torch.vtensor<[1,128],f32>
    %35 = torch.aten.mul.Scalar %34, %int1 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
    %36 = torch.aten.mul.Scalar %arg5, %int1 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32>
    %37 = torch.aten.add.Tensor %35, %36, %int1 : !torch.vtensor<[1,128],f32>, !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
    %38 = torch.aten.relu %37 : !torch.vtensor<[1,128],f32> -> !torch.vtensor<[1,128],f32>
    %39 = torch.prim.ListConstruct %int1, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
    %40 = torch.aten.empty.memory_format %39, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,128],f32>
    %41 = torch.aten.to.dtype %40, %int7, %false, %false, %none : !torch.vtensor<[1,128],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,128],f64>
    %42 = torch.aten.uniform %3, %float0.000000e00, %float1.000000e00, %none : !torch.vtensor<[1,128],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[1,128],f64>
    %43 = torch.aten.lt.Tensor %42, %4 : !torch.vtensor<[1,128],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[1,128],i1>
    %44 = torch.aten.to.dtype %43, %int6, %false, %false, %none : !torch.vtensor<[1,128],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,128],f32>
    %45 = torch.aten.div.Scalar %44, %float5.000000e-01 : !torch.vtensor<[1,128],f32>, !torch.float -> !torch.vtensor<[1,128],f32>
    %46 = torch.aten.mul.Tensor %38, %45 : !torch.vtensor<[1,128],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[1,128],f32>
    %47 = torch.aten.transpose.int %arg6, %int0, %int1 : !torch.vtensor<[10,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[128,10],f32>
    %48 = torch.aten.mm %46, %47 : !torch.vtensor<[1,128],f32>, !torch.vtensor<[128,10],f32> -> !torch.vtensor<[1,10],f32>
    %49 = torch.aten.mul.Scalar %48, %int1 : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
    %50 = torch.aten.mul.Scalar %arg7, %int1 : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
    %51 = torch.aten.add.Tensor %49, %50, %int1 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
    %values, %indices = torch.aten.max.dim %51, %int1, %true : !torch.vtensor<[1,10],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],si64>
    %52 = torch.aten.sub.Tensor %51, %values, %int1 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
    %53 = torch.aten.exp %52 : !torch.vtensor<[1,10],f32> -> !torch.vtensor<[1,10],f32>
    %54 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
    %55 = torch.aten.sum.dim_IntList %53, %54, %true, %none : !torch.vtensor<[1,10],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
    %56 = torch.aten.log %55 : !torch.vtensor<[1,1],f32> -> !torch.vtensor<[1,1],f32>
    %57 = torch.aten.sub.Tensor %52, %56, %int1 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
    %58 = torch.aten.ne.Scalar %arg9, %int-100 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1],i1>
    %59 = torch.aten.where.self %58, %arg9, %2 : !torch.vtensor<[1],i1>, !torch.vtensor<[1],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
    %60 = torch.aten.unsqueeze %59, %int1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
    %61 = torch.aten.gather %57, %int1, %60, %false : !torch.vtensor<[1,10],f32>, !torch.int, !torch.vtensor<[1,1],si64>, !torch.bool -> !torch.vtensor<[1,1],f32>
    %62 = torch.aten.squeeze.dim %61, %int1 : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
    %63 = torch.aten.neg %62 : !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>
    %64 = torch.aten.ne.Scalar %arg9, %int-100 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1],i1>
    %65 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %66 = torch.aten.where.self %64, %63, %65 : !torch.vtensor<[1],i1>, !torch.vtensor<[1],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1],f32>
    %67 = torch.aten.ne.Scalar %arg9, %int-100 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1],i1>
    %68 = torch.aten.sum %67, %none : !torch.vtensor<[1],i1>, !torch.none -> !torch.vtensor<[],si64>
    %69 = torch.aten.to.dtype %68, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %70 = torch.aten.sum %66, %none : !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[],f32>
    %71 = torch.aten.div.Tensor %70, %69 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
    %72 = torch.aten.to.dtype %1, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %73 = torch.aten.div.Tensor %72, %69 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
    %74 = torch.aten.unsqueeze %arg9, %int1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
    %75 = torch.aten.ne.Scalar %74, %int-100 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1],i1>
    %76 = torch.aten.where.self %75, %74, %2 : !torch.vtensor<[1,1],i1>, !torch.vtensor<[1,1],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1,1],si64>
    %77 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %78 = torch.prim.ListConstruct %int1, %int10 : (!torch.int, !torch.int) -> !torch.list<int>
    %79 = torch.aten.broadcast_to %77, %78 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[1,10],f32>
    %80 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %81 = torch.aten.to.dtype %0, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %82 = torch.aten.broadcast_to %81, %80 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[1,1],f32>
    %83 = torch.aten.scatter.src %79, %int1, %76, %82 : !torch.vtensor<[1,10],f32>, !torch.int, !torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],f32> -> !torch.vtensor<[1,10],f32>
    %84 = torch.aten.ne.Scalar %74, %int-100 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1],i1>
    %85 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %86 = torch.aten.where.self %84, %73, %85 : !torch.vtensor<[1,1],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,1],f32>
    %87 = torch.aten.mul.Tensor %83, %86 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,1],f32> -> !torch.vtensor<[1,10],f32>
    %88 = torch.aten.exp %57 : !torch.vtensor<[1,10],f32> -> !torch.vtensor<[1,10],f32>
    %89 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
    %90 = torch.aten.sum.dim_IntList %87, %89, %true, %none : !torch.vtensor<[1,10],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
    %91 = torch.aten.mul.Tensor %88, %90 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,1],f32> -> !torch.vtensor<[1,10],f32>
    %92 = torch.aten.sub.Tensor %87, %91, %int1 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32>
    %93 = torch.aten.transpose.int %47, %int0, %int1 : !torch.vtensor<[128,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,128],f32>
    %94 = torch.aten.mm %92, %93 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[10,128],f32> -> !torch.vtensor<[1,128],f32>
    %95 = torch.aten.transpose.int %92, %int0, %int1 : !torch.vtensor<[1,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,1],f32>
    %96 = torch.aten.mm %95, %46 : !torch.vtensor<[10,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[10,128],f32>
    %97 = torch.aten.transpose.int %96, %int0, %int1 : !torch.vtensor<[10,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[128,10],f32>
    %98 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
    %99 = torch.aten.sum.dim_IntList %92, %98, %true, %none : !torch.vtensor<[1,10],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,10],f32>
    %100 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list<int>
    %101 = torch.aten.view %99, %100 : !torch.vtensor<[1,10],f32>, !torch.list<int> -> !torch.vtensor<[10],f32>
    %102 = torch.aten.transpose.int %97, %int0, %int1 : !torch.vtensor<[128,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,128],f32>
    %103 = torch.aten.mul.Tensor %94, %45 : !torch.vtensor<[1,128],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[1,128],f32>
    %104 = torch.aten.threshold_backward %103, %38, %int0 : !torch.vtensor<[1,128],f32>, !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
    %105 = torch.aten.transpose.int %33, %int0, %int1 : !torch.vtensor<[9216,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[128,9216],f32>
    %106 = torch.aten.mm %104, %105 : !torch.vtensor<[1,128],f32>, !torch.vtensor<[128,9216],f32> -> !torch.vtensor<[1,9216],f32>
    %107 = torch.aten.transpose.int %104, %int0, %int1 : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[128,1],f32>
    %108 = torch.aten.mm %107, %32 : !torch.vtensor<[128,1],f32>, !torch.vtensor<[1,9216],f32> -> !torch.vtensor<[128,9216],f32>
    %109 = torch.aten.transpose.int %108, %int0, %int1 : !torch.vtensor<[128,9216],f32>, !torch.int, !torch.int -> !torch.vtensor<[9216,128],f32>
    %110 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
    %111 = torch.aten.sum.dim_IntList %104, %110, %true, %none : !torch.vtensor<[1,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128],f32>
    %112 = torch.prim.ListConstruct %int128 : (!torch.int) -> !torch.list<int>
    %113 = torch.aten.view %111, %112 : !torch.vtensor<[1,128],f32>, !torch.list<int> -> !torch.vtensor<[128],f32>
    %114 = torch.aten.transpose.int %109, %int0, %int1 : !torch.vtensor<[9216,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[128,9216],f32>
    %115 = torch.prim.ListConstruct %int1, %int64, %int12, %int12 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %116 = torch.aten.view %106, %115 : !torch.vtensor<[1,9216],f32>, !torch.list<int> -> !torch.vtensor<[1,64,12,12],f32>
    %117 = torch.aten.mul.Tensor %116, %29 : !torch.vtensor<[1,64,12,12],f32>, !torch.vtensor<[1,64,12,12],f32> -> !torch.vtensor<[1,64,12,12],f32>
    %118 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %119 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %120 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %121 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %122 = torch.aten.max_pool2d_with_indices_backward %117, %18, %118, %119, %120, %121, %false, %result1 : !torch.vtensor<[1,64,12,12],f32>, !torch.vtensor<[1,64,24,24],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.vtensor<[1,64,12,12],si64> -> !torch.vtensor<[1,64,24,24],f32>
    %123 = torch.aten.threshold_backward %122, %18, %int0 : !torch.vtensor<[1,64,24,24],f32>, !torch.vtensor<[1,64,24,24],f32>, !torch.int -> !torch.vtensor<[1,64,24,24],f32>
    %124 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %125 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %126 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %127 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %128 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %129 = torch.aten.convolution %123, %arg2, %none, %124, %125, %126, %true, %128, %int1 : !torch.vtensor<[1,64,24,24],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,26,26],f32>
    %130 = torch.aten.transpose.int %12, %int0, %int1 : !torch.vtensor<[1,32,26,26],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,1,26,26],f32>
    %131 = torch.aten.transpose.int %123, %int0, %int1 : !torch.vtensor<[1,64,24,24],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,1,24,24],f32>
    %132 = torch.aten.convolution %130, %131, %none, %124, %125, %126, %false, %127, %int1 : !torch.vtensor<[32,1,26,26],f32>, !torch.vtensor<[64,1,24,24],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[32,64,3,3],f32>
    %133 = torch.aten.transpose.int %132, %int0, %int1 : !torch.vtensor<[32,64,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,32,3,3],f32>
    %134 = torch.prim.ListConstruct %int0, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %135 = torch.aten.sum.dim_IntList %123, %134, %false, %none : !torch.vtensor<[1,64,24,24],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[64],f32>
    %136 = torch.aten.threshold_backward %129, %12, %int0 : !torch.vtensor<[1,32,26,26],f32>, !torch.vtensor<[1,32,26,26],f32>, !torch.int -> !torch.vtensor<[1,32,26,26],f32>
    %137 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %138 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %139 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %140 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %141 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    %142 = torch.aten.convolution %136, %arg0, %none, %137, %138, %139, %true, %141, %int1 : !torch.vtensor<[1,32,26,26],f32>, !torch.vtensor<[32,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.none
    %143 = torch.aten.transpose.int %arg8, %int0, %int1 : !torch.vtensor<[1,1,28,28],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,28,28],f32>
    %144 = torch.aten.transpose.int %136, %int0, %int1 : !torch.vtensor<[1,32,26,26],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,1,26,26],f32>
    %145 = torch.aten.convolution %143, %144, %none, %137, %138, %139, %false, %140, %int1 : !torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[32,1,26,26],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,3,3],f32>
    %146 = torch.aten.transpose.int %145, %int0, %int1 : !torch.vtensor<[1,32,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,1,3,3],f32>
    %147 = torch.prim.ListConstruct %int0, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %148 = torch.aten.sum.dim_IntList %136, %147, %false, %none : !torch.vtensor<[1,32,26,26],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[32],f32>
    %149 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %150 = torch.prim.ListConstruct %int32, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %151 = torch.aten.broadcast_to %149, %150 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[32,1,3,3],f32>
    %152 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %153 = torch.prim.ListConstruct %int32, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %154 = torch.aten.broadcast_to %152, %153 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[32,1,3,3],f32>
    %155 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %156 = torch.prim.ListConstruct %int32 : (!torch.int) -> !torch.list<int>
    %157 = torch.aten.broadcast_to %155, %156 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[32],f32>
    %158 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %159 = torch.prim.ListConstruct %int32 : (!torch.int) -> !torch.list<int>
    %160 = torch.aten.broadcast_to %158, %159 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[32],f32>
    %161 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %162 = torch.prim.ListConstruct %int64, %int32, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %163 = torch.aten.broadcast_to %161, %162 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[64,32,3,3],f32>
    %164 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %165 = torch.prim.ListConstruct %int64, %int32, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %166 = torch.aten.broadcast_to %164, %165 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[64,32,3,3],f32>
    %167 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %168 = torch.prim.ListConstruct %int64 : (!torch.int) -> !torch.list<int>
    %169 = torch.aten.broadcast_to %167, %168 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[64],f32>
    %170 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %171 = torch.prim.ListConstruct %int64 : (!torch.int) -> !torch.list<int>
    %172 = torch.aten.broadcast_to %170, %171 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[64],f32>
    %173 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %174 = torch.prim.ListConstruct %int128, %int9216 : (!torch.int, !torch.int) -> !torch.list<int>
    %175 = torch.aten.broadcast_to %173, %174 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[128,9216],f32>
    %176 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %177 = torch.prim.ListConstruct %int128, %int9216 : (!torch.int, !torch.int) -> !torch.list<int>
    %178 = torch.aten.broadcast_to %176, %177 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[128,9216],f32>
    %179 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %180 = torch.prim.ListConstruct %int128 : (!torch.int) -> !torch.list<int>
    %181 = torch.aten.broadcast_to %179, %180 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[128],f32>
    %182 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %183 = torch.prim.ListConstruct %int128 : (!torch.int) -> !torch.list<int>
    %184 = torch.aten.broadcast_to %182, %183 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[128],f32>
    %185 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %186 = torch.prim.ListConstruct %int10, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
    %187 = torch.aten.broadcast_to %185, %186 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[10,128],f32>
    %188 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %189 = torch.prim.ListConstruct %int10, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
    %190 = torch.aten.broadcast_to %188, %189 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[10,128],f32>
    %191 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %192 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list<int>
    %193 = torch.aten.broadcast_to %191, %192 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[10],f32>
    %194 = torch.aten.to.dtype %2, %int6, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %195 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list<int>
    %196 = torch.aten.broadcast_to %194, %195 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[10],f32>
    %197 = torch.aten.mul.Scalar %151, %float9.000000e-01 : !torch.vtensor<[32,1,3,3],f32>, !torch.float -> !torch.vtensor<[32,1,3,3],f32>
    %198 = torch.aten.mul.Tensor %146, %146 : !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32,1,3,3],f32> -> !torch.vtensor<[32,1,3,3],f32>
    %199 = torch.aten.add.Tensor %197, %198, %float9.999990e-02 : !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32,1,3,3],f32>, !torch.float -> !torch.vtensor<[32,1,3,3],f32>
    %200 = torch.aten.add.Scalar %199, %float9.999990e-07, %int1 : !torch.vtensor<[32,1,3,3],f32>, !torch.float, !torch.int -> !torch.vtensor<[32,1,3,3],f32>
    %201 = torch.aten.sqrt %200 : !torch.vtensor<[32,1,3,3],f32> -> !torch.vtensor<[32,1,3,3],f32>
    %202 = torch.aten.add.Scalar %154, %float9.999990e-07, %int1 : !torch.vtensor<[32,1,3,3],f32>, !torch.float, !torch.int -> !torch.vtensor<[32,1,3,3],f32>
    %203 = torch.aten.sqrt %202 : !torch.vtensor<[32,1,3,3],f32> -> !torch.vtensor<[32,1,3,3],f32>
    %204 = torch.aten.div.Tensor %203, %201 : !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32,1,3,3],f32> -> !torch.vtensor<[32,1,3,3],f32>
    %205 = torch.aten.mul.Tensor %204, %146 : !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32,1,3,3],f32> -> !torch.vtensor<[32,1,3,3],f32>
    %206 = torch.aten.add.Tensor %arg0, %205, %float-1.000000e00 : !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32,1,3,3],f32>, !torch.float -> !torch.vtensor<[32,1,3,3],f32>
    %207 = torch.aten.mul.Scalar %157, %float9.000000e-01 : !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32>
    %208 = torch.aten.mul.Tensor %148, %148 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32>
    %209 = torch.aten.add.Tensor %207, %208, %float9.999990e-02 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32>
    %210 = torch.aten.add.Scalar %209, %float9.999990e-07, %int1 : !torch.vtensor<[32],f32>, !torch.float, !torch.int -> !torch.vtensor<[32],f32>
    %211 = torch.aten.sqrt %210 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32>
    %212 = torch.aten.add.Scalar %160, %float9.999990e-07, %int1 : !torch.vtensor<[32],f32>, !torch.float, !torch.int -> !torch.vtensor<[32],f32>
    %213 = torch.aten.sqrt %212 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32>
    %214 = torch.aten.div.Tensor %213, %211 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32>
    %215 = torch.aten.mul.Tensor %214, %148 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32>
    %216 = torch.aten.add.Tensor %arg1, %215, %float-1.000000e00 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32>
    %217 = torch.aten.mul.Scalar %163, %float9.000000e-01 : !torch.vtensor<[64,32,3,3],f32>, !torch.float -> !torch.vtensor<[64,32,3,3],f32>
    %218 = torch.aten.mul.Tensor %133, %133 : !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64,32,3,3],f32> -> !torch.vtensor<[64,32,3,3],f32>
    %219 = torch.aten.add.Tensor %217, %218, %float9.999990e-02 : !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.float -> !torch.vtensor<[64,32,3,3],f32>
    %220 = torch.aten.add.Scalar %219, %float9.999990e-07, %int1 : !torch.vtensor<[64,32,3,3],f32>, !torch.float, !torch.int -> !torch.vtensor<[64,32,3,3],f32>
    %221 = torch.aten.sqrt %220 : !torch.vtensor<[64,32,3,3],f32> -> !torch.vtensor<[64,32,3,3],f32>
    %222 = torch.aten.add.Scalar %166, %float9.999990e-07, %int1 : !torch.vtensor<[64,32,3,3],f32>, !torch.float, !torch.int -> !torch.vtensor<[64,32,3,3],f32>
    %223 = torch.aten.sqrt %222 : !torch.vtensor<[64,32,3,3],f32> -> !torch.vtensor<[64,32,3,3],f32>
    %224 = torch.aten.div.Tensor %223, %221 : !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64,32,3,3],f32> -> !torch.vtensor<[64,32,3,3],f32>
    %225 = torch.aten.mul.Tensor %224, %133 : !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64,32,3,3],f32> -> !torch.vtensor<[64,32,3,3],f32>
    %226 = torch.aten.add.Tensor %arg2, %225, %float-1.000000e00 : !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.float -> !torch.vtensor<[64,32,3,3],f32>
    %227 = torch.aten.mul.Scalar %169, %float9.000000e-01 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32>
    %228 = torch.aten.mul.Tensor %135, %135 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32>
    %229 = torch.aten.add.Tensor %227, %228, %float9.999990e-02 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32>
    %230 = torch.aten.add.Scalar %229, %float9.999990e-07, %int1 : !torch.vtensor<[64],f32>, !torch.float, !torch.int -> !torch.vtensor<[64],f32>
    %231 = torch.aten.sqrt %230 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32>
    %232 = torch.aten.add.Scalar %172, %float9.999990e-07, %int1 : !torch.vtensor<[64],f32>, !torch.float, !torch.int -> !torch.vtensor<[64],f32>
    %233 = torch.aten.sqrt %232 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32>
    %234 = torch.aten.div.Tensor %233, %231 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32>
    %235 = torch.aten.mul.Tensor %234, %135 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32>
    %236 = torch.aten.add.Tensor %arg3, %235, %float-1.000000e00 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32>
    %237 = torch.aten.mul.Scalar %175, %float9.000000e-01 : !torch.vtensor<[128,9216],f32>, !torch.float -> !torch.vtensor<[128,9216],f32>
    %238 = torch.aten.mul.Tensor %114, %114 : !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128,9216],f32> -> !torch.vtensor<[128,9216],f32>
    %239 = torch.aten.add.Tensor %237, %238, %float9.999990e-02 : !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128,9216],f32>, !torch.float -> !torch.vtensor<[128,9216],f32>
    %240 = torch.aten.add.Scalar %239, %float9.999990e-07, %int1 : !torch.vtensor<[128,9216],f32>, !torch.float, !torch.int -> !torch.vtensor<[128,9216],f32>
    %241 = torch.aten.sqrt %240 : !torch.vtensor<[128,9216],f32> -> !torch.vtensor<[128,9216],f32>
    %242 = torch.aten.add.Scalar %178, %float9.999990e-07, %int1 : !torch.vtensor<[128,9216],f32>, !torch.float, !torch.int -> !torch.vtensor<[128,9216],f32>
    %243 = torch.aten.sqrt %242 : !torch.vtensor<[128,9216],f32> -> !torch.vtensor<[128,9216],f32>
    %244 = torch.aten.div.Tensor %243, %241 : !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128,9216],f32> -> !torch.vtensor<[128,9216],f32>
    %245 = torch.aten.mul.Tensor %244, %114 : !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128,9216],f32> -> !torch.vtensor<[128,9216],f32>
    %246 = torch.aten.add.Tensor %arg4, %245, %float-1.000000e00 : !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128,9216],f32>, !torch.float -> !torch.vtensor<[128,9216],f32>
    %247 = torch.aten.mul.Scalar %181, %float9.000000e-01 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32>
    %248 = torch.aten.mul.Tensor %113, %113 : !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32>
    %249 = torch.aten.add.Tensor %247, %248, %float9.999990e-02 : !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32>
    %250 = torch.aten.add.Scalar %249, %float9.999990e-07, %int1 : !torch.vtensor<[128],f32>, !torch.float, !torch.int -> !torch.vtensor<[128],f32>
    %251 = torch.aten.sqrt %250 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32>
    %252 = torch.aten.add.Scalar %184, %float9.999990e-07, %int1 : !torch.vtensor<[128],f32>, !torch.float, !torch.int -> !torch.vtensor<[128],f32>
    %253 = torch.aten.sqrt %252 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32>
    %254 = torch.aten.div.Tensor %253, %251 : !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32>
    %255 = torch.aten.mul.Tensor %254, %113 : !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32>
    %256 = torch.aten.add.Tensor %arg5, %255, %float-1.000000e00 : !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32>
    %257 = torch.aten.mul.Scalar %187, %float9.000000e-01 : !torch.vtensor<[10,128],f32>, !torch.float -> !torch.vtensor<[10,128],f32>
    %258 = torch.aten.mul.Tensor %102, %102 : !torch.vtensor<[10,128],f32>, !torch.vtensor<[10,128],f32> -> !torch.vtensor<[10,128],f32>
    %259 = torch.aten.add.Tensor %257, %258, %float9.999990e-02 : !torch.vtensor<[10,128],f32>, !torch.vtensor<[10,128],f32>, !torch.float -> !torch.vtensor<[10,128],f32>
    %260 = torch.aten.add.Scalar %259, %float9.999990e-07, %int1 : !torch.vtensor<[10,128],f32>, !torch.float, !torch.int -> !torch.vtensor<[10,128],f32>
    %261 = torch.aten.sqrt %260 : !torch.vtensor<[10,128],f32> -> !torch.vtensor<[10,128],f32>
    %262 = torch.aten.add.Scalar %190, %float9.999990e-07, %int1 : !torch.vtensor<[10,128],f32>, !torch.float, !torch.int -> !torch.vtensor<[10,128],f32>
    %263 = torch.aten.sqrt %262 : !torch.vtensor<[10,128],f32> -> !torch.vtensor<[10,128],f32>
    %264 = torch.aten.div.Tensor %263, %261 : !torch.vtensor<[10,128],f32>, !torch.vtensor<[10,128],f32> -> !torch.vtensor<[10,128],f32>
    %265 = torch.aten.mul.Tensor %264, %102 : !torch.vtensor<[10,128],f32>, !torch.vtensor<[10,128],f32> -> !torch.vtensor<[10,128],f32>
    %266 = torch.aten.add.Tensor %arg6, %265, %float-1.000000e00 : !torch.vtensor<[10,128],f32>, !torch.vtensor<[10,128],f32>, !torch.float -> !torch.vtensor<[10,128],f32>
    %267 = torch.aten.mul.Scalar %193, %float9.000000e-01 : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
    %268 = torch.aten.mul.Tensor %101, %101 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
    %269 = torch.aten.add.Tensor %267, %268, %float9.999990e-02 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
    %270 = torch.aten.add.Scalar %269, %float9.999990e-07, %int1 : !torch.vtensor<[10],f32>, !torch.float, !torch.int -> !torch.vtensor<[10],f32>
    %271 = torch.aten.sqrt %270 : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
    %272 = torch.aten.add.Scalar %196, %float9.999990e-07, %int1 : !torch.vtensor<[10],f32>, !torch.float, !torch.int -> !torch.vtensor<[10],f32>
    %273 = torch.aten.sqrt %272 : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
    %274 = torch.aten.div.Tensor %273, %271 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
    %275 = torch.aten.mul.Tensor %274, %101 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
    %276 = torch.aten.add.Tensor %arg7, %275, %float-1.000000e00 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
    return %206, %216, %226, %236, %246, %256, %266, %276, %206, %216, %226, %236, %246, %256, %266, %276, %71 : !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128],f32>, !torch.vtensor<[10,128],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[32,1,3,3],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[64,32,3,3],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[128,9216],f32>, !torch.vtensor<[128],f32>, !torch.vtensor<[10,128],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[],f32>
  }
}


moomoohorse321 avatar Mar 14 '25 01:03 moomoohorse321