deepsvg icon indicating copy to clipboard operation
deepsvg copied to clipboard

Training classifier on bottleneck

Open tsaxena opened this issue 4 years ago • 6 comments

@alexandre01 I am trying to train a classifier on the icons using the bottleneck embeddings that I get by model inference using pretrained model. In some cases though it doesnt seem to work. I used the code in your latent_ops notebook and encode each icon. In some cases I get this error The size of tensor a (10) must match the size of tensor b (8) at non-singleton dimension 0 Am I missing something?

tsaxena avatar Dec 28 '20 21:12 tsaxena

I am using the following methods to encode icon, but does seem to work on some of the icon indices.

` def encode(data): model_args = batchify((data[key] for key in cfg.model_args), device) with torch.no_grad(): z = model(*model_args, encode_mode=True) return z

def encode_icon(idx): data = dataset.get(id=idx, random_aug=False) return encode(data) `

tsaxena avatar Dec 28 '20 21:12 tsaxena

Hello @tsaxena! Yes unfortunately this is due to the fact that the pretrained model was trained with a maximum of 8 paths per SVG. Since an index embedding is used in the model, this means one cannot perform inference using a larger amount.

You'd need to train a model with more paths or filter your classification dataset to eight paths.

alexandre01 avatar Dec 28 '20 21:12 alexandre01

Thanks for the prompt reply @alexandre01 . So does that mean you did not use all 100k icons for training the pretrained model?

tsaxena avatar Dec 28 '20 23:12 tsaxena

Yes, the dataset is about 100k icons, but because of time constraints the pretrained model was only trained on a filtered subset. And I don't have access to the GPU server I used anymore.

alexandre01 avatar Dec 28 '20 23:12 alexandre01

I eventually want to fine tune the network on svgs that will have more than 8 paths. Do you suggest training from scratch? From what I understand, the max number of SVG paths is a configuration parameter that can be changed.

tsaxena avatar Dec 28 '20 23:12 tsaxena

I am not @alexandre01. But what you say is correct, I think.

The max number of paths is a config parameter and can be changed in deepsvg/model/confg.py:

self.max_num_groups = 8          # Number of paths (N_P)

There are also individual configs in configs/deepsvg/ that overwrite this variable. You may need to increase the number of paths here as well.

You would need to retrain to cope with a larger number of paths. If you have a beefy GPU, retraining does not take very long. On my RTX 3090, I can retrain from scratch within hours.

What is your use case?

pwichmann avatar Dec 29 '20 14:12 pwichmann

@tsaxena ,Hi, I'm also having issues with tensor mismatches using my own svg images. I hope to be able to adjust the configuration of max_num_groups and max_total_len in this project and train the model's own max_num_groups and max_total_len to solve the problem of tensor mismatch. Unfortunately, I didn't make any progress. Have you solved this problem? Can you give me some detailed guidance? I would be infinitely grateful and look forward to hearing from you!

HeCunr avatar Mar 14 '24 14:03 HeCunr

@pwichmann ,Hello, I followed the method you provided but it failed. Can you give me some detailed guidance? I would be infinitely grateful and look forward to hearing from you!

HeCunr avatar Mar 14 '24 14:03 HeCunr