ml-dab
ml-dab copied to clipboard
DAB: Differentiable Approximation Bridges
A simplified example demonstrating a DAB network presented in Improving Discrete Latent Representations With Differentiable Approximation Bridges.
Usage
The only dependency for this demo is pytorch.
To run the 10-sort signum-dense problem described in section 4.1 of the paper simply run:
python main.py
This should result in the following which corroborates the paper’s result of 94.2% :
train[Epoch 2168][1999872.0 samples][7.79 sec]: Loss: 79.2356 DABLoss: 7.9058 Accuracy: 95.5683
…
test[Epoch 2168][399360.0 samples][0.91 sec]: Loss: 79.2329 DABLoss: 7.9012 Accuracy: 94.6424
Create a DAB for a custom non-differentiable function
- Create a suitable approximation neural network.
- Implement custom hard function similar to SignumWithMargin in models/dab.py .
- Stack a DAB module in your neural network pipeline.
- Add DAB loss to normal loss.
Cite
@article{
dabimprovingdiscreterepr2020,
title={Improving Discrete Latent Representations With Differentiable Approximation Bridges},
author={Ramapuram, Jason and Webb, Russ},
journal={IEEE WCCI},
year={2020}
}