lightning-pose icon indicating copy to clipboard operation
lightning-pose copied to clipboard

Would it be possible to use a pretrained model (such as Topview Mouse from DLC animal zoo)?

Open hummuscience opened this issue 1 year ago • 51 comments

I have been playing around with lightning pose the past few days and quite impressed with the training speed and performance!

Coming from DeepLabCut, I am testing LP on videos of mice captured from the top view. As you probably know, the DLC animal zoo had a pretrained model for this scenario .

Would it be possible to use that as a backbone for LP instead of the typical resnets?

I am still new to your codebase, so I might have not understood it in depth yet and missing something here...

hummuscience avatar May 24 '24 10:05 hummuscience

@hummuscience thanks for the kind words!

We currently do not offer any pretrained models, though we hope to start providing some within the year. Once DLC releases the dataset used to train their TopViewMouse network we can use that to train an LP version. I'll leave this issue open for now and update it once we start working on this.

We are also considering providing a pretrained model from the Facemap dataset (mouse face from different angles) and the CRIM13 dataset (top-down view of two mice, one black and one white).

If anybody has additional pretrained models they would like to see (and importantly, a pointer to a labeled dataset), please let us know!

themattinthehatt avatar May 24 '24 13:05 themattinthehatt

I just checked the AnimalZoo preprint and the mentioned references. It doesn't seem that any of the actual labeled datasets are available, but in some cases, the videos are (see here for example: https://zenodo.org/records/3608658).

I wonder if it would make sense to run the TopViewMouse model on these videos, extract some output frames (refine them in case of errors), and then convert the project to a DLC project. It might take some time, but could be worth the effort (for me, at least). How many frames would you aim for in this case?

The facemap model would be quite useful, since many people working with head-fixed mice could benefit from it.

The CRIM13 model is also interesting, but wouldn't have that many applications since it's specific for assays with two animals (one white the other black). Unless, this is a more common assay, and I am not aware of that.

On the subject of multianimal tracking. Are there currently any possibilities with LP?

hummuscience avatar May 27 '24 10:05 hummuscience

We've been in touch with the DLC folks and they plan to release the labeled datasets once the paper is out. However, it is not clear how long that will take so your suggestion is probably the quickest way to getting something workable. This isn't something we have the bandwidth to work on right now, but if you're up for trying it that would be great! I'd suggest labeling all the video frames with the TopViewMouse model, then doing the following:

  1. compute motion energy (absolute difference between keypoints on consecutive frames, averaged over all keypoints)
  2. remove frames with low motion energy (when the mouse is sitting still)
  3. for the remaining frames, run a clustering algorithm (kmeans is fine) on the poses and selecting 1000 clusters. Then you can take 1 frame from each cluster to get 1000 labeled frames with a decent amount of pose variability (see a related implementation here, though note this runs kmeans on a pca embedding of the frames rather than the predicted keypoints).
  4. refine errors in selected frames if necessary

I would also note that we find, even with the LP bells and whistles, that more labeled frames is always better. so if you're not too daunted by the refinement step, selecting 2k or 3k frames will almost certainly result in a more robust model than 1k frames. but 1k will certainly do a good job.

If you go down this route please let me know! Happy to keep discussing it with you.

Re: multi-animal tracking - this is something that we are working towards, but it will be a while before we have these features built into LP (unless the animals are visibly distinct, like in the CRIM13 dataset)

themattinthehatt avatar May 28 '24 13:05 themattinthehatt

So, I started to implement this for some of my own videos and currently refining the predictions on 300 images to test if it improves things.

Meanwhile, it seems like the datasets from the SuperAnimal paper are public now (or at least, I found them on zenodo: MausHaus and the whole TopViewMouse dataset.

I tried the pretrained model on some images from the MausHaus and the BlackMice datasets, but the results were actually not as good as I expected. The expectation would be that the pretrained model would reproduce the labels from the training set. But it didn't (or rather, performed poorly).

Now there is also the entire TopViewMouse dataset. There, not all keypoints are labelled in all images (since they come from different datasets). I wondered if I can just go ahead and train LP with it like that.

The other issue is that the annotations are in a JSON file. I think I could manage to convert it to LP format though.

hummuscience avatar Jul 10 '24 12:07 hummuscience

I looked at the TopViewMouse demo a year or so back and also found that it did not perform as well as expected on a top-view mouse video that looked very similar to one in the training dataset. Good to know that you could replicate this finding.

I still haven't had a chance to look into the TopVIewMouse dataset - I do remember that not all keypoints are labeled in all images, but are the keypoints at least named the same when they are in the same location across datasets? If so then you can definitely train and LP model on these, you would just leave the ground truth label empty where it doesn't exist and then LP ignores this keypoint during training.

It shouldn't be too hard to convert the annotations from JSON to LP format. If you end up doing this please let me know and we can discuss best ways to train the model!

themattinthehatt avatar Jul 10 '24 13:07 themattinthehatt

The results are a bit better when one uses spatio-temporal adaptation. But not as I would expect.

Yes, the positions are the same. Then I will go ahead and try out the try. I will report :)

hummuscience avatar Jul 10 '24 13:07 hummuscience

Awesome excited to see the results! I take it the labeled dataset doesn't have the associated videos as well? Maybe I can ask the Mathis's for that data, then we could extract the context frames and test out a context model as well.

themattinthehatt avatar Jul 10 '24 15:07 themattinthehatt

Training is running 👍 will post once its done.

Yeah, the dataset doesn't contain any videos. I think some of the origin datasets could have videos (pranav 2018 maybe?). But yeah, it could be easier to ask Mathis's for the videos. Even though it is possible that they won't have them...

Would it be possible to inform the context model with a non-context model somehow? Maybe one could use the PCA?

Btw, I wrote a script that automatically extracts the context frames for each image. Could that be useful to add as a utility in the scripts folder?

hummuscience avatar Jul 12 '24 14:07 hummuscience

First look at the training. It seems like I should be stopping the training earlier (150k?) or at least saving more checkpoints.

Test videos look very good :) I am thinking of training a DLC model witht he same dataset (maybe some shuffles?) to compare output. There are 300 additional frames not contained in the original TopViewMouse5k that come from my own datasets.

Will check the evaluation

I am quite new to LP so I am not so sure about the choices in the config.yaml file. I added it below.

Screenshot 2024-07-15 at 10 49 46 Screenshot 2024-07-15 at 10 49 29 Screenshot 2024-07-15 at 10 49 00 Screenshot 2024-07-15 at 10 48 45

data:
  image_orig_dims:
    height: 480
    width: 640
  image_resize_dims:
    height: 512
    width: 512
  data_dir: /mnt/hpc_slurm/home/abdelhaym/freely-moving/lp-models/all_mighty_mouse/
  video_dir: /mnt/hpc_slurm/home/abdelhaym/freely-moving/lp-models/all_mighty_mouse/videos/
  csv_file: CollectedData_all.csv
  downsample_factor: 2
  num_keypoints: 27
  keypoint_names:
  - nose
  - left_ear
  - right_ear
  - left_ear_tip
  - right_ear_tip
  - left_eye
  - right_eye
  - neck
  - mid_back
  - mouse_center
  - mid_backend
  - mid_backend2
  - mid_backend3
  - tail_base
  - tail1
  - tail2
  - tail3
  - tail4
  - tail5
  - left_shoulder
  - left_midside
  - left_hip
  - right_shoulder
  - right_midside
  - right_hip
  - tail_end
  - head_midpoint
  mirrored_column_matches: null
  columns_for_singleview_pca:
  - 1
  - 2
  - 7
  - 9
  - 13
  - 19
  - 20
  - 21
  - 22
  - 23
  - 24
  - 26
training:
  imgaug: dlc-top-down
  train_batch_size: 8
  val_batch_size: 48
  test_batch_size: 48
  train_prob: 0.8
  val_prob: 0.1
  train_frames: 1
  num_gpus: 1
  num_workers: 4
  early_stop_patience: 3
  unfreezing_epoch: 20
  min_epochs: 300
  max_epochs: 750
  log_every_n_steps: 10
  check_val_every_n_epoch: 5
  gpu_id: 0
  rng_seed_data_pt: 0
  rng_seed_model_pt: 0
  lr_scheduler: multisteplr
  lr_scheduler_params:
    multisteplr:
      milestones:
      - 150
      - 200
      - 250
      gamma: 0.5
model:
  losses_to_use: []
  backbone: resnet50_animal_ap10k
  model_type: heatmap
  heatmap_loss_type: mse
  model_name: test
  checkpoint: null
  lightning_pose_version: 1.4.0
dali:
  general:
    seed: 123456
  base:
    train:
      sequence_length: 64
    predict:
      sequence_length: 128
  context:
    train:
      batch_size: 16
    predict:
      sequence_length: 96
losses:
  pca_multiview:
    log_weight: 5.0
    components_to_keep: 3
    epsilon: null
  pca_singleview:
    log_weight: 5.0
    components_to_keep: 0.99
    epsilon: null
  temporal:
    log_weight: 5.0
    epsilon: 20.0
    prob_threshold: 0.05
eval:
  hydra_paths:
  - ' '
  predict_vids_after_training: true
  save_vids_after_training: true
  fiftyone:
    dataset_name: freemovetest
    model_display_names:
    - test_freemoving
    launch_app_from_script: false
    remote: true
    address: 127.0.0.1
    port: 5151
  test_videos_directory: /mnt/hpc_slurm/home/abdelhaym/freely-moving/lp-models/all_mighty_mouse/test_videos/
  saved_vid_preds_dir: null
  confidence_thresh_for_vid: 0.9
  video_file_to_plot: null
  pred_csv_files_to_plot:
  - ' '
callbacks:
  anneal_weight:
    attr_name: total_unsupervised_importance
    init_val: 0.0
    increase_factor: 0.01
    final_val: 1.0
    freeze_until_epoch: 0

hummuscience avatar Jul 15 '24 08:07 hummuscience

I tried to run a semi-supervised model (PCA or temporal), but it fails due to the input images being of different sizes.

I could rescale the images and the key points, do you think that would make sense?

hummuscience avatar Jul 15 '24 15:07 hummuscience

ah very cool, glad you were able to get this working!

I tried to run a semi-supervised model (PCA or temporal), but it fails due to the input images being of different sizes.

yes pca will require the frames to be the same sizes; but actually this is an interesting use case because even if you resized the frames to be the same size the size of the animal would vary a lot from dataset to dataset, so maybe PCA wouldn't work so well anyways. I'll have to give this some more thought.

config options

  • image_resize_dims: 512x512 is quite large, you might also try 256x256 or 384x384 and see if they work as well/better. model training/inference will certainly be faster with smaller resize dims. training can actually be more accurate in some cases too because the number of trainable parameters increases as you increase the resize dims, so I've seen (through other users) that even in 1000x1000 pixel images with a freely moving mouse resizing to something smaller than 512x512 is more accurate.
  • train_prob/val_prob: I would up the train_prob so you use more of the data for training. maybe use something like 0.95/0.05? that doesn't leave any leftover for test data, but it depends on how you want to test the model (could be on held-data that you labeled that isn't the CollectedData.csv file). also, there are a ton of frames in this dataset (>5k right?) so you could even go further to 0.98/0.02 or something (this would still give you >100 validation frames)
  • train_batch_size: if you decrease the resize_dims you could double this to 16 without memory issues I'm guessing; then the model will train faster

everything else looks good! I see you set imgaug to dlc-top-down already, which is great :ok_hand:

Would it be possible to inform the context model with a non-context model somehow? Maybe one could use the PCA?

the lack of context frames and the frames being different sizes means it is difficult to play with the context/semi-supervised features of the model. there's not a real easy workaround to either of these with this dataset that I can think of off the top of my head. So I guess I'd say use the fully supervised model first and see how that works for you?

First look at the training. It seems like I should be stopping the training earlier (150k?) or at least saving more checkpoints.

I'm kinda surprised the validation loss starts going up so early - but my intuition is that this is related to the big resize dim (512x512), I'm curious if this goes away with smaller resize dims.

themattinthehatt avatar Jul 15 '24 17:07 themattinthehatt

Re-training now with your suggestions.

I have a RTX A6000, so I increased the batch size to 32 without any memory error (after 40 epochs). Went with 256 x 256 size And 0.97 train, 0.02 val which ends up with 100+ validation images and 50+ test

Interestingly, the model is not much faster, an epoch takes 1 minute

I was checking the loss development with Tensorboard, and it seems like the loss is reducing quicker than the first model I trained.

It could be due to the larger training data size and larger batch size that the loss is reducing faster and its generalizing better.

Screenshot 2024-07-16 at 13 20 45 Screenshot 2024-07-16 at 13 20 14 Screenshot 2024-07-16 at 13 20 02

hummuscience avatar Jul 16 '24 11:07 hummuscience

Going through the evaluation results of the first model I trained, it seems like the model doesn't generalize across datasets where less keypoints are labelled.

For example, in this dataset, only 4 keypoints were labelled. When we predict all 26 keypoints, the model does quite a bad job at it, with a high confidence.

Am I doing something wrong?

In the SuperAnimal paper, they mention using gradient masking of the heatmaps to deal with this issue. I have no clear idea what that means, though.

Screenshot 2024-07-16 at 14 14 54

Screenshot 2024-07-16 at 14 15 05

hummuscience avatar Jul 16 '24 12:07 hummuscience

Interestingly, the model is not much faster, an epoch takes 1 minute

I guess this makes sense, you have fewer batches but each batch takes longer to process. Is your GPU utilization at or near 100% while training? If not you could also increase training.num_workers to something like 6 or 8 (depending on the number of cores you have) and that would speed training up a bit.

I was checking the loss development with Tensorboard, and it seems like the loss is reducing quicker than the first model I trained.

You are correct that this is due to the increased batch size - you can see in the first model that there is a big dip around 10k iterations. This is actually due to the unfreezing of the backbone weights. When you increase the batch size you take fewer steps per epoch, and so the backbone is unfrozen earlier (in terms of number of gradient steps). Actually in your case because there are so many frames you could probably reduce the parameter training.unfreeze_epoch to something like 5 instead of 20, but that doesn't invalidate any of your current results.

For example, in this dataset, only 4 keypoints were labelled. When we predict all 26 keypoints, the model does quite a bad job at it, with a high confidence.

Huh this is a bit unexpected, I would have guessed the model could generalize better than this. I'm curious if the 256x256 model generalizes better. This is a great test though! Do you see this kind of issue with other datatsets that have few labeled keypoints? I forget how exactly the SuperAnimal paper does the gradient masking but I'll look into that and get back to you.

themattinthehatt avatar Jul 16 '24 13:07 themattinthehatt

I found the code that DLC uses for gradient masking in their dlc_pytorch branch

        for b in range(batch_size):
            for heatmap_idx, group_keypoints in enumerate(coords[b]):
                for keypoint in group_keypoints:
                    # FIXME: Gradient masking weights should be parameters
                    if keypoint[-1] == 0:
                        # full gradient masking
                        weights[b, heatmap_idx] = 0.0
                    elif keypoint[-1] == -1:
                        # full gradient masking
                        weights[b, heatmap_idx] = 0.0

                    elif keypoint[-1] > 0:
                        # keypoint visible
                        self.update(
                            heatmap=heatmap[b, :, :, heatmap_idx],
                            grid=grid,
                            keypoint=keypoint[..., :2],
                            locref_map=self.get_locref(locref_map, b, heatmap_idx),
                            locref_mask=self.get_locref(locref_mask, b, heatmap_idx),
                        )

If I understand this correctly, the gradient masking is to the weights before the backwards pass? But I am not sure, since in the SuperAnimal paper they mention that it is applied before the loss calculation.

In LP the masking is done during [the loss calculation]https://github.com/danbider/lightning-pose/blob/4967266feb59a8c08ff3c31d08f520d480cd10d1/lightning_pose/losses/losses.py#L149), as in, all NaNs are removed before loss calculation. Is that the same approach though?

Checking the SuperAnimal paper, there are some images from with vs. without masking

Screenshot 2024-07-16 at 16 24 23 Screenshot 2024-07-16 at 16 24 03

The methods part mentions the following:

Training naively on these projected annotations would harm the training stability, as the loss function penalizes undefined keypoints, as if they were not visible (i.e., occluded).

For stable training of our panoptic pose estimation model, we mask components of the loss function across keypoints. The keypoint mask $n_k$ is set to 1 if the keypoint $k$ is present in the annotation of the image and set to 0 if the keypoint is absent. We denote the predicted probability for keypoint $k$ at pixel $(i, j)$ as $p_k(i, j) \in [0, 1]$ and the respective label as $t_k(i, j) \in {0, 1}$, and formulate the masked $L_k$ error loss function as

$$ L_k = \sum_{k=1}^{m} n_k \sum_{i,j} |p_k(i, j) - t_k(i, j)|_z, $$

with $z=2$ for mean square error and $z=1$ for L1 loss (e.g., used for locref maps in DLCRNet) and the masked cross-entropy loss function as

$$ L_{CE} = - \sum_{k=1}^{m} \sum_{i,j} n_k t_k(i, j) \log p_k(i, j). $$

Note that we make distinct the difference between not annotated and not defined in the original dataset and we only mask undefined keypoints. This is important as, in the case of sideview animals, "not annotated" could also mean occluded/invisible. Adding masking to not annotated keypoints will encourage the model to assign high likelihood to occluded keypoints.

hummuscience avatar Jul 16 '24 14:07 hummuscience

Does LP distinguish between "unlabelled" keypoints (because of occlusions) and keypoints that were not labelled at all?

hummuscience avatar Jul 16 '24 15:07 hummuscience

No, currently LP does not distinguish between the two - if a ground truth label is missing it is dropped from the loss function with the remove_nans function you pointed out (though see "Why does the network produce high confidence values..." here: https://lightning-pose.readthedocs.io/en/latest/source/faqs.html).

So yes, on a first pass it appears that LP by default does the same gradient masking that the SuperAnimal paper implements. Looking at the DLC function you linked, weights does not refer to the weights of the model, but rather the mask applied to the loss ($n_k$ in their notation above).

themattinthehatt avatar Jul 16 '24 15:07 themattinthehatt

Oh something else I just realized: the TopViewMouse dataset contains some datasets with multiple animals - how do you deal with this right now? LP cannot currently handle multi-animal pose estimation.

themattinthehatt avatar Jul 16 '24 15:07 themattinthehatt

I removed the two datasets that have multiple animals (TriMouse and Golden Lab).

I also just realized that one of the Datasets (Kiehn Lab Openfield) actually doesn't have labels. So I will rerun the training without it

hummuscience avatar Jul 16 '24 15:07 hummuscience

Great. Let me know how the generalization looks with the 256x256 model when done, I'm still scratching my head a bit about the bad performance in the frames you showed above.

themattinthehatt avatar Jul 16 '24 16:07 themattinthehatt

This is the current state of the training (it still has the Kiehn Lab Openfield data though).

The train_heatmap_mse_loss is plateauing as well as the RMSE loss. Not sure what to think about the val loss.

Should I stop the training in this case? If I stop the training, will it save the checkpoint? (I set the config to save multiple checkpoints but realized that it is not implemented on the dynamic_crop branch)

Screenshot 2024-07-16 at 20 17 10 Screenshot 2024-07-16 at 20 16 33

hummuscience avatar Jul 16 '24 18:07 hummuscience

the noisiness in the validation plot is weird, especially compared to the black line. is black 512x512 and magenta 256x256?

one thing you can do is hit the three dots in the upper right hand corner of these plots and change the y-axis to a log scale, that's typically more helpful the further you get in training - my guess is that the train_heatmap_mse_loss isn't plateauing yet, it's just decreasing on smaller scales.

the model should be saving out weights along the way, you'll find them in the tb_logs directory (you'll have to go down a couple more subdirectories).

themattinthehatt avatar Jul 17 '24 00:07 themattinthehatt

btw if you're training on the dynamic_crop branch I would recommend switching over to main and pulling the latest updates, I recently made an update to the validation dataloader that might be relevant here (a5e38316053bcfa127007772be114e516d492b91). Previously the validation data was passing through the image augmentation pipeline, so the "validation" data was actually different on every single pass. In the most up-to-date main branch that has been fixed and now the validation data is not augmented. Obviously you were doing the same thing with the 512x512 network and didn't see the weird spikes in the validation loss but wouldn't hurt to remove that factor.

themattinthehatt avatar Jul 17 '24 00:07 themattinthehatt

Alright, re-training now with the on the unsupervised multiview branch with 8 workers, and unfreeze epoch set to 5.

I will also have a look at how well the magenta model performed.

hummuscience avatar Jul 17 '24 12:07 hummuscience

So far so good. Orange is the latest model.

Screenshot 2024-07-17 at 17 29 13 Screenshot 2024-07-17 at 17 28 58

hummuscience avatar Jul 17 '24 15:07 hummuscience

(I wonder if we should move this to discussions?)

So, the model is trained and it seems to perform better than the previous ones. It might be due to the removal of the weird datasets though (2-3 mice and the one without labels).

I would have expected the early stopping to kick in at some point after 5.5 hours. But maybe I am misunderstanding how it works. how do I know which checkpoint was used for the inference in the end? I understand that lightning uses the checkpoint with the "best score" but I can't find out how this is determined...

Screenshot 2024-07-18 at 11 28 10 Screenshot 2024-07-18 at 11 27 46

The predictions on some datasets look quite good: 13 12 11

While on others, less so: 9 8

Specially this effect of all the predictions being "squished" into the central line of the animal of to the front of the animal.

One obvious difference betewen the two different types of datasets is the size of the animal. I thought this should only affect semi-supervised PCA methods, but I have a blank losses_to_use list in the config. For some weird reason though, my test_video predictions have a *_pca_singleview_error.csv file. Is the PCA still being computed and used? (I am on the semisupervised multiview branch).

This issue reminds me of the situation that DLC had with the SuperAnimal models and therefore adapted spatial-pyramid search during testing and video adaptation during inference of previously unseen data.

One mistake that I did was that I didn't adjust the train/va/test proportions after removing one of the datasets. I ended um with a small val/test dataset (80 images each) which did not contains all datasets. I wonder if it would make sense to perform the train/test/val split manually to make sure all datasets are represented in each split.

hummuscience avatar Jul 18 '24 10:07 hummuscience

awesome, thanks so much for looking into this. yeah maybe we can switch over to the discord? https://discord.gg/tDUPdRj4BM

to answer your most recent questions here though:

  • if you set training.early_stopping to true in the config file the model doesn't actually perform early stopping (apologies for this, old nomenclature that we should update). rather, the trainer monitors the validation loss and saves out a model checkpoint every time the validation loss is lower than any previously recorded validation loss. the model is trained for the full number of epochs still. so when you load the weights you'll get the ones corresponding to that lowest validation loss epoch.
  • those first three images you shared look pretty good. do the turquoise markers correspond to the set of keypoints that were labeled for that particular dataset, and the blue markers the unlabeled keypoints? (if so that's very helpful). I'm not sure why the generalization to missing keypoints would work in some instances but not others, especially that last one where...well...it's just a dark mouse on a bright background.
  • regarding the PCA error file, if you set losses_to_use to be an empty list the model will not use the pca loss during training, but it will still automatically compute that loss on videos after running inference if the columns_for_pca_singleview isn't null
  • i still haven't read deeply into their video adaptation method, but I would think that's not necessary on frames from datasets that were used in the training
  • if you want to manually create a test dataset you can select a set of frames, remove those frames from the CollectedData.csv file, and place them in a file called CollectedData_new.csv. Then you can use all frames in CollectedData.csv for train/val. Our training function will automatically look for a file named CollectedData_new.csv (or, more generally, the name of the csv file used for training with "_new" appended to the end) and run inference on that. This is what we did for all of our OOD experiments in the LP paper, to have strict control over which frames were train/val vs OOD test.

themattinthehatt avatar Jul 18 '24 18:07 themattinthehatt

I am currently thinking that the issues with the labels on the sides of the animal getting bad predictions has to do with the relative proportion of these labels vs. others (such as ears or tail base) in the whole dataset.

I will try to do some stratification with the train/test data or some oversampling. I haven't found hints of this being implemented in LP until now. Do you think it might make sense to add?

hummuscience avatar Aug 15 '24 12:08 hummuscience

That's a good observation, it is indeed possible that oversampling would force the model to focus more on these less-frequent keypoints. One way to do this would be to implement a custom sampler for the pytorch data loader, which would allow upweighting of the labeled examples with more keypoints.

I don't have the bandwidth to work on this for the time being, but happy to discuss the details more if you're interested in trying it out.

themattinthehatt avatar Aug 20 '24 20:08 themattinthehatt

Nice! Good to know there is a way to implement this directly, I will give it a look :)

In the meantime, I ran a few experiments. I took the whole datasets I am working with, created a OOD test-set (40 images from each modality/dataset), and the rest I used for training using a 90% train and 10% eval split.

I tried 4 different approaches.

  1. Trained untouched ("Original")
  2. "Oversampled": the images with the under-represented labels (created a "weight" per image based on how much under-represented labels it has) were oversampled while keeping the total number of images for training fixed. The final training set had a total of around 4400 images (1900 if duplicates are removed).
  3. "Oversampled-Inflated": since I am "losing" precious training data if I oversample while keeping the number of images the same, I inflated the total number of images by 1.5, kept the original dataset, and then added more images, but there I oversampled the under-represented images
  4. "Oversampled-2.5Inflated": The same as 3, but with 2.5 inflation, to makes sure even more under-represented images are present (in retrospect, I should have thought of a better way the bias the data more towards the under-represented labels).

Doing this lead to the following distributions of labels in the training data:

replative-bodypart-proportions

Here is the validation loss during training: (blue original, dark blue oversampled, pink inflated, purple 2.5 inflated)

Screenshot 2024-08-21 at 15 28 22

I also saved a snapshot every 50 epochs from each model to run the eval on the OOD dataset and see the development:

Here is the median log pixel error of some keypoints (mainly the problematic ones and mouse_center/tail_base where are present in each dataset. The vertical lines show the "best" model as chosen by LP checkpoint-comparison-pixel-error

I wanted to see if the models perform similarly on different origin datasets. Here is the log pixel error for "left_hip"

checkpoint-comparison-pixelerror-bydataset

And tail_base:

checkpoint-comparison-pixelerror-bydataset-tail_base

It seems like the oversampling and inflation improves the model performance for these rare labels.

The varying performance between datasets is still a bit puzzling... Will have a closer look at the weights for each image, it might be that OFT/3CSI are over-represented for some reason.

hummuscience avatar Aug 21 '24 13:08 hummuscience