burn icon indicating copy to clipboard operation
burn copied to clipboard

Add "load pytorch tensor" section into the burn book

Open med1844 opened this issue 1 year ago • 9 comments

Issue based on discussion #2315, @antimora

To my best knowledge, here's how to load a tensor:

  1. In python: Ensure you wrap the tensor with dict before save, e.g.

    torch.save({"some_key": tensor}, "path/to/tensor.pt")
    
  2. In rust:

    #[derive(Module, Debug)]
    struct FloatTensor<B: Backend, const D: usize> {
        some_key: Param<Tensor<B, D>>,
    }
    
    fn main() {
        type B = NdArray;
        let device = Default::default();
        let tensor: FloatTensorRecord<B, 3> =
            PyTorchFileRecorder::<FullPrecisionSettings>::new()
                .load("path/to/tensor.pt".into(), &device)
                .unwrap();
        let tensor = tensor.some_key.val();
    }
    

med1844 avatar Sep 30 '24 00:09 med1844

Hi, could I take up this issue?

csking101 avatar Oct 26 '24 02:10 csking101

Hi, could I take up this issue?

Yes! Please go ahead. We would appreciate your contribution. Let me know if you need more info.

antimora avatar Oct 28 '24 18:10 antimora

It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible.

Just to clarify, this is an issue about improving documentation?

ARelaxedScholar avatar Dec 10 '24 04:12 ARelaxedScholar

It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible.

Just to clarify, this is an issue about improving documentation?

Hi @ARelaxedScholar , yes it's about updating the documentation. Any help is greatly appreciated!

antimora avatar Dec 10 '24 19:12 antimora

It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible. Just to clarify, this is an issue about improving documentation?

Hi @ARelaxedScholar , yes it's about updating the documentation. Any help is greatly appreciated!

Amazing, I'd love to help!

Sorry if it's silly but I have two questions before I start.

  1. Where do I write/update the documentation (this will be my first legit open source contribution.) I have skimmed through the contributor-book and will take the time to peruse it later, but I didn't see any indications for where documentation changes would be done.
  2. Is there a documentation/writing style that you're looking for? I tend to have a strong voice in my writing—which I imagine might not be ideal for more technical documents. Would there be any references that I can shadow? Is the burn.dev book itself that reference.

Thanks for welcoming me. :)

ARelaxedScholar avatar Dec 10 '24 22:12 ARelaxedScholar

It seems that this issue is still open and there hasn't been any updates, I would love to take up the issue if possible. Just to clarify, this is an issue about improving documentation?

Hi @ARelaxedScholar , yes it's about updating the documentation. Any help is greatly appreciated!

Never mind, apparently I am blind. Lol. I'm currently in exam season, but I'll start doing it when I am done.

ARelaxedScholar avatar Dec 16 '24 01:12 ARelaxedScholar

Just encountered this issue on my project. Agree the docs need a section on this. I'm writing up my notes here in case it's helpful. I have a few months experience with rust, more experience with python and torch. I'm guessing I am the target audience for the doc.

The solution that worked for me was in #2315. The solution in this issue almost worked but I think has a typo and should have last line: let tensor = tensor.some_key.val();

The most mysterious thing to me was where things like FloatTensorRecord came from. Searching through rust docs was more confusing than anything on this. My current understanding it that is comes from the #[derive(Module)] on the FloatTensor struct.

The fact that rust needs the tensor to be saved inside a dictionary is not obvious (but easy to work around). Definitely would be good to document this.

Some other "dead ends" I went down involved not having the Tensor in a Param. Then things seemed to work but I had all zeros in the tensor. I did find explanations of what the Param means in other parts of the docs.

For document organization: I was porting a model from pytorch to rust. I had the basic architecture defined and managed to load weights from the pytorch file. The next step was to validate the rust model on test inputs. I wanted reassurance that the model was doing the right thing before I ported over all the input preprocessing stuff from pytorch.

nwhitehead avatar Jan 13 '25 23:01 nwhitehead

The solution in this issue almost worked but I think has a typo and should have last line: let tensor = tensor.some_key.val();

Thanks for pointing this out! I have edited the typo.

I also met this problem when I was porting a model from pytorch to rust. Maybe we should expand the goal here and further improve the "Import Models - PyTorch Model" part to include how to test outputs, which would involve loading tensors and comparing result using all_close.

med1844 avatar Jan 14 '25 03:01 med1844

@nwhitehead, @med1844 thanks for your inputs. Yes, I agree with you. I came to the same conclusion about the need to verify if a ported model matched the original model. We can definitely improve the experience since it'd be a common pattern.

antimora avatar Jan 14 '25 16:01 antimora