prototypical-networks-tf
                                
                                 prototypical-networks-tf copied to clipboard
                                
                                    prototypical-networks-tf copied to clipboard
                            
                            
                            
                        Implementation of Prototypical Networks for Few-shot Learning in TensorFlow 2.0
Prototypical Networks for Few-shot in TensorFlow 2.0
Implementation of Prototypical Networks for Few-shot Learning paper (https://arxiv.org/abs/1703.05175) in TensorFlow 2.0. Model has been tested on Omniglot and miniImagenet datasets with the same splitting as in the paper.
 
Dependencies and Installation
- The code has been tested on Ubuntu 18.04 with Python 3.6.8 and TensorFflow 2.0.0-alpha0
- The two main dependencies are TensorFlow and Pillow package (Pillow is included in dependencies)
- To install prototflib runpytnon setup.py install
- Run bash data/download_omniglot.shfrom repo's root directory to download Omniglot dataset
- miniImagenet was downloaded from brilliant repo from renmengye(https://github.com/renmengye/few-shot-ssl-public) and placed intodata/mini-imagenetfolder
Repository Structure
The repository organized as follows. data directory contains scripts for dataset downloading and used as a default directory for datasets. prototf is the library containing the model itself (prototf/models) and logic for datasets loading and processing (prototf/data). scripts directory contains scripts for launching the training. train/run_train.py and eval/run_eval.py launch training and evaluation respectively. tests folder contains basic training procedure on small-valued parameters to check general correctness. results folder contains .md file with current configuration and details of conducted experiments.
Training
- Training and evaluation configurations are specified through config files, each config describes single train+eval evnironment.
- Run python scripts/train/run_train.py --config scripts/config_omniglot.confto run training on Omniglot with default parameters.
- Run python scripts/train/run_train.py --config scripts/config_miniimagenet.confto run training on miniImagenet with default parmeters
Evaluating
- Run python scripts/eval/run_eval.py --config scripts/config_omniglot.confto run evaluation on Omniglot
- Run python scripts/eval/run_eval.py --config scripts/config_miniimagenet.confto run evaluation on miniImagenet
Tests
- Run python -m unittest tests/test_omniglot.pyfrom repo's root to test Omniglot
- Run python -m unittest tests/test_mini_imagenet.pyfrom repo's root test miniImagenet
Results
Omniglot:
| Evnironment | 5-way-5-shot | 5-way-1-shot | 20-way-5-shot | 20-way-1shot | 
|---|---|---|---|---|
| Accuracy | 99.4% | 97.4% | 98.4% | 92.2% | 
miniImagenet
| Evnironment | 5-way-5-shot | 5-way-1-shot | 
|---|---|---|
| Accuracy | 66.0% | 43.5% | 
Additional settings can be found in results folder in the root of repository.