ProtFlash icon indicating copy to clipboard operation
ProtFlash copied to clipboard

`ProtFlash` einops error with variable length input during `nn.DataParallel` inference.

Open ayaanhossain opened this issue 8 months ago • 0 comments

Describe the bug ProtFlash when wrapped with torch.nn.DataParallel crashes via an einops error.

To Reproduce

import torch
from ProtFlash.pretrain import load_prot_flash_base
from ProtFlash.utils import batchConverter

# Proteins we want to embed
data = [
    ("protein1", "FVLHFPCIHDHMQAVVQWITRMDMVYKFRADMYGKGKNNGFRDVHCVDHQQRFGWTQTMPDYPSGWEVASFKKSHIGRRASAPIGLGSYTKLKSMKIHMQKSTKWDWGMFKHQATVMMQEREQGRSENLGNYYTMNHCNTERRRHIITVIYMINPYRMRLKTNQKFLYNQCYFYKWVRWNTDAMTSMLNVTCNHSLYKQCWDHTYLLAYKDPQGSNEQNTDEGHVRMMVKECGPKILYYDCFTKVPMFHDLFGSWLMWLPILLQNLAVVDGYGTSVLMTEGDSYCEGVKFGNICTIFRRDASPSPINRIWVSYICLRSIGKSAGESKAFKYMVRVGWFKQGFCLYEKWSLDFFNEFEHPIGLNVWNNQKHHDTGFYLPFFRKDTIQHMSVEQWPDDECRYLNSIKNGAMTAEFSLMPFQCTPDKAQRFHEIFFKVSGWEWGMESGVALDVIQVMAEEWWLVFQIFFIHEHCYHCNVHTNRGHWSSVHGWNGFVAARDYIRLADRANHHWNNHIQENPASMERYIFGNLKRVAFINYQMAADPGFQTKVRRVDRYYNRVTVTIRISTWNKDPMLQTDKRTSSTYNFMQCRMWEWKNPNVDRYEFKRYSSKPEKDTSLACREFVVSGNEVARLRKDRPKHHLFFFWEDDYLGIAGFSAISTLEHEPGPPWHYIQPTKSHLLQNASRCYPALTFWIEHMQWEYMCWHPQEQGTDCMCPMPLSPIFCFEAGLDFADKPSPNWFSWMTCVMEARKGIYIAFDQSTPSIPCMMMRHMGFCGSWPPNKIPSIMKFGAKQQYRSKFHQVPLPANYVLHRPQLEAFMILFWHTIKNFEHSDKQQKAKTVQHEWMGFMFDPQKCHTEDCTCNYPIEDYDTIHRLWKQVKFYENYATTWIFLCTPLLQLTWWPRGIVCMEKDRTRQEMCHLQCTKVMMQTDLKTVFIMIDGILWMIPSKLEVVDHFWAEGMFMRTCCAWPMSNELQNLAFKHGPILSAHAWDTEGNTDEVGERAVVTGCIHYPWRRCANEYCWQCAIHGMCCQYFWHHAILPDDKADPYNDYLRENAPEWAPCQLYPHANKANEEHEAKECKAYKYVYMIRATGEFQATKRHCDPRFHWTIFGMMDALSDIDFNGFYDDWMIMFMPFTKYAGRVEYEIGIVRSFWKPQPNLPWILDHSPEMTCAIPGTTMHTVKLRQCQWRLLFPFHFHHARKWDQDMLKPTGTDGGYCIISLYRQRDKNECCSPFRQTWVDSIIPNIYDCTIKPKNADWKDHHFGRCGKMVWGMNDAQDLCRSNTKDDEEGENDPHQLWKGNPKMKG"),
    ("protein2", "SDAMFQANYAGTMDESVHKSDESDSLYWERMTCAMVISWVLQMQMMKPRWNCNGANWLKAQATTYFDGSEVDAERNHPPMLCFMCKCSCAPEHGERQFQQMNWKIMTQAKREYSTANMPTHMQRSHYLNAKNNCFFPGNRSDVDEVEQIKKMLFSNFKCWRGQQNAVYYYWRAKFETTLRCEVKMKETRQHPYPLFLHNDSHKMKRFREHNVENAASRHPHFKGTIGMRFWKPYKERKYIQMEGLCIITCTIPMGWERPVKKFWCVECFDIGMMCDEAGMGGKMGEILWICQCEDVYMNPNMYRAIYIEVPMICIKWGDIYLHQPGYNAIWIMNQVDLMDQCKGDTFFTVFQSYMVDCQNADNFQINKHHQLNNEWMRVWQSKRGQAWFPDYTECTYDPFRICGMSYPVWDASDDKWSLEHQDMGFQWDSPYWCRNCFWMHGTFDNYDFKEPFCFWWLITEPQQERNVDNTFRCRNEARFLAYMWDENNSWTCGWMVMAKTPCETYRCPTPNKFFRAAISPNMCGARRFGGVLSKQMFVQSWFANGKEIIHRNKCMDGICIFTVNPHKSHAHFIEDYCQRNWGDNPVHEWQSKKIREYDKGPQWSFEVLCTPSNNVLKNEDTELFENSKCFDVGHAKWIPDDRAKGMYATMRCVWPHNANSFFEECFPGSTRMGWEKKRMSWIVDWDTIYCKVFAGINFDDWYEEAYQRGPAQNGQPKDWFGQFHQNPKKPDDSQTFYFPDWGSGSKADQEFMFCPQRWCIEVMCQFRLRMTMAASHGNDKYAAFTIPRAQDVGDHDNGHSCIDLGATASWKGSYCYWCARANAQECADMASNGYPQDQMYLQHCVRHQLPWRFTLTDAYAWRDVDDNGGENFMLQEPRIFMNAMDMWMRGWTWRPMWTIAEAECYCPHDQYSYNQDTRFHNTFAGTFMTRGQDCSSYQKEGPHLDKECCLHCLDCKYFFFHRGPVWMQWSGCCVMGHFSNVGKCYLMIFWLQFSACTGFMEADQGALTHAELDYSHWGCWYGNDPQQMQMTDLIMCGYCPPKFCADETPPCIACKKHPRRNMLELYHRYLVLCCWKGKFVYARHWAEKNGACWSWMMSNKGLAERVCCASGGYKSQCSMNETWIATVLNEAWLFQAFKHDHHMMVSGVLDHRTHDECQLKQTSCNVSGIVMCVGELGHMMWLMQDQVGDHFNVNCGQGNQGPICHKVMVTDHLRINPWMKGYGCVLWPMLEMQYMWWLKCFVCIQFYPFRVWTARCAQLPTAMTKLLWAVTTILHFVLRCSPYCMLMAKGNEKPASRAAT"),
    ("protein3", "AKCTLGLDYGKCEARWYFSLAMIWRLYYYFVLRVDKCFITVIGMWYEDQGEIIMERQKDLQHFLHIHGRKKPHCTANFEVKHRMLMCYFQRHCGSWWENDQIQQYLHMDCLTMSDKNHAWNFFWCRFHFTHFFEWHIIYHHGGANEEGRNYHWLSMWGSASRTLVDKCSQSSGAWLAWYKSAFPCQSRSSNLTHCRFYITPKKFYITIAVDFVVWIRGRLAIKFHLADSHNMMILTYFQYSLYNTDCWMMDETGNGFDWWCHMIRFAHNDTALLGFAPHCIWFVFDFGCNHRLKKDQCRYSKGVFLSWWAVCNWPWHGGHQKRIGDMVVFLQNANCPCHNPKASWVRIVLCTGWHVMKGTAKHMFPIDEQFGPGIFHQMYSHGNFYGLCWWSHAQKYQMYSKAKMCQRLARTVNPWNKRNTMVCMQEDVRPLIDVHQQGEQALVQSAEFGEYDQNNEGNGQARNVYRYWGREKFLVKAGSLMKGNPVNYTAIIDSHDGFDCSTLNWTYAIRAGFIYGECECNILTNDHGTGCRICQEVTHMMPAGDQALRWGRAYVQTAAGVAMTREKMSIFSVLLVYIDNTWAGLCVCPWQSFTIAHHLFKIGQGSFVHDSVENNCKQYTCEKDMCGSDYFHRCATNHRTHMGNEYYFIDLCIMNNESQRKVISEFMMRGVGMWMQWYLIEMWLHCIEYMCATWSCCERSTCDVWRCAAQTPFRATVRVNKGWQEYANEPIKQRHYLDEQMHLTNAALITNRNHPPRDPFSSPWMNSCCFTMYVMQGAINDERVTHNYNGHRKTVRHLFHGASDHEDHWWILYEWCEHTTITSSSNCWNVYDIEAYWWVPLYMPERTQLEPQSWLTRFTANWPWAMSPVKQIVCFCQEGTHDEWMYEHDIMSHSPGWKDYWAVWTFPMPNPLQMYWNDSDHGSLGLKLHTFCMVHYKDNGWGMMFLCPRWGFMQYFFFKVTNNQCRFLQNGILRCKPAYHPMPRKTHMDPFQSVWCCWHGTHAREHNKVHEDEKAITSPVSPVCTGRWLAHIWWMWKITQLKQRLCDKNSESPAHGVWTMLGSAFCLGEHVMTWVHVWYHENWVMDIQMHHAGQFYANLTVMPQLDKFKNEFTHEQRTRYNGVFGTVHRWSPDTAYDIETIRAKWWEIIMMCSEGTQIWIMFDLMSIVNARNIKGNMKFILKGNELCKSTQRAHPSNTFDANPFSRHFRQDFLLDWIIAVEYLVDDSTMWIWVTQQMIRTMQKLVGEHDMPSYAFGFTCVYEGLMYIIMLLWSFHQIIDFRTSGCDILVVQGEMHMFNK"),
    ("protein4", "NDLHDCLWEIESIYFANPYVVPHKCKHEMHYLYPEQMGLVIGWRTWCFATWDRMAWVMQISKYADYGGSDLGEYEAQVCQMTCYPTHWPCWSGMIMYVEYKQSLKKVIFHEICARAMSRKEWTNAGDVEFYTEILPWVYEMAHDWADEDMHWFMPPSELSVNPVPWCHCAKEIKPWTHGSCMIDNGDPDQDKSESRVDRWSDTCNLMLLWKFYCLLLWIYARNFERPNANYKLVFKRASTRPRSVIKPETSPGHKPSQWYQHNIPTHLRLKMENRHCIEHPCREVRFMCCYSDFEIGGMDVRRDQTDRHSLSDIFFGICQMTQDAILFSNRCTEKKTFFHHDDEKHIRFWRDGDFHCKQGVHVEMSAWHLPPVPYKNKVTWLSHKRAQNLKLLIKMSMPTYCQEESATSYQVCVDGCIRFTNPWKHYMNGEDEERAPWSIATQQERDETYFWWSCGQYTIEYFHYMFKYKQFSSNSNHRTTEQNMHYVEEVLRTSACIHRDHEDDEMIKLRVMPRQCYDIPIFWAFYRARYCIKCEDDIKEGESEWMSPFNFCFFCWLIWEITPIFADSPKSDEIWNTFTAELAVVMGCNMRNCCSDWCRYKAVDMYRTDRPGICYFGLQLLNVITSYWSFSRGQFLSDHSKYRTWDYIHHPPKAYEANINVCFLSNINYLVSGFSEHENGPWTWKEWQGMNLKAKHTVRTIIWVAHRMFMVERRVGEMMISSFSEIHYKQCPHRWMCKCKAPGEVTTAHYNCCFYDVQWTGTDGECADCHAAYAACMTRNASPIVGPKLWQIKQHDKADISSFIDRFSTQGYQPTVAVSEGREKQPWCGYMFGHMFPKNDPWCGQNQMQNARKAIKAYEGKTRSETTKMRYKMLAPRWWLYMPNYRCAILQLLWHLMKYYLKNDIEDANPKHSMRQEMYDCWYDIPEGIPGIGRDEFWWLDRLTHYSERGHFQFPNRYYCMHPRVIWGEHMQTEGWKFYKYWWNSYFGPMAMDLNPIIQRSHIGMCKDMTYKAECRYMDYGCDFPVFQAAWNSTCDLNEKQKIAKSPVCMDNVHMQDVESPCRENYYVLMGWLSHRERCHLQHGKMPFNPSLTEQCVKPNQFDVENKDIADCTMWKENWMGWLWLVSRYEWEYEMSAATANWNLIPNLDPRQAQRDTMKRLMGWYCRHYDHERKNWLRCEWVPHETSFNNLWCFAQPYMDKAQQVHGAIKRRCFLHTRLGTKSYFPDCHVQLFCFDDQCMGKCEKQIMLFFIQNGHGPRCIKFHKGISTHTMENDPLRHTIRCCSTWGYNFSVFSSLPFWYYRHKYMMDSHAYD"),
]

# Tokenize
ids, batch_token, lengths = batchConverter(data)

# Load Model
model = load_prot_flash_base()

# Data Parallel Inference
model = torch.nn.DataParallel(model) # Commenting this line
model.to(device='cuda')              # and this one removes the error

# Embedding
with torch.no_grad():
    token_embedding = model(batch_token, lengths)

# Generate per-sequence representations via averaging
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_embedding[i, 0: len(seq) + 1].mean(0))

Expected behavior Expect token embedding tensor of shape [4, L, 768] where L is the length of the longest sequence.

Error encountered

einops.EinopsError:  Error while processing rearrange-reduction pattern "b (g j) -> b g () j".
 Input tensor shape: torch.Size([1, 1315]). Additional info: {'j': 21}.
 Shape mismatch, can't divide axis of length 1315 in chunks of 21

Workaround The only way to solve this problem on my end is to not use torch.nn.DataParallel.

Thanks for any insights and solutions!

ayaanhossain avatar Feb 22 '25 03:02 ayaanhossain