RecurrentAttentionConvolutionalNeuralNetwork
RecurrentAttentionConvolutionalNeuralNetwork copied to clipboard
Hi, where is the "apn.pt.pt" file?
Hi, where is the "apn.pt.pt" file?
def init_with_apn2(self):
ckpt = torch.load('../checkpoints/CUBS/apn2.pt.pt')
self.apn1.load_state_dict(ckpt['apn1_state_dict'])
I haven't made it available on git. That is pretrained model file for the Attention Proposal Network.
You have to generate candidate labels for each image using functions from utils.py. Change the file for coordinate labels in dataset.py to the one you generated. Then train the APN2 network in networks.py with the run_model_coords.sh script.
OK,thanks