Pangolin
Pangolin copied to clipboard
Inconsistency of predictions depending on batch size
Dear Authors,
Thank you for the great tool.
I want to implement an option to predict scores with batch size larger than 1. During my first tests I noticed, that the predictions differ depending on the batch size. Could you check what might be the reason for this behaviour of the model? Below, I provide the example variant (chr12-110435045-G-A), for which the score differs when it's predicted for the single variant and for the provided batch of size 4: 0.5400000214576721 in the original version against 0.5299999713897705 on the batch. I also provide my code to reproduce the issue. To make the question more compact, I give an example with a prediction mismatch for just one of the models.
import torch
import numpy as np
import pyfastx
from pkg_resources import resource_filename
from pangolin.model import *
###############################################################################################
test_variants = [
'chr12-110435044-T-C',
'chr12-110435044-T-G',
'chr12-110435045-G-A',
'chr12-110435045-G-C',
]
atol = 0.000001 # tolerance value to be used in np.allclose()
d = 50
reference_fasta_path = 'GRCh38.primary_assembly.genome.fa'
###############################################################################################
# the same as in the original version
IN_MAP = np.asarray([[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
def one_hot_encode(seq, strand):
seq = seq.upper().replace('A', '1').replace('C', '2')
seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
if strand == '+':
seq = np.asarray(list(map(int, list(seq))))
elif strand == '-':
seq = np.asarray(list(map(int, list(seq[::-1]))))
seq = (5 - seq) % 5 # Reverse complement
return IN_MAP[seq.astype('int8')]
models = []
for i in [0,2,4,6]:
for j in range(1,4):
model = Pangolin(L, W, AR)
if torch.cuda.is_available():
model.cuda()
weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)))
else:
weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)), map_location=torch.device('cpu'))
model.load_state_dict(weights)
model.eval()
models.append(model)
###############################################################################################
# process variants
def prepare_variant_for_batch(lnum, chr, pos, ref, alt, fasta, d):
seq = fasta[chr][pos-5001-d:pos+len(ref)+4999+d].seq
ref_seq = seq
alt_seq = seq[:5000+d] + alt + seq[5000+d+len(ref):]
return ref_seq, alt_seq
fasta = pyfastx.Fasta(reference_fasta_path)
batch_chroms = []
batch_positions = []
batch_refs = []
batch_alts = []
for test_variant in test_variants:
chr = test_variant.split('-')[0]
pos = int(test_variant.split('-')[1])
ref = test_variant.split('-')[2]
alt = test_variant.split('-')[3]
ref_seq, alt_seq = prepare_variant_for_batch(0, chr, pos, ref, alt, fasta, d)
batch_chroms.append(chr)
batch_positions.append(pos)
batch_refs.append(ref_seq)
batch_alts.append(alt_seq)
model = models[0]
strand = '-'
# predict on batch
encoded_refs = [] # store encoded reference sequences in a list
encoded_alts = [] # store encoded alternative sequences in a list
for i in range(len(batch_refs)):
ref_seq = torch.from_numpy(one_hot_encode(batch_refs[i], strand).T).float()
alt_seq = torch.from_numpy(one_hot_encode(batch_alts[i], strand).T).float()
encoded_refs.append(ref_seq)
encoded_alts.append(alt_seq)
batch_ref = torch.stack(encoded_refs) # create a tensor with multiple ref sequences
batch_alt = torch.stack(encoded_alts) # create a tensor with multiple alt sequences
if torch.cuda.is_available():
batch_ref = batch_ref.to(torch.device("cuda"))
batch_alt = batch_alt.to(torch.device("cuda"))
with torch.no_grad():
pred_ref = model(batch_ref)[:,[1,4,7,10][j],:].cpu().numpy() # [0][[1,4,7,10][j],:].cpu().numpy() modify indexing
pred_alt = model(batch_alt)[:,[1,4,7,10][j],:].cpu().numpy() # [0][[1,4,7,10][j],:].cpu().numpy() modify indexing
# predict single
i=2
ref_seq = one_hot_encode(batch_refs[i], strand).T
ref_seq = torch.from_numpy(np.expand_dims(ref_seq, axis=0)).float()
alt_seq = one_hot_encode(batch_alts[i], strand).T
alt_seq = torch.from_numpy(np.expand_dims(alt_seq, axis=0)).float()
if torch.cuda.is_available():
ref_seq = ref_seq.to(torch.device("cuda"))
alt_seq = alt_seq.to(torch.device("cuda"))
with torch.no_grad():
pred_ref_single = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy()
pred_alt_single = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy()
# compare
print(np.allclose(pred_ref_single, pred_ref[i], atol=atol)) # Switches from True to False between atol=0.00001 and atol=0.000001