pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

Add DaViT

Open fffffgggg54 opened this issue 2 years ago • 12 comments

Adapt the DaViT model from https://arxiv.org/abs/2204.03645 and https://github.com/dingmyu/davit.

Notably, the model performs on par with many new models, such as MaxViT, whilst having higher throughput, a design that should allow for easy pyramid feature extraction, and resolution invariance. In my experiments, the model transfers well between training and test images of different sizes (train 56x56, infer 640x640 and vice versa).

fffffgggg54 avatar Dec 07 '22 01:12 fffffgggg54

The documentation is not available anymore as the PR was closed or merged.

Implementation done and passes all tests. I'm not sure about the styling and documentation and if there are issues with that. I got feature extraction working by implementing the pyramid features as the basic forward, with the regular classification forward calling that in the parent class and the child class overriding forward to return the pyramid features. Looking at some of the other transformer implementations, I think a generalized version of this structure should work well, though it's not as flexible compared to the FeatureDictNet.

fffffgggg54 avatar Dec 09 '22 11:12 fffffgggg54

@fffffgggg54 the FeatureDict methods can be used (I've been meaning to do this as an example for swin / maxvit / etc) but the swin ones need some re-org.

You want to reorganize the stages instead of having everything in module lists sequences by looping

DaViTStage
  self.downsample = PatchEmbed
  self.channel = Channel...()
  self.spatial = Spatial()

In general it's easier to reason about the model with say FX manipulation, layer-scale module maniuplation if models in the model (as ordered by model.named_modules() and parameters) are generally ordered to match their call sequence...

I can sort this out maybe next week, I'm going to tackle this for FocalNet (they organized like swin w/ downsample at end of stage), swin

rwightman avatar Dec 09 '22 19:12 rwightman

And also having the two cpe modules in a ModuleList instead of cpe1/cpe2 instantiated relative to the other modules in they order they are used in fwd would be preferred ... I've done a lot of checkpoint remapping on the fly for this sort of thing

https://github.com/rwightman/pytorch-image-models/blob/0fe90449e5e07e91a78ea847b76eda6010b55283/timm/models/convnext.py#L331-L352

https://github.com/rwightman/pytorch-image-models/blob/0fe90449e5e07e91a78ea847b76eda6010b55283/timm/models/efficientformer.py#L488-L513

rwightman avatar Dec 09 '22 19:12 rwightman

Made a couple changes, PR 5 in my fork.

cpe modules were moved out of ModuleLists and into separate stages as cpe1 and cpe2. Other stages (PatchEmbed) were separated out as much as possible without sacrificing flexibility. The main DaViT stages have also been put into a separate class. I feel that these changes should for sure stay.

Since the DaViTStage class bears a resemblance to some of the stuff in ByoaNet/ByobNet and the structure of DaViT is also similar to those models, it might make sense to try and use that here.

I also got rid of all of the ModuleLists and iterating and replaced them with SequentialWithSize, which works like sequential but includes the size argument. I was hesitant to go this route in yesterday's version because this class doesn't work with TorchScript. I'm not sure if this has something to do with my implementation or if it's on torch's side. Looping through ModuleLists seems to be the best naive way to get around this. It might be possible to rewrite everything without the size argument.

FeatureDictNet didn't work, since it calls module(x) instead of module(x, size). I'm not sure if its worth overwriting the method when overriding forward in DaViTFeatures also works. Implicitly determining size would fix this.

I'll work on removing the size parameter, which should solve both the TorchScript and the feature extraction issues. I'm not sure if you have any input for me regarding how to move forward with this.

fffffgggg54 avatar Dec 10 '22 14:12 fffffgggg54

I've removed the size argument and everything works with nn.Sequential now. TorchScript works, as does FeatureDictNet. PatchEmbed also no longer requires register_notrace_module. Some parts are a bit messy due to conversions between (N, HW, C), (N, C, H, W), and (N, H, W, C).

fffffgggg54 avatar Dec 11 '22 05:12 fffffgggg54

@fffffgggg54 changes looking good, a few minor formatting things I'll take care of and will do some testing, I'll likely combine it with polishing up FocalNet and do in same block of time, maybe later this week.

Don't be suprised if this PR doesn't get merged until after I merge to another branch and then merge that :)

FYI I did some refactoring of module imports so that's why tests broke, I can take care of the fix when I do final polish, but it's fairly simple change on the import

rwightman avatar Dec 12 '22 07:12 rwightman

@rwightman Would it be possible for you to tell me specifically what needs to be done, specifically the formatting? I'm trying to learn about contributing to large projects and I want the experience. I've rebased my fork to use the new module structure.

fffffgggg54 avatar Dec 13 '22 08:12 fffffgggg54

@fffffgggg54 a number of the changes are just small style nits that I often do for consistency... I can point out a few examples.

For the downsample, I think I'd still keep the very first patch projection as it's own top-level .patch_embed module, and (and have none in the first stage), and then have one in the other stages. This is for consistency with convnet stems and other hierarchical transformers.

I think that's it actually at this point, you've done a great job fitting the style and using the model builder featuers, etc. The last fix to the module naming change I made recently should be quick, you're already using the new multi-weight default cfgs w/ added pretrained tags which is awesome.

Import block looks something like this now, a bit more granular, non-model support modules have _ in front, so when looking at models dir they're all at the top, separate from models. layers moved out to another top level.

https://github.com/rwightman/pytorch-image-models/blob/84178fca60fca9d35c0ea5ec08af2e88bf058559/timm/models/maxxvit.py#L49-L59

rwightman avatar Dec 14 '22 23:12 rwightman

I separated the first downsample out. I don't like the way it looks compared to all the stages having a patch embed and attention blocks, but it makes sense.

The imports are also updated to the latest ones and passes tests in my fork. I've contacted a few of the authors for additional weights, but haven't heard back. It might make sense for you to ask once this gets merged.

fffffgggg54 avatar Dec 17 '22 06:12 fffffgggg54

@fffffgggg54 thanks again for PR, I've been fiddling with the models and I'm falling short of the official classification results trying to validate, I've run through a few different crop sizes and interpolation combos, best I got was 84.2 for B and official impl claims 84.6. I didn't see anythign obvs wrong, but feel like maybe a pad or somethign isn't quite right for the classification setup (weird that those models aren't in the official impl).

rwightman avatar Jan 09 '23 17:01 rwightman

@rwightman I did a few tests and it seems that the setup of the head was likely causing trouble. Current impl has norm->pool->fc, while pool->norm->fc seems to get me ~0.5% higher top-1 for val accuracy for the tiny model. I'll update the PR and send my setup once I fiddle with it a bit more and get results matching the paper.

fffffgggg54 avatar Jan 13 '23 00:01 fffffgggg54

I changed the head to pool->norm->fc with an implementation from convnext. This got me 82.80% on T, 84.28% on S, and 84.58% on B with bicubic interpolation, 0.85 crop, and 224 image size. I updated the model config to reflect these settings. I'm not sure why the B variant is still off by 0.02%.

fffffgggg54 avatar Jan 16 '23 01:01 fffffgggg54

@fffffgggg54 thanks, merged, I did a bit of refactoring, I didn't really feel their 'patch embed' was a patch embed anymore given it was being used like a stride 4 conv for stem, and a 2x2 box filter like conv for downsample, so split.

Weights are in HF hub https://huggingface.co/timm/davit_base.msft_in1k

rwightman avatar Jan 27 '23 21:01 rwightman