RecurrentAttentionConvolutionalNeuralNetwork
RecurrentAttentionConvolutionalNeuralNetwork copied to clipboard
Hi, where is the "apn.pt.pt" file?
Hi, where is the "apn.pt.pt" file?
def init_with_apn2(self):
ckpt = torch.load('../checkpoints/CUBS/apn2.pt.pt')
self.apn1.load_state_dict(ckpt['apn1_state_dict'])
I haven't made it available on git. That is pretrained model file for the Attention Proposal Network.
You have to generate candidate labels for each image using functions from utils.py
. Change the file for coordinate labels in dataset.py
to the one you generated. Then train the APN2
network in networks.py
with the run_model_coords.sh
script.
OK,thanks