SSSS icon indicating copy to clipboard operation
SSSS copied to clipboard

Possible updates to PyTorch 1.10

Open jhualberta opened this issue 11 months ago • 1 comments

Thank you for your lecture, it is very useful. I know it has already been 5 years since these note originally released. Some functions in PyTorch has been deprecated. The followings are possible updates to PyTorch 1.10 and Python3.11 schrodinger.py

schrodinger.py

Replaced torch.symeig(H, eigenvectors=True) → ✅ torch.linalg.eigh(H) #########################################################

import numpy as np 
import torch
torch.set_default_dtype(torch.float64)
import torch.nn as nn
import matplotlib.pyplot as plt

class Schrodinger1D(nn.Module):
    def __init__(self, xmesh):
        super(Schrodinger1D, self).__init__()
        
        self.xmesh = xmesh
        self.potential = nn.Parameter(xmesh**2)

        nmesh = xmesh.shape[0]
        h2 = (xmesh[1] - xmesh[0]) ** 2
        self.K =   torch.diag(1/h2 * torch.ones(nmesh, dtype=xmesh.dtype), diagonal=0) \
                 - torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=1) \
                 - torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=-1)

    def _solve(self):
        H = torch.diag(self.potential) + self.K
        eigvals, eigvecs = torch.linalg.eigh(H)  # Replaced deprecated symeig
        return eigvecs[:, 0]  # Ground state (corresponding to smallest eigenvalue)

    def forward(self, target):
        psi = self._solve()
        return (psi**2 - target).abs().sum()

    def plot(self, target):
        psi = self._solve().detach()

        plt.cla()
        plt.plot(self.xmesh.numpy(), target.numpy(), label='Target Density')
        plt.plot(self.xmesh.numpy(), psi.square().numpy(), label='Current Density')
        plt.plot(self.xmesh.numpy(), self.potential.detach().numpy()/10000, label='Potential (V/10000)')
        plt.legend()
        plt.draw()

if __name__ == '__main__':
    # Prepare mesh and target density
    xmin, xmax, Nmesh = -1, 1, 500
    xmesh = torch.linspace(xmin, xmax, Nmesh)
    
    target = torch.zeros(Nmesh)
    idx = torch.where(torch.abs(xmesh) < 0.5)
    target[idx] = 1. - torch.abs(xmesh[idx])
    target = (target / torch.norm(target))**2
    
    model = Schrodinger1D(xmesh)
    optimizer = torch.optim.LBFGS(
        model.parameters(), 
        max_iter=10, 
        tolerance_change=1E-7, 
        tolerance_grad=1E-7, 
        line_search_fn='strong_wolfe'
    )

    def closure():
        optimizer.zero_grad()
        loss = model(target)  # Density difference 
        loss.backward()
        return loss 

    plt.ion()
    for epoch in range(50):
        loss = optimizer.step(closure)
        print(epoch, loss.item())
        model.plot(target)
        plt.pause(0.01)

    plt.ioff()
    model.plot(target)
    plt.show()

jhualberta avatar Jan 30 '25 00:01 jhualberta

Thank you. Glad you find it is useful. There might be multiple places that need updating.

For now, I will leave it as it is and keep your issue open. Otherwise, you are welcome to submit a PR.

wangleiphy avatar Jan 30 '25 01:01 wangleiphy