unnormalized MPS after DMRG when getting samples in multithread mode
Problem
Hi @GTorlai,
I am trying to use your code with ITensors.jl to compute the groundstate of a 2D Heisenberg model and obtain random pauli measurements from it. The code works fine when starting julia normally without any arguments and including my function into REPL and then executing it. However I noticed that julia was not using the full potential of my CPU (only about 8 cores are used) and I looked into the getsamples function where threads are beeing used. Thats why I thought I could speed up the sampling process by starting julia with more threads i.e julia -t 15 could help. With that change, I was still only observing one cpu beeing utilized during the DMRG sweeps (but this is maybe another issue?) The sampling part of the code (example below) was using all 15 cores successfully. However for some Hamiltonians I obtained the following error:
ERROR: LoadError: TaskFailedException
Stacktrace:
[1] wait
@ ./task.jl:322 [inlined]
[2] threading_run(func::Function)
@ Base.Threads ./threadingconstructs.jl:34
[3] macro expansion
@ ./threadingconstructs.jl:93 [inlined]
[4] getsamples(M0::MPS, bases::Matrix{String}, nshots::Int64; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PastaQ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:189
[5] getsamples(M0::MPS, bases::Matrix{String}, nshots::Int64)
@ PastaQ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:182
[6] top-level scope
@ ~/provable-ml-quantum/julia_scripts/test_failed_simu.jl:25
[7] include(fname::String)
@ Base.MainInclude ./client.jl:444
[8] top-level scope
@ REPL[12]:1
nested task error: sample: MPS is not normalized, norm=1.0000609555601314
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] sample(m::MPS)
@ ITensors ~/.julia/environments/v1.6/dev/ITensors/src/mps/mps.jl:524
[3] getsamples!(M::MPS; readout_errors::NamedTuple{(:p1given0, :p0given1), Tuple{Nothing, Nothing}})
@ PastaQ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:260
[4] getsamples!
@ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:257 [inlined]
[5] macro expansion
@ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:161 [inlined]
[6] (::PastaQ.var"#118#threadsfor_fun#179"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MPS, Vector{Vector{Vector{Int64}}}, UnitRange{Int64}})(onethread::Bool)
@ PastaQ ./threadingconstructs.jl:81
[7] #invokelatest#2
@ ./essentials.jl:708 [inlined]
[8] invokelatest
@ ./essentials.jl:706 [inlined]
[9] macro expansion
@ ./threadingconstructs.jl:86 [inlined]
[10] getsamples(M0::MPS, nshots::Int64; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PastaQ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:159
[11] getsamples(M0::MPS, nshots::Int64)
@ PastaQ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:155
[12] macro expansion
@ ~/.julia/environments/v1.6/dev/PastaQ/src/circuits/getsamples.jl:193 [inlined]
[13] (::PastaQ.var"#133#threadsfor_fun#185"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Matrix{String}, Int64, MPS, Vector{Vector{Vector{Pair{String, Int64}}}}, Int64, UnitRange{Int64}})(onethread::Bool)
@ PastaQ ./threadingconstructs.jl:81
[14] (::PastaQ.var"#133#threadsfor_fun#185"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Matrix{String}, Int64, MPS, Vector{Vector{Vector{Pair{String, Int64}}}}, Int64, UnitRange{Int64}})()
@ PastaQ ./threadingconstructs.jl:48
in expression starting at ~/provable-ml-quantum/julia_scripts/test_failed_simu.jl:25
I then computed the normalization constants before and after the sampling process and printed them. The norm was always 1.000000 for the normal MPS and the dense one. So this error does not make sense to me. Could you please help me to find the bug?
Minimal working example
This is a minimal example to reproduce the bug and here are the files needed for that: https://cloud.ml.jku.at/s/FS6fj9oTpzGSfHx. The MPS state was obtained during a simulation whose code I will post below. This small script will produce the above error when starting julia in multithread mode (with -t 15 I'll see 22 sub-processes) and will just work fine when starting julia with (-t 1 but weirdly I still see 8 processes in htop)
# this is the file test_failed_simu.jl
using ITensors
using ITensors.HDF5
using LinearAlgebra
using PastaQ
using Random
using Printf
using JLD
using Debugger
f = h5open("./data/failed_simulation_5x5_id9_gs.jld", "r")
psi = read(f,"psi", MPS)
close(f)
# display(psi)
norm1 = norm(psi)
norm1d = norm(dense(psi))
@printf("norm before normalize %f, dense %f\n", norm1, norm1d)
Random.seed!(0)
other_vars = load("./data/failed_simulation_5x5_id9.jld")
N= 5*5
nshots=1000
samples = getsamples(dense(psi), randombases(N, nshots), 1)
original code
The following is some adapted code which was given to me by Hsing-Yuan Huang which was probably also code written by you? It generates the groundstates of different random Hamiltonians and samples from them, which are then stored. The problem occurs eventough I played with different values of minsweep, cutoff Λ, energy tolerance ϵ.
using ITensors
using ITensors.HDF5
using LinearAlgebra
using PastaQ
using Random
# using Plots
using Printf
# using Test
# using StatsBase
using JLD
using ProgressBars
using Debugger
using CodeTracking
# TODO move to GPU
# https://github.com/ITensor/ITensors.jl/blob/74db39bab0a1bcd6a02dd744a38c74714ad6ff83/ITensorGPU/examples/dmrg.jl
# https://github.com/ITensor/ITensors.jl/tree/main/ITensorGPU
function main()
n = 1
Random.seed!(n * 1234)
basepath = "data/"
Lx = 5
Ly = 5
N = Lx * Ly
nshots = 1000
Jmin = 0.0
Jmax = 2.0
npoints = 100
@printf("Lx=%i, Ly=%i ", Lx, Ly)
Λ = 1e-8
noise = 1e-9
χ = [10,10,20,20,50,100,200,200,500,1000,1500]
χ₀ = 10
ϵ = 1e-4
nsweeps = 500
minsweeps = 10
groundstate = nothing
for k in ProgressBar(1:npoints)
@printf("iteration %i", k)
sites = siteinds("Qubit", N; conserve_qns=true)
lattice = square_lattice(Lx, Ly; yperiodic=false, xperiodic=false)
J = Jmax * rand(length(lattice))
# println("couplings:")
# display(J)
# Define the Heisenberg spin Hamiltonian on this lattice
# for (b, bond) in enumerate(lattice)
# print(bond)
# end
ampo = AutoMPO()
for (b, bond) in enumerate(lattice)
ampo .+= J[b], "X", bond.s1, "X", bond.s2
ampo .+= J[b], "Y", bond.s1, "Y", bond.s2
ampo .+= J[b], "Z", bond.s1, "Z", bond.s2
end
H = MPO(ampo, sites)
st = [isodd(n) ? "1" : "0" for n = 1:N]
ψ₀ = randomMPS(sites, st)
sweeps = Sweeps(nsweeps)
maxdim!(sweeps, χ...)
cutoff!(sweeps, Λ)
noise!(sweeps, noise)
@printf("Running dmrg\n")
observer = DMRGObserver(["Z"], sites, energy_tol=ϵ, minsweeps=minsweeps)
E, ψ = dmrg(H, ψ₀, sweeps; observer=observer, outputlevel=1)
ψ = LinearAlgebra.normalize(ψ)
norm1 = norm(ψ)
norm1d = norm(dense(ψ))
@printf("norm after normalize %f, dense %f\n", norm1, norm1d)
@printf("Measurements\n")
SvN = entanglemententropy(ψ)
norm2 = norm(ψ)
norm2d = norm(dense(ψ))
@printf("norm after entanglemententropy %f, dense %f\n", norm2, norm2d)
## X = measure(ψ, "X")
## Z = measure(ψ, "Z")
# XX = measure(ψ, ("X", "X"))
# YY = measure(ψ, ("Y", "Y"))
# ZZ = measure(ψ, ("Z", "Z"))
# this is the same as above, but in new version
# X = expect(ψ, "X")
# Y = expect(ψ, "Y")
# Z = expect(ψ, "Z")
XX = correlation_matrix(ψ, "X", "X")
norm3 = norm(ψ)
norm3d = norm(dense(ψ))
@printf("norm after correlation_matrix XX %f, dense %f\n", norm3, norm3d)
YY = correlation_matrix(ψ, "Y", "Y")
norm4 = norm(ψ)
norm4d = norm(dense(ψ))
@printf("norm after correlation_matrix YY %f, dense %f\n", norm4, norm4d)
ZZ = correlation_matrix(ψ, "Z", "Z")
norm5 = norm(ψ)
norm5d = norm(dense(ψ))
@printf("norm after correlation_matrix ZZ %f, dense %f\n", norm5, norm5d)
# ψ = LinearAlgebra.normalize(ψ)
try
@printf("Sampling\n")
samples = getsamples(dense(ψ), randombases(N, nshots), 1) # nshots different meas. bases, each base with one shot
@bp
path = basepath * "simulation_$(Lx)x$(Ly)_id$(npoints * (n - 1) + k).jld"
JLD.jldopen(path, "w") do fout
JLD.write(fout, "J", J)
JLD.write(fout, "ZZ", ZZ)
JLD.write(fout, "YY", YY)
JLD.write(fout, "XX", XX)
JLD.write(fout, "E", E)
JLD.write(fout, "SvN", SvN)
JLD.write(fout, "samples", samples)
end
fout = h5open(basepath * "simulation_$(Lx)x$(Ly)_id$(npoints * (n - 1) + k)_gs.jld", "w")
write(fout, "psi", ψ)
close(fout)
catch e
showerror(stdout, e)
@printf("Sampling failed for sample %i\n", npoints * (n - 1) + k)
# saving error states
path = basepath * "failed_simulation_$(Lx)x$(Ly)_id$(npoints * (n - 1) + k).jld"
JLD.jldopen(path, "w") do fout
JLD.write(fout, "J", J)
JLD.write(fout, "ZZ", ZZ)
JLD.write(fout, "YY", YY)
JLD.write(fout, "XX", XX)
JLD.write(fout, "E", E)
JLD.write(fout, "SvN", SvN)
end
fout = h5open(basepath * "failed_simulation_$(Lx)x$(Ly)_id$(npoints * (n - 1) + k)_gs.jld", "w")
write(fout, "psi", ψ)
close(fout)
end
println("")
end
end
print(@code_string main())
if "" != PROGRAM_FILE && realpath(@__FILE__) == realpath(PROGRAM_FILE)
main()
end
Thank you in advance for your time and for providing this useful package!
PS.: Package info:
(@v1.6) pkg> status
Status `~/.julia/environments/v1.6/Project.toml`
[052768ef] CUDA v3.8.5
[da1fd8a2] CodeTracking v1.0.9
[31a5f54b] Debugger v0.7.6
[d89171c1] ITensorGPU v0.0.5
[9136182c] ITensors v0.3.15 `dev/ITensors`
[4138dd39] JLD v0.12.5
[2b0e0bc5] LanguageServer v4.2.0
[30b07047] PastaQ v0.0.23 `dev/PastaQ`
[91a5bcdd] Plots v1.29.1
[49802e3a] ProgressBars v1.4.1
[2913bbd2] StatsBase v0.33.16
(@v1.6) pkg>
The packages ITensors and PastaQ are unmodified and are just installed in development mode to make code search with visual studio code more reliable
It looks like in the threaded loop multiple threads are modifying the MPS in-place: https://github.com/GTorlai/PastaQ.jl/blob/54cc060b60a13b132d93578f7bf5516eab23927a/src/circuits/getsamples.jl#L161
I wonder if that is causing issues?
It looks like in the threaded loop multiple threads are modifying the MPS in-place:
https://github.com/GTorlai/PastaQ.jl/blob/54cc060b60a13b132d93578f7bf5516eab23927a/src/circuits/getsamples.jl#L161
I wonder if that is causing issues?
Hm I reckon that this is not the issue, since nshot is 1 in my setting. I am trying to measure 1 shot for many different bases. So the parallelization over nshots in line 161 would only create 1 single thread?
Also I tried to change sample_ = getsamples!(M; kwargs...) into sample_ = getsamples!(copy(M); kwargs...) howoever still experiencing the bug :|
Hi @VietTralala ,
I encountered the same issue. Have you solved the problem?