Adding the sys.path doesn't functional in train.py
I need to load a checkpoint trained from other folder. When I use 'torch.load', it throw an error since I change the working directory. I search the solution from internet the solution. And I need to add the two line of code: ' import sys sys.path.insert(0, '/my/model/path')' I try this code in Jupiternotebook which located in the root dir of the template, it works. However, I add the code into train.py whose location is root_dir/src/train.py. The torch.load doesn't work.
I think the problem is in the pyrootutils module. But I'm not familiar with it.
Hi @yangze68 ,
I don't know your use case but are you sure you need to use torch.load()? Lightning is built with other ways of loading checkpoint in mind, for example:
model = LitModel.load_from_checkpoint("best_model.ckpt")
https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_basic.html
If you need to add path to the pythonpath, I don't think pyrootuils package interferes in any way if you use these defaults:
https://github.com/ashleve/lightning-hydra-template/blob/8987b23b7f991a3de3f043058abd7ba4f63ea13f/src/train.py#L14
(the only thing it does is adding the root dir to the pythonpath and setting/loading env variables, there's no reason you shouldn't be able to add more paths with sys.path.insert)
Notice sys.path.insert(0, "/absolute/path/to/directory") requires absolute path, make sure your path is not relative.
Also notice your working directory changes depending on whether you run your script from root_dir/ or src/
Thanks for your reply. I checked my code. I find that I'm trying to load a model from another similar template, which has the same folder structure as the current. The interpreter constantly searches the related code in this folder, even though I insert the extra path I want, which is why raises error.
Do you think this is related to adding the current root dir to the $PATH?
Since I'm doing a complicated task that can't use the Pytorch Lightning, I have to write my own code. Therefore, I need to load the model using torch.load().
Thank you again for your excellent work, The template is really helpful.
I don't know if I understand correctly, but if you have two different folders/projects added to PYTHONPATH, both of which have the same folder structure (e.g. src/models/...), then the interpreter won't know which module you're really trying to import.
You are correct. Thanks for your help!