pymanopt
pymanopt copied to clipboard
Question About Data Batches
Hi there, My cost function can be defined as a sum of functions that each function is data depended. My problem is that I have allot of data and I cant create my cost function at once, otherwise I will run out of memory. Therefore, I'm trying to use batch-optimization, in each iteration creating a part of my cost function based on one batch and optimize it, in the next iteration use the argmin of last iteration as starting point for a new cost function which is the cost function from previous batches plus a new part created from the current batch.
I'm using Pytorch implementation, and I'm not sure how to implement the batch optimization using Pymanopt.
Any ideas?
Regards Yuval
Hello Yuval,
You could create a new problem structure for each cost function?
That is: you create a problem structure for your first cost function and run a solver on it with options set so that it'll only run a few iterations / for some capped amount of time. Then, you collect the best x produced by that solver, delete the current problem from memory (to free space), create a new problem structure for your second cost function, and run a solver on that new problem, initialized at the x you just found. And repeat.
Would that do the trick for you?
Best, Nicolas
@NicolasBoumal thanks! It seems it did the trick. Unfortunately working with a lot of data just with CPU is very slow, it there any plans to add support for GPUs?
Regards Yuval
I cannot speak for the PyManopt lead team, but I imagine GPU support is not part of near future plans. We have partial GPU support in the Matlab version of Manopt, in case that is of interest to you.
Stopping by to leave this cross-ref here; maybe that's interesting @yuvalH9? It appears @Raph-AI is using a GPU/pymanopt/tensorflow2 combo :upside_down_face:
Another option might be to use a stateful data loader that your cost closes over. Something along those lines should work without the need to re-instantiate new problem classes:
data_loader = ...
manifold = ...
compute_batch_cost = ...
@pymanopt.function.PyTorch(manifold)
def cost(point):
batch = data_loader.get_batch()
return compute_batch_cost(batch, point)
GPU support is currently not available as pymanopt's core is written on top of numpy. The different backends we have are currently only used for their autodiff capabilities. We plan to address this in the future though, either by rewriting the core on top of JAX or via a general backend abstraction. Work on that has not even started yet though, so I cannot give an ETA yet.