MP-SPDZ
MP-SPDZ copied to clipboard
How to implement ResNet?
I try to write Resnet18 as this:
program.options_from_args()
from Compiler import ml
try:
ml.set_n_threads(int(program.args[2]))
except:
pass
get_data = lambda train, transform=None: torchvision.datasets.CIFAR10(
root='/tmp', train=train, download=True, transform=transform)
import torchvision, numpy
data = []
for train in True, False:
ds = get_data(train)
# normalize to [-1,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255 * 2 - 1, binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
class RestNetBasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super(RestNetBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
output = self.conv1(x)
output = F.relu(self.bn1(output))
output = self.conv2(output)
output = self.bn2(output)
return F.relu(x + output)
class RestNetDownBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super(RestNetDownBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.extra = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
extra_x = self.extra(x)
output = self.conv1(x)
out = F.relu(self.bn1(output))
out = self.conv2(out)
out = self.bn2(out)
return F.relu(extra_x + out)
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
RestNetBasicBlock(64, 64, 1))
self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
RestNetBasicBlock(128, 128, 1))
self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
RestNetBasicBlock(256, 256, 1))
self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
RestNetBasicBlock(512, 512, 1))
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(512, 10)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = out.reshape(x.shape[0], -1)
out = self.fc(out)
return out
net = ResNet18()
# train for a bit
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
ds = get_data(train=True, transform=transform)
optimizer = torch.optim.Adam(net.parameters(), amsgrad=True)
criterion = nn.CrossEntropyLoss()
for i, data in enumerate(torch.utils.data.DataLoader(ds, batch_size=128)):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
with torch.no_grad():
ds = get_data(False, transform)
total = correct_classified = 0
for data in torch.utils.data.DataLoader(ds, batch_size=128):
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct_classified += (predicted == labels).sum().item()
test_acc = (100 * correct_classified / total)
print('Cleartext test accuracy of the network: %.2f %%' % test_acc)
layers = ml.layers_from_torch(net, training_samples.shape, 128, input_via=0)
optimizer = ml.SGD(layers)
optimizer.fit(
training_samples,
training_labels,
epochs=int(1),
batch_size=128,
validation_data=(test_samples, test_labels),
program=program,
reset=False
)
The error is CompilerError: unknown PyTorch module: ResNet18.
It seems I can't pass a self-defined module to the Compiler. Is there any example of ResNet18 inference in MP-SPDZ?
The PyTorch interface only supports sequential networks, but ResNet contains an addition and thus isn't sequential. We have implemented ResNet-50 inference, which you can run as follows from the MP-SPDZ root directory:
git clone https://github.com/mkskeller/EzPC
cd EzPC/Athos/Networks/ResNet
axel -a -n 5 -c --output ./PreTrainedModel http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz
cd PreTrainedModel && tar -xvzf resnet_v2_fp32_savedmodel_NHWC.tar.gz && cd ..
python3 ResNet_main.py --runPrediction True --scalingFac 12 --saveImgAndWtData True
cd ../../../..
Scripts/fixed-rep-to-float.py EzPC/Athos/Networks/ResNet/ResNet_img_input.inp
Scripts/compile-emulate.py tf EzPC/Athos/Networks/ResNet/graphDef.bin 8
You can change the last line to the compile-run.sh -E <protocol>
.
I have implemented a simple training code for residual blocks in ml.py, and I hope it may bring some motivations for you. If anyone has implemented a complete ResNet training, I am really looking forward to it being open source.
class SimpleRes_Linear(DenseBase):
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
if activation == 'id':
self.activation_layer = None
elif activation == 'relu':
self.activation_layer = Relu([N, d, d_out])
elif activation == 'square':
self.activation_layer = Square([N, d, d_out])
else:
raise CompilerError('activation not supported: %s', activation)
self.N = N
self.d_in = d_in
self.d_out = d_out
self.d = d
self.activation = activation
self.X = MultiArray([N, d, d_in], sfix)
self.Y = MultiArray([N, d, d_out], sfix)
self.W = Tensor([d_in, d_out], sfix)
self.b = sfix.Array(d_out)
back_N = min(N, self.back_batch_size)
self.nabla_Y = MultiArray([back_N, d, d_out], sfix)
self.nabla_X = MultiArray([back_N, d, d_in], sfix)
self.nabla_W = sfix.Matrix(d_in, d_out)
self.nabla_b = sfix.Array(d_out)
self.debug = debug
l = self.activation_layer
if l:
self.f_input = l.X
l.Y = self.Y
l.nabla_Y = self.nabla_Y
else:
self.f_input = self.Y
def __repr__(self):
return '%s(%s, %s, %s, activation=%s)' % \
(type(self).__name__, self.N, self.d_in,
self.d_out, repr(self.activation))
def reset(self):
d_in = self.d_in
d_out = self.d_out
r = math.sqrt(6.0 / (d_in + d_out))
print('Initializing dense weights in [%f,%f]' % (-r, r))
self.W.randomize(-r, r)
self.b.assign_all(0)
def input_from(self, player, raw=False):
self.W.input_from(player, raw=raw)
if self.input_bias:
self.b.input_from(player, raw=raw)
def compute_f_input(self, batch):
N = len(batch)
assert self.d == 1
if self.input_bias:
prod = MultiArray([N, self.d, self.d_out], sfix)
else:
prod = self.f_input
max_size = program.Program.prog.budget // self.d_out
@multithread(self.n_threads, N, max_size)
def _(base, size):
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
prod.assign_part_vector(
X_sub.direct_mul(self.W, indices=(
batch.get_vector(base, size), regint.inc(self.d_in),
regint.inc(self.d_in), regint.inc(self.d_out))), base)
if self.input_bias:
if self.d_out == 1:
@multithread(self.n_threads, N)
def _(base, size):
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)+self.X.expand_to_vector(0, size)
self.f_input.assign_vector(v, base)
else:
@for_range_multithread(self.n_threads, 100, N)
def _(i):
v = prod[i].get_vector() + self.b.get_vector() + self.X.get_vector()
self.f_input[i].assign_vector(v)
progress('f input')
def _forward(self, batch=None):
if batch is None:
batch = regint.Array(self.N)
batch.assign(regint.inc(self.N))
self.compute_f_input(batch=batch)
if self.activation_layer:
self.activation_layer.forward(batch)
if self.debug_output:
print_ln('dense X %s', self.X.reveal_nested())
print_ln('dense W %s', self.W.reveal_nested())
print_ln('dense b %s', self.b.reveal_nested())
print_ln('dense Y %s', self.Y.reveal_nested())
if self.debug:
limit = self.debug
@for_range_opt(len(batch))
def _(i):
@for_range_opt(self.d_out)
def _(j):
to_check = self.Y[i][0][j].reveal()
check = to_check > limit
@if_(check)
def _():
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
print_ln('X %s', self.X[i].reveal_nested())
print_ln('W %s',
[self.W[k][j].reveal() for k in range(self.d_in)])
def backward(self, compute_nabla_X=True, batch=None):
N = len(batch)
d = self.d
d_out = self.d_out
X = self.X
Y = self.Y
W = self.W
b = self.b
nabla_X = self.nabla_X
nabla_Y = self.nabla_Y
nabla_W = self.nabla_W
nabla_b = self.nabla_b
if self.activation_layer:
self.activation_layer.backward(batch)
f_schur_Y = self.activation_layer.nabla_X
else:
f_schur_Y = nabla_Y
if compute_nabla_X:
@multithread(self.n_threads, N)
def _(base, size):
B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
nabla_X.assign_part_vector(
B.direct_mul_trans(W, indices=(regint.inc(size, base),
regint.inc(self.d_out),
regint.inc(self.d_out),
regint.inc(self.d_in))),
base)
nabla_X[:]+=sfix.from_sint(1)
print('res')
if self.print_random_update:
print_ln('backward %s', self)
index = regint.get_random(64) % self.nabla_X.total_size()
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
index, self.nabla_X.to_array()[index].reveal())
progress('nabla X')
self.backward_params(f_schur_Y, batch=batch)
You should find that version 0.3.9 now supports using non-sequential PyTorch networks: https://github.com/data61/MP-SPDZ/blob/master/Programs/Source/torch_resnet.py