ProtFlash
                                
                                
                                
                                    ProtFlash copied to clipboard
                            
                            
                            
                        `ProtFlash` einops error with variable length input during `nn.DataParallel` inference.
Describe the bug
ProtFlash when wrapped with torch.nn.DataParallel crashes via an einops error.
To Reproduce
import torch
from ProtFlash.pretrain import load_prot_flash_base
from ProtFlash.utils import batchConverter
# Proteins we want to embed
data = [
    ("protein1", "FVLHFPCIHDHMQAVVQWITRMDMVYKFRADMYGKGKNNGFRDVHCVDHQQRFGWTQTMPDYPSGWEVASFKKSHIGRRASAPIGLGSYTKLKSMKIHMQKSTKWDWGMFKHQATVMMQEREQGRSENLGNYYTMNHCNTERRRHIITVIYMINPYRMRLKTNQKFLYNQCYFYKWVRWNTDAMTSMLNVTCNHSLYKQCWDHTYLLAYKDPQGSNEQNTDEGHVRMMVKECGPKILYYDCFTKVPMFHDLFGSWLMWLPILLQNLAVVDGYGTSVLMTEGDSYCEGVKFGNICTIFRRDASPSPINRIWVSYICLRSIGKSAGESKAFKYMVRVGWFKQGFCLYEKWSLDFFNEFEHPIGLNVWNNQKHHDTGFYLPFFRKDTIQHMSVEQWPDDECRYLNSIKNGAMTAEFSLMPFQCTPDKAQRFHEIFFKVSGWEWGMESGVALDVIQVMAEEWWLVFQIFFIHEHCYHCNVHTNRGHWSSVHGWNGFVAARDYIRLADRANHHWNNHIQENPASMERYIFGNLKRVAFINYQMAADPGFQTKVRRVDRYYNRVTVTIRISTWNKDPMLQTDKRTSSTYNFMQCRMWEWKNPNVDRYEFKRYSSKPEKDTSLACREFVVSGNEVARLRKDRPKHHLFFFWEDDYLGIAGFSAISTLEHEPGPPWHYIQPTKSHLLQNASRCYPALTFWIEHMQWEYMCWHPQEQGTDCMCPMPLSPIFCFEAGLDFADKPSPNWFSWMTCVMEARKGIYIAFDQSTPSIPCMMMRHMGFCGSWPPNKIPSIMKFGAKQQYRSKFHQVPLPANYVLHRPQLEAFMILFWHTIKNFEHSDKQQKAKTVQHEWMGFMFDPQKCHTEDCTCNYPIEDYDTIHRLWKQVKFYENYATTWIFLCTPLLQLTWWPRGIVCMEKDRTRQEMCHLQCTKVMMQTDLKTVFIMIDGILWMIPSKLEVVDHFWAEGMFMRTCCAWPMSNELQNLAFKHGPILSAHAWDTEGNTDEVGERAVVTGCIHYPWRRCANEYCWQCAIHGMCCQYFWHHAILPDDKADPYNDYLRENAPEWAPCQLYPHANKANEEHEAKECKAYKYVYMIRATGEFQATKRHCDPRFHWTIFGMMDALSDIDFNGFYDDWMIMFMPFTKYAGRVEYEIGIVRSFWKPQPNLPWILDHSPEMTCAIPGTTMHTVKLRQCQWRLLFPFHFHHARKWDQDMLKPTGTDGGYCIISLYRQRDKNECCSPFRQTWVDSIIPNIYDCTIKPKNADWKDHHFGRCGKMVWGMNDAQDLCRSNTKDDEEGENDPHQLWKGNPKMKG"),
    ("protein2", "SDAMFQANYAGTMDESVHKSDESDSLYWERMTCAMVISWVLQMQMMKPRWNCNGANWLKAQATTYFDGSEVDAERNHPPMLCFMCKCSCAPEHGERQFQQMNWKIMTQAKREYSTANMPTHMQRSHYLNAKNNCFFPGNRSDVDEVEQIKKMLFSNFKCWRGQQNAVYYYWRAKFETTLRCEVKMKETRQHPYPLFLHNDSHKMKRFREHNVENAASRHPHFKGTIGMRFWKPYKERKYIQMEGLCIITCTIPMGWERPVKKFWCVECFDIGMMCDEAGMGGKMGEILWICQCEDVYMNPNMYRAIYIEVPMICIKWGDIYLHQPGYNAIWIMNQVDLMDQCKGDTFFTVFQSYMVDCQNADNFQINKHHQLNNEWMRVWQSKRGQAWFPDYTECTYDPFRICGMSYPVWDASDDKWSLEHQDMGFQWDSPYWCRNCFWMHGTFDNYDFKEPFCFWWLITEPQQERNVDNTFRCRNEARFLAYMWDENNSWTCGWMVMAKTPCETYRCPTPNKFFRAAISPNMCGARRFGGVLSKQMFVQSWFANGKEIIHRNKCMDGICIFTVNPHKSHAHFIEDYCQRNWGDNPVHEWQSKKIREYDKGPQWSFEVLCTPSNNVLKNEDTELFENSKCFDVGHAKWIPDDRAKGMYATMRCVWPHNANSFFEECFPGSTRMGWEKKRMSWIVDWDTIYCKVFAGINFDDWYEEAYQRGPAQNGQPKDWFGQFHQNPKKPDDSQTFYFPDWGSGSKADQEFMFCPQRWCIEVMCQFRLRMTMAASHGNDKYAAFTIPRAQDVGDHDNGHSCIDLGATASWKGSYCYWCARANAQECADMASNGYPQDQMYLQHCVRHQLPWRFTLTDAYAWRDVDDNGGENFMLQEPRIFMNAMDMWMRGWTWRPMWTIAEAECYCPHDQYSYNQDTRFHNTFAGTFMTRGQDCSSYQKEGPHLDKECCLHCLDCKYFFFHRGPVWMQWSGCCVMGHFSNVGKCYLMIFWLQFSACTGFMEADQGALTHAELDYSHWGCWYGNDPQQMQMTDLIMCGYCPPKFCADETPPCIACKKHPRRNMLELYHRYLVLCCWKGKFVYARHWAEKNGACWSWMMSNKGLAERVCCASGGYKSQCSMNETWIATVLNEAWLFQAFKHDHHMMVSGVLDHRTHDECQLKQTSCNVSGIVMCVGELGHMMWLMQDQVGDHFNVNCGQGNQGPICHKVMVTDHLRINPWMKGYGCVLWPMLEMQYMWWLKCFVCIQFYPFRVWTARCAQLPTAMTKLLWAVTTILHFVLRCSPYCMLMAKGNEKPASRAAT"),
    ("protein3", "AKCTLGLDYGKCEARWYFSLAMIWRLYYYFVLRVDKCFITVIGMWYEDQGEIIMERQKDLQHFLHIHGRKKPHCTANFEVKHRMLMCYFQRHCGSWWENDQIQQYLHMDCLTMSDKNHAWNFFWCRFHFTHFFEWHIIYHHGGANEEGRNYHWLSMWGSASRTLVDKCSQSSGAWLAWYKSAFPCQSRSSNLTHCRFYITPKKFYITIAVDFVVWIRGRLAIKFHLADSHNMMILTYFQYSLYNTDCWMMDETGNGFDWWCHMIRFAHNDTALLGFAPHCIWFVFDFGCNHRLKKDQCRYSKGVFLSWWAVCNWPWHGGHQKRIGDMVVFLQNANCPCHNPKASWVRIVLCTGWHVMKGTAKHMFPIDEQFGPGIFHQMYSHGNFYGLCWWSHAQKYQMYSKAKMCQRLARTVNPWNKRNTMVCMQEDVRPLIDVHQQGEQALVQSAEFGEYDQNNEGNGQARNVYRYWGREKFLVKAGSLMKGNPVNYTAIIDSHDGFDCSTLNWTYAIRAGFIYGECECNILTNDHGTGCRICQEVTHMMPAGDQALRWGRAYVQTAAGVAMTREKMSIFSVLLVYIDNTWAGLCVCPWQSFTIAHHLFKIGQGSFVHDSVENNCKQYTCEKDMCGSDYFHRCATNHRTHMGNEYYFIDLCIMNNESQRKVISEFMMRGVGMWMQWYLIEMWLHCIEYMCATWSCCERSTCDVWRCAAQTPFRATVRVNKGWQEYANEPIKQRHYLDEQMHLTNAALITNRNHPPRDPFSSPWMNSCCFTMYVMQGAINDERVTHNYNGHRKTVRHLFHGASDHEDHWWILYEWCEHTTITSSSNCWNVYDIEAYWWVPLYMPERTQLEPQSWLTRFTANWPWAMSPVKQIVCFCQEGTHDEWMYEHDIMSHSPGWKDYWAVWTFPMPNPLQMYWNDSDHGSLGLKLHTFCMVHYKDNGWGMMFLCPRWGFMQYFFFKVTNNQCRFLQNGILRCKPAYHPMPRKTHMDPFQSVWCCWHGTHAREHNKVHEDEKAITSPVSPVCTGRWLAHIWWMWKITQLKQRLCDKNSESPAHGVWTMLGSAFCLGEHVMTWVHVWYHENWVMDIQMHHAGQFYANLTVMPQLDKFKNEFTHEQRTRYNGVFGTVHRWSPDTAYDIETIRAKWWEIIMMCSEGTQIWIMFDLMSIVNARNIKGNMKFILKGNELCKSTQRAHPSNTFDANPFSRHFRQDFLLDWIIAVEYLVDDSTMWIWVTQQMIRTMQKLVGEHDMPSYAFGFTCVYEGLMYIIMLLWSFHQIIDFRTSGCDILVVQGEMHMFNK"),
    ("protein4", "NDLHDCLWEIESIYFANPYVVPHKCKHEMHYLYPEQMGLVIGWRTWCFATWDRMAWVMQISKYADYGGSDLGEYEAQVCQMTCYPTHWPCWSGMIMYVEYKQSLKKVIFHEICARAMSRKEWTNAGDVEFYTEILPWVYEMAHDWADEDMHWFMPPSELSVNPVPWCHCAKEIKPWTHGSCMIDNGDPDQDKSESRVDRWSDTCNLMLLWKFYCLLLWIYARNFERPNANYKLVFKRASTRPRSVIKPETSPGHKPSQWYQHNIPTHLRLKMENRHCIEHPCREVRFMCCYSDFEIGGMDVRRDQTDRHSLSDIFFGICQMTQDAILFSNRCTEKKTFFHHDDEKHIRFWRDGDFHCKQGVHVEMSAWHLPPVPYKNKVTWLSHKRAQNLKLLIKMSMPTYCQEESATSYQVCVDGCIRFTNPWKHYMNGEDEERAPWSIATQQERDETYFWWSCGQYTIEYFHYMFKYKQFSSNSNHRTTEQNMHYVEEVLRTSACIHRDHEDDEMIKLRVMPRQCYDIPIFWAFYRARYCIKCEDDIKEGESEWMSPFNFCFFCWLIWEITPIFADSPKSDEIWNTFTAELAVVMGCNMRNCCSDWCRYKAVDMYRTDRPGICYFGLQLLNVITSYWSFSRGQFLSDHSKYRTWDYIHHPPKAYEANINVCFLSNINYLVSGFSEHENGPWTWKEWQGMNLKAKHTVRTIIWVAHRMFMVERRVGEMMISSFSEIHYKQCPHRWMCKCKAPGEVTTAHYNCCFYDVQWTGTDGECADCHAAYAACMTRNASPIVGPKLWQIKQHDKADISSFIDRFSTQGYQPTVAVSEGREKQPWCGYMFGHMFPKNDPWCGQNQMQNARKAIKAYEGKTRSETTKMRYKMLAPRWWLYMPNYRCAILQLLWHLMKYYLKNDIEDANPKHSMRQEMYDCWYDIPEGIPGIGRDEFWWLDRLTHYSERGHFQFPNRYYCMHPRVIWGEHMQTEGWKFYKYWWNSYFGPMAMDLNPIIQRSHIGMCKDMTYKAECRYMDYGCDFPVFQAAWNSTCDLNEKQKIAKSPVCMDNVHMQDVESPCRENYYVLMGWLSHRERCHLQHGKMPFNPSLTEQCVKPNQFDVENKDIADCTMWKENWMGWLWLVSRYEWEYEMSAATANWNLIPNLDPRQAQRDTMKRLMGWYCRHYDHERKNWLRCEWVPHETSFNNLWCFAQPYMDKAQQVHGAIKRRCFLHTRLGTKSYFPDCHVQLFCFDDQCMGKCEKQIMLFFIQNGHGPRCIKFHKGISTHTMENDPLRHTIRCCSTWGYNFSVFSSLPFWYYRHKYMMDSHAYD"),
]
# Tokenize
ids, batch_token, lengths = batchConverter(data)
# Load Model
model = load_prot_flash_base()
# Data Parallel Inference
model = torch.nn.DataParallel(model) # Commenting this line
model.to(device='cuda')              # and this one removes the error
# Embedding
with torch.no_grad():
    token_embedding = model(batch_token, lengths)
# Generate per-sequence representations via averaging
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_embedding[i, 0: len(seq) + 1].mean(0))
Expected behavior
Expect token embedding tensor of shape [4, L, 768] where L is the length of the longest sequence.
Error encountered
einops.EinopsError:  Error while processing rearrange-reduction pattern "b (g j) -> b g () j".
 Input tensor shape: torch.Size([1, 1315]). Additional info: {'j': 21}.
 Shape mismatch, can't divide axis of length 1315 in chunks of 21
Workaround
The only way to solve this problem on my end is to not use torch.nn.DataParallel.
Thanks for any insights and solutions!