SSSS
SSSS copied to clipboard
Possible updates to PyTorch 1.10
Thank you for your lecture, it is very useful. I know it has already been 5 years since these note originally released. Some functions in PyTorch has been deprecated. The followings are possible updates to PyTorch 1.10 and Python3.11 schrodinger.py
schrodinger.py
Replaced torch.symeig(H, eigenvectors=True) → ✅ torch.linalg.eigh(H) #########################################################
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
import torch.nn as nn
import matplotlib.pyplot as plt
class Schrodinger1D(nn.Module):
def __init__(self, xmesh):
super(Schrodinger1D, self).__init__()
self.xmesh = xmesh
self.potential = nn.Parameter(xmesh**2)
nmesh = xmesh.shape[0]
h2 = (xmesh[1] - xmesh[0]) ** 2
self.K = torch.diag(1/h2 * torch.ones(nmesh, dtype=xmesh.dtype), diagonal=0) \
- torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=1) \
- torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=-1)
def _solve(self):
H = torch.diag(self.potential) + self.K
eigvals, eigvecs = torch.linalg.eigh(H) # Replaced deprecated symeig
return eigvecs[:, 0] # Ground state (corresponding to smallest eigenvalue)
def forward(self, target):
psi = self._solve()
return (psi**2 - target).abs().sum()
def plot(self, target):
psi = self._solve().detach()
plt.cla()
plt.plot(self.xmesh.numpy(), target.numpy(), label='Target Density')
plt.plot(self.xmesh.numpy(), psi.square().numpy(), label='Current Density')
plt.plot(self.xmesh.numpy(), self.potential.detach().numpy()/10000, label='Potential (V/10000)')
plt.legend()
plt.draw()
if __name__ == '__main__':
# Prepare mesh and target density
xmin, xmax, Nmesh = -1, 1, 500
xmesh = torch.linspace(xmin, xmax, Nmesh)
target = torch.zeros(Nmesh)
idx = torch.where(torch.abs(xmesh) < 0.5)
target[idx] = 1. - torch.abs(xmesh[idx])
target = (target / torch.norm(target))**2
model = Schrodinger1D(xmesh)
optimizer = torch.optim.LBFGS(
model.parameters(),
max_iter=10,
tolerance_change=1E-7,
tolerance_grad=1E-7,
line_search_fn='strong_wolfe'
)
def closure():
optimizer.zero_grad()
loss = model(target) # Density difference
loss.backward()
return loss
plt.ion()
for epoch in range(50):
loss = optimizer.step(closure)
print(epoch, loss.item())
model.plot(target)
plt.pause(0.01)
plt.ioff()
model.plot(target)
plt.show()
Thank you. Glad you find it is useful. There might be multiple places that need updating.
For now, I will leave it as it is and keep your issue open. Otherwise, you are welcome to submit a PR.