cooper icon indicating copy to clipboard operation
cooper copied to clipboard

Adding a much simpler example

Open AJJLagerweij opened this issue 1 year ago • 0 comments

Enhancement

To aid the understanding of this great library one could add a very simple example problem, a proposal is shown below.

Motivation

Currently, the optimization problem in the examples are rather complex. To make using the library more intuitive, one can add a dummy example that is so simple that it can be performed by hand, but that would similarly make the approach easy to understand.

Alternatives

I propose solving the following: minimize the distance to the origin such that it is on the line $y=1-x$. Essentially:

 \min_{x, y} x^2+ y^2 \qquad  \text{s.t.:} \,\, x + y - 1 = 0

an example problem discussed in the (Platt and Barr 1987) paper on constrained optimizations and led to the development of another Lagrange Multiplier Optimization library called mdmm. The cooper implementation is:

"""
Pet problem from Constrain Differential Optimization (Plat and Barr 1987).

The paper discusses general problems of the following format.
    min f(x)
        s.t.: g(x)

And discusses among others the example of finding the closest point to the origin such that it is on the line y=1-x.
This can be found by minimizing the following equation:
    min x^2 + y^2
        s.t.: y + x - 1 =0
Which is being solved by this example.

Bram van der Heijden
2023
"""
import matplotlib.pyplot as plt
import numpy as np
import cooper
import torch


class ConstraintOptimizer(cooper.ConstrainedMinimizationProblem):
    def __init__(self, loss, constraint):
        self.loss = loss
        self.constraint = constraint
        super().__init__(is_constrained=True)

    def closure(self, x):
        # Compute the cost function.
        loss = self.loss(x)

        # Compute the violation of the constraint function.
        eq_defect = self.constraint(x)

        return cooper.CMPState(loss=loss, eq_defect=eq_defect)


def f(x):
    """Cost function representing the square distance from the origin"""
    loss = x[0]**2 + x[1]**2
    return loss


def g(x):
    """Constraint function, representing a linear line y=1-x"""
    loss = x[0] + x[1] - 1
    return loss


# Define optimization problem.
cmp = ConstraintOptimizer(f, g)
formulation = cooper.LagrangianFormulation(cmp)

# Define the primal variables and optimizer.
x = torch.tensor([4, 3], dtype=torch.float32, requires_grad=True)
primal_optimizer = cooper.optim.ExtraAdam([x], lr=0.2)

# Define the dual optimizer, not fully instantiated.
dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraAdam, lr=0.05)

# Wrap the optimizer together.
constraint_optimizer = cooper.ConstrainedOptimizer(formulation, primal_optimizer, dual_optimizer)

# Perform the training.
max_itr = 250
itr = np.arange(max_itr)
loss_itr = np.zeros(max_itr)
cons_itr = np.zeros(max_itr)
for i in range(max_itr):
    # Print iteration updates.
    loss_itr[i] = float(f(x))
    cons_itr[i] = float(g(x))
    print(f"Itr: {i:7,d}, loss {loss_itr[i]:4g}, conts {cons_itr[i]:4g}")

    # Perform actual minimization.
    constraint_optimizer.zero_grad()
    lagrangian = formulation.composite_objective(cmp.closure, x)
    formulation.custom_backward(lagrangian)
    constraint_optimizer.step(cmp.closure, x)

# Exact solution is at (0.5, 0.5)
print("")
print("Result is: ", x.detach().numpy())
print("Error to exact is: ", float(torch.norm(x-torch.tensor([0.5, 0.5]))))

# Plot the convergence graph.
fig, ax = plt.subplots(1, 2)
ax[0].set_ylabel(f"Cost")
ax[0].set_xlabel(f"Iteration number")
ax[0].plot(itr, loss_itr)
ax[1].set_ylabel(f"Constraint")
ax[1].set_xlabel(f"Iteration number")
ax[1].plot(itr, cons_itr)
plt.show()

References

  1. J. C. Platt en A. H. Barr, ‘Constrained differential optimization’, in Proceedings of the 1987 international conference on neural information processing systems, in NIPS’87. Cambridge, MA, USA: MIT Press, 1987, pp. 612-621.
  2. K. Crowson (crowsonkb), ‘mdmm’, available on GitHub

AJJLagerweij avatar Aug 22 '23 07:08 AJJLagerweij