mdmm
mdmm copied to clipboard
Simple mathematical example from the (Plat and Barr 1987) paper
Many thanks for this package.
Would it be possible for you to add the following examples to your library. I believe that it is a bit more clear, and it fits nicely with the background of the package as they implement the examples of the paper written by Plat and Barr.
Kind Regards
Minimizing the distance to the origin such that it is on the line $y=1-x$. Essentially:
\min_{x, y} x^2+ y^2 \qquad \text{s.t.:} \,\, x + y - 1 = 0
"""
Pet problem from Constrain Differential Optimization (Plat and Barr 1987).
The paper discusses general problems of the following format.
min f(x)
s.t.: g(x)
And discusses among others the example of finding the closest point to the origin such that it is on the line y=1-x.
This can be found by minimizing the following equation:
min x^2 + y^2
s.t.: y + x - 1 =0
Which is being solved by this example.
Bram van der Heijden
2023
"""
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import mdmm
import torch
def f(x):
"""Cost function representing the square distance from the origin"""
loss = x[0]**2 + x[1]**2
return loss
def g(x):
"""Constraint function, representing a linear line y=1-x"""
loss = x[0] + x[1] - 1
return loss
# Create variables in our optimization, these need to create an Automatic Differentiation graph.
x = torch.tensor([4, 3], dtype=torch.float32, requires_grad=True)
# Create optimization problem.
constraint = mdmm.EqConstraint(partial(g, x), 0)
mdmm_module = mdmm.MDMM([constraint])
optimizer = mdmm_module.make_optimizer(x, lr=1e-1)
# Optimize the problem.
max_itr = 500
itr = np.arange(max_itr)
loss_itr = np.zeros(max_itr)
cons_itr = np.zeros(max_itr)
for i in range(max_itr):
# Print iteration updates.
loss_itr[i] = float(f(x))
cons_itr[i] = float(g(x))
print(f"Itr: {i:7,d}, loss {loss_itr[i]:4g}, conts {cons_itr[i]:4g}")
# Perform a step in the optimization procedure.
loss = f(x)
mdmm_return = mdmm_module(loss)
optimizer.zero_grad()
mdmm_return.value.backward()
optimizer.step()
# Exact solution is at (0.5, 0.5)
print("")
print("Result is: ", x.detach().numpy())
print("Error to exact is: ", float(torch.norm(x-torch.tensor([0.5, 0.5]))))
# Plot the convergence graph.
fig, ax = plt.subplots(1, 2)
ax[0].set_ylabel(f"Cost")
ax[0].set_xlabel(f"Iteration number")
ax[0].plot(itr, loss_itr)
ax[1].set_ylabel(f"Constraint")
ax[1].set_xlabel(f"Iteration number")
ax[1].plot(itr, cons_itr)
plt.show()
The optimization will find the approximation of the exact solution at $(x,y)=(0.5, 0.5)$ and converge appropriately.