burn
burn copied to clipboard
Add "load pytorch tensor" section into the burn book
Issue based on discussion #2315, @antimora
To my best knowledge, here's how to load a tensor:
-
In python: Ensure you wrap the tensor with dict before save, e.g.
torch.save({"some_key": tensor}, "path/to/tensor.pt") -
In rust:
#[derive(Module, Debug)] struct FloatTensor<B: Backend, const D: usize> { some_key: Param<Tensor<B, D>>, } fn main() { type B = NdArray; let device = Default::default(); let tensor: FloatTensorRecord<B, 3> = PyTorchFileRecorder::<FullPrecisionSettings>::new() .load("path/to/tensor.pt".into(), &device) .unwrap(); let tensor = tensor.some_key.val(); }
Hi, could I take up this issue?
Hi, could I take up this issue?
Yes! Please go ahead. We would appreciate your contribution. Let me know if you need more info.
It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible.
Just to clarify, this is an issue about improving documentation?
It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible.
Just to clarify, this is an issue about improving documentation?
Hi @ARelaxedScholar , yes it's about updating the documentation. Any help is greatly appreciated!
It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible. Just to clarify, this is an issue about improving documentation?
Hi @ARelaxedScholar , yes it's about updating the documentation. Any help is greatly appreciated!
Amazing, I'd love to help!
Sorry if it's silly but I have two questions before I start.
- Where do I write/update the documentation (this will be my first legit open source contribution.) I have skimmed through the contributor-book and will take the time to peruse it later, but I didn't see any indications for where documentation changes would be done.
- Is there a documentation/writing style that you're looking for? I tend to have a strong voice in my writing—which I imagine might not be ideal for more technical documents. Would there be any references that I can shadow? Is the burn.dev book itself that reference.
Thanks for welcoming me. :)
It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible. Just to clarify, this is an issue about improving documentation?
Hi @ARelaxedScholar , yes it's about updating the documentation. Any help is greatly appreciated!
Never mind, apparently I am blind. Lol. I'm currently in exam season, but I'll start doing it when I am done.
Just encountered this issue on my project. Agree the docs need a section on this. I'm writing up my notes here in case it's helpful. I have a few months experience with rust, more experience with python and torch. I'm guessing I am the target audience for the doc.
The solution that worked for me was in #2315. The solution in this issue almost worked but I think has a typo and should have last line:
let tensor = tensor.some_key.val();
The most mysterious thing to me was where things like FloatTensorRecord came from. Searching through rust docs was more confusing than anything on this. My current understanding it that is comes from the #[derive(Module)] on the FloatTensor struct.
The fact that rust needs the tensor to be saved inside a dictionary is not obvious (but easy to work around). Definitely would be good to document this.
Some other "dead ends" I went down involved not having the Tensor in a Param. Then things seemed to work but I had all zeros in the tensor. I did find explanations of what the Param means in other parts of the docs.
For document organization: I was porting a model from pytorch to rust. I had the basic architecture defined and managed to load weights from the pytorch file. The next step was to validate the rust model on test inputs. I wanted reassurance that the model was doing the right thing before I ported over all the input preprocessing stuff from pytorch.
The solution in this issue almost worked but I think has a typo and should have last line:
let tensor = tensor.some_key.val();
Thanks for pointing this out! I have edited the typo.
I also met this problem when I was porting a model from pytorch to rust. Maybe we should expand the goal here and further improve the "Import Models - PyTorch Model" part to include how to test outputs, which would involve loading tensors and comparing result using all_close.
@nwhitehead, @med1844 thanks for your inputs. Yes, I agree with you. I came to the same conclusion about the need to verify if a ported model matched the original model. We can definitely improve the experience since it'd be a common pattern.