torchTT
torchTT copied to clipboard
a problem with using torchTT.interpolate.dmrg_cross()
I tried to use interpolate.dmrg_cross() from a numpy array.
Generate a random array:
import numpy as np
import torch as tn
import torchtt as tntt
test = np.random.rand(10, 10, 10)
Define the function in this way:
def func(args):
return tn.tensor([tn.from_numpy(test)[*args[0]]], dtype=tn.complex128)
When I tried to use dmg_cross to interpolate this exemplary random tensor,
N = list(test.shape)
x = tntt.interpolate.dmrg_cross(func, N, eps=10**(-8))
I got the error message:
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevconsole.py", line 364, in runcode
coro = func()
^^^^^^
File "<input>", line 1, in <module>
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/boyuanshi/Desktop/second_project/equilibrium_v2/Screened_Interactions_Plot.py", line 61, in <module>
x = tntt.interpolate.dmrg_cross(func, N, eps=10**(-8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/boyuanshi/.conda/envs/Desktop/lib/python3.11/site-packages/torchTT-2.0-py3.11-macosx-10.9-x86_64.egg/torchtt/interpolate.py", line 514, in dmrg_cross
supercore = tn.reshape(function(eval_index),[rank[k],N[k],N[k+1],rank[k+2]])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 10, 10, 2]' is invalid for input of size 1
I am quite confused why it is the case?