lipEstimation copied to clipboard
Computation of the "exact" method not working
Hi there, I just did a check of this repository because I might have to use it for my phD but it turn out I highlighed some unwanted behavior of the script. To check the implementation I built a model whose lipschtiz constant is known (4 exactly) and attained at a specific vector (call main_direct in the script). It is attained for all activation of the ReLU positive We see that processing the vector does give an increase of factor 4 of the norm. The computation of the spectral norm of the Jacobian of the network does give the right value: 4 Therefore I was puzzled when the function lipschitz_second_order_ub in exact mode didn't give the correct answer. Please tell me what am I doing wrong
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import sys
from functools import reduce
from lipschitz_approximations import lipschitz_second_order_ub
from lipschitz_utils import *
def compute_spectral_norm(weight,n_power_iterations,do_power_iteration=True,eps=1e-12):
weight_mat = weight
h, w = weight_mat.size()
# randomly initialize `u` and `v`
u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=eps)
v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=eps)
if do_power_iteration:
with torch.no_grad():
for _ in range(n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = F.normalize(, u), dim=0, eps=eps, out=v)
u = F.normalize(, v), dim=0, eps=eps, out=u)
if n_power_iterations > 0:
# See above on why we need to clone
u = u.clone()
v = v.clone()
sigma =,, v))
return sigma
def rvs(dim=3):
#function to compute a random orthogonal matrix
random_state = np.random
H = np.eye(dim)
D = np.ones((dim,))
for n in range(1, dim):
x = random_state.normal(size=(dim-n+1,))
D[n-1] = np.sign(x[0])
x[0] -= D[n-1]*np.sqrt((x*x).sum())
# Householder transformation
Hx = (np.eye(dim-n+1) - 2.*np.outer(x, x)/(x*x).sum())
mat = np.eye(dim)
mat[n-1:, n-1:] = Hx
H =, mat)
# Fix the last sign such that the determinant is 1
D[-1] = (-1)**(1-(dim % 2))*
# Equivalent to, H) but faster, apparently
H = (D*H.T).T
return H
def reconfigure_ortho(ortho):
# Configuring orthogonal matrix to make first vector all positive
# (it keeps the orthogonality)
for k in range(ortho.shape[0]):
if ortho[k,0]<0:
ortho[k,0] = -ortho[k,0]
for j in range(1,ortho.shape[0]):
ortho[k,j] = -ortho[k,j]
return ortho
def generate_weight_matrix(size):
#function to align fist vectors
weights = []
previous_orth = rvs(size)
previous_orth = reconfigure_ortho(previous_orth)
main_direct = previous_orth.T[0]
for k in range(3):
new_orth = rvs(size)
new_orth = reconfigure_ortho(new_orth)
sig = np.ones(size)
weights.append((new_orth@ np.diag(sig)@ previous_orth.T))
previous_orth = new_orth
return weights,main_direct
def testing_virmaux_on_crafted_model(size):
layer1 = nn.Linear(size,size,bias=False)
layer2 = nn.Linear(size,size,bias=False)
model = nn.Sequential(layer1,nn.ReLU(),layer2)
weights,main_direct = generate_weight_matrix(size)
layer1.weight = nn.Parameter(torch.FloatTensor(weights[0]))
layer2.weight = nn.Parameter(torch.FloatTensor(weights[1])),""),"")
# model = torch.load("")
# main_direct = torch.load("")
sigma = np.diag(np.ones(size))
weights_prod =[0].T,sigma),weights[1].T)
print("input main direction norm")
norm_input = torch.norm(torch.FloatTensor(main_direct)).item()
norm_output = torch.norm(model(torch.FloatTensor(main_direct))).item()
print("Lipschitz constant is at least equal to:",norm_output/norm_input)
input_size = X_train[[0]].size()
compute_module_input_sizes(model, input_size)
print("computation lipschitz constant by Virmaux")
print(lipschitz_second_order_ub(model, algo='exact'))
print("computation Lipschitz constant following Virmaux principle")
value = compute_spectral_norm(torch.tensor(weights_prod),n_power_iterations=100,do_power_iteration=True,eps=1e-12).item()
if __name__ == "__main__":
Here is the output:
input main direction norm
Lipschitz constant is at least equal to: 4.0
computation lipschitz constant by Virmaux
ratio s 0.5
factor abs prod: 0.999999991079676
100%|#####################################################| 1024/1024 [00:00<00:00, 17019.34it/s]
factor 0.8029070624208176
computation Lipschitz constant following Virmaux principle
tensor(4., dtype=torch.float64)