keypoint-detection
keypoint-detection copied to clipboard
help with larger maxxvit models
Hey, i'm trying to add more MaxxVit model but I often end up with error,
for example those works :
class MaxVitPicoUnet(MaxVitUnet):
MODEL_NAME = "maxvit_rmlp_pico_rw_256" # 7.5M params.
FEATURE_CONFIG = [
{"down": 2, "channels": 32},
{"down": 4, "channels": 32},
{"down": 8, "channels": 64},
{"down": 16, "channels": 128},
{"down": 32, "channels": 256},
]
class MaxVitSmallUnet(MaxVitUnet):
MODEL_NAME = "maxxvit_rmlp_small_rw_256" # 7.5M params.
FEATURE_CONFIG = [
{"down": 2, "channels": 96},
{"down": 4, "channels": 96},
{"down": 8, "channels": 192},
{"down": 16, "channels": 384},
{"down": 32, "channels": 768},
]
how would you do it for maxvit_small_tf_512 ?
@ExtReMLapin you need to match the feature config to the actual model.
So for example, for your model the feature config is (taken from the docs ):
for o in output:
# print shape of each feature map in output
# e.g.:
# torch.Size([1, 64, 256, 256])
# torch.Size([1, 96, 128, 128])
# torch.Size([1, 192, 64, 64])
# torch.Size([1, 384, 32, 32])
# torch.Size([1, 768, 16, 16])
So you would need to change the number of channels accordingly in the FEATURE_CONFIG (in this case, change the # channels for the first entry to 64).
At some point I wanted to make this generic, but I haven't dedicated the time yet. Feel free to make a PR ;)