antiberty-pytorch icon indicating copy to clipboard operation
antiberty-pytorch copied to clipboard

issues in `AntiBERTyRunner.py`

Open Elmiar0642 opened this issue 2 years ago • 0 comments

Hey there,

I attempted to re-run the new v3.0.x of IGFold with openmm on my system last night. After updating and upgrading the packages, I tried to run the notebook, and I found the following error being thrown from the script AntiBERTy.py.

File "/xxx/yyy/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

To resolve this, I checked which devices the variables embeddings and attention_maps are attached and detached.

They both were created in GPU and only embeddings is detached from the GPU to the CPU. So, I made the following change:

  • Detached them to the CPU, and made both into a list.

` # gather embeddings embeddings = outputs.hidden_states embeddings = torch.stack(embeddings, dim=1) embeddings = embeddings.detach().cpu().tolist()

    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

It threw the following error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
TypeError: list indices must be integers or slices, not tuple

To understand the core problem, I wanted to understand embeddings and attention_maps. So,

` # gather embeddings embeddings = outputs.hidden_states embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    
    embeddings = embeddings.detach().cpu().tolist()
    
    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

Details


embeddings: tensor([[[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 1.2959e-01, -2.3578e-01, -9.5074e-01,  ..., -3.4716e-01,
            3.9048e-01, -7.9039e-01],
          [-1.1861e-01, -8.5111e-01,  1.7778e-01,  ..., -6.4417e-01,
           -1.6268e-01, -7.4019e-01],
          ...,
          [ 1.4825e+00,  1.0562e+00, -5.5296e-01,  ...,  4.6048e-02,
           -5.8749e-01,  3.5935e-01],
          [ 1.1087e+00,  8.3452e-01, -4.6560e-01,  ..., -6.5979e-01,
            7.0711e-02,  1.3638e+00],
          [ 7.1583e-01,  8.4463e-01,  7.4550e-01,  ...,  5.5646e-01,
           -6.0864e-01,  1.2408e+00]],

         [[ 8.2428e-01, -6.0705e-01, -9.0634e-01,  ..., -4.5286e-02,
           -6.8834e-02,  4.4105e-01],
          [ 7.2001e-01,  6.3411e-01, -1.0107e+00,  ..., -4.3047e-01,
           -5.7251e-01, -6.7011e-01],
          [ 4.6859e-01, -8.5742e-01, -1.5053e-02,  ..., -2.8734e-01,
           -1.0233e+00, -3.6219e-01],
          ...,
          [ 1.0764e+00,  1.1695e+00, -6.8277e-01,  ...,  2.8122e-02,
           -9.8832e-01,  1.4659e-01],
          [ 8.8104e-01,  1.1147e+00, -7.1646e-01,  ..., -1.0783e-01,
           -7.9473e-01,  1.0538e+00],
          [ 6.3558e-01,  9.0190e-01,  4.0055e-01,  ...,  3.1800e-01,
           -1.0868e+00,  9.7025e-01]],

         [[ 9.6156e-01, -9.6647e-01, -1.4004e+00,  ..., -6.3557e-01,
            4.1958e-01, -1.8568e-01],
          [ 3.0844e-01,  1.0339e+00, -1.5486e+00,  ...,  2.1584e-01,
           -3.8619e-01, -8.9405e-01],
          [ 4.5382e-01, -3.8623e-01,  1.7961e-01,  ..., -1.4155e-01,
           -1.1880e+00, -5.4827e-01],
          ...,
          [ 9.9114e-01,  5.7983e-01, -2.9399e-01,  ..., -4.6010e-01,
           -6.7488e-01, -6.2466e-01],
          [ 7.5153e-01,  4.8691e-01, -5.4032e-01,  ...,  2.6127e-01,
           -1.0607e+00,  7.8277e-01],
          [ 8.5168e-01,  4.9293e-01, -2.6708e-01,  ...,  3.8526e-01,
           -1.1824e+00,  8.5203e-01]],

         ...,

         [[ 1.2814e+00, -4.3900e-01, -3.2785e-01,  ..., -1.2414e+00,
           -6.3775e-01, -1.3176e+00],
          [ 3.0157e-01,  1.6172e+00, -1.3343e+00,  ..., -1.2285e+00,
           -5.5167e-01, -1.8283e+00],
          [ 3.5919e-01, -2.6482e-01, -1.0645e+00,  ..., -4.3375e-02,
           -3.2065e-01, -9.8966e-01],
          ...,
          [ 1.8181e+00, -1.6646e-01, -1.2666e+00,  ...,  1.0637e+00,
            1.4646e+00, -1.6298e+00],
          [ 1.0763e+00, -5.1882e-01, -6.8510e-01,  ...,  1.3576e+00,
            1.2688e+00, -1.4657e+00],
          [ 1.7986e+00, -7.4009e-02, -1.2577e+00,  ...,  1.0660e+00,
            1.4812e+00, -1.4051e+00]],

         [[ 1.2025e+00, -5.5392e-01, -1.0193e+00,  ..., -8.1229e-01,
           -2.3811e-01, -4.7275e-01],
          [ 6.5538e-01,  1.1917e+00, -5.2697e-01,  ..., -8.7801e-01,
           -7.4126e-01, -1.9144e+00],
          [ 2.5875e-01, -7.9232e-01, -8.5029e-01,  ...,  6.4324e-02,
           -8.0997e-02, -1.9687e+00],
          ...,
          [ 1.4830e+00, -1.9244e-01, -6.8066e-01,  ...,  2.1269e-01,
            1.0873e+00, -1.3896e+00],
          [ 5.3997e-01, -1.4820e-01, -2.0483e-01,  ...,  7.3495e-01,
            8.6871e-01, -1.3526e+00],
          [ 1.6477e+00, -5.3092e-02, -7.1276e-01,  ...,  3.2879e-01,
            1.1778e+00, -9.6469e-01]],

         [[ 1.5494e+00, -9.5254e-01, -8.3588e-01,  ..., -4.2762e-01,
            6.2013e-01,  1.0120e-02],
          [ 4.4904e-02,  7.8505e-01, -1.0384e+00,  ..., -7.8334e-02,
           -1.7476e-01, -1.6311e+00],
          [ 1.7894e-01, -9.9010e-01, -1.1633e+00,  ...,  6.0122e-01,
           -1.0615e-01, -1.5358e+00],
          ...,
          [ 1.2771e+00, -1.8352e-01, -1.4466e+00,  ..., -6.2605e-01,
            1.2011e+00, -2.0856e+00],
          [ 5.6284e-01, -9.5801e-02, -1.1209e+00,  ..., -5.1828e-01,
            4.9442e-01, -1.5956e+00],
          [ 1.1071e+00,  3.0336e-01, -1.8048e+00,  ..., -3.8724e-01,
            1.1147e+00, -1.5361e+00]]],


        [[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 5.0035e-01,  5.4549e-01,  3.4283e-01,  ..., -3.0739e-01,
           -4.9315e-01, -1.1373e+00],
          [-4.0275e-01,  2.1443e-02,  2.0123e-01,  ..., -2.4489e-01,
            8.3188e-01, -6.5645e-01],
          ...,
          [ 4.0514e-01, -3.2213e-01,  3.7994e-01,  ...,  1.2408e-01,
            6.3095e-01,  9.2037e-03],
          [ 1.9132e-01, -4.4131e-01,  4.2406e-01,  ..., -2.6266e-01,
            9.8391e-01,  5.5734e-01],
          [ 4.0278e-01, -4.9534e-02,  3.3810e-01,  ...,  1.4354e-01,
            8.4249e-01,  4.0723e-01]],

         [[ 6.2418e-02, -6.1317e-01, -1.5439e+00,  ..., -3.1803e-01,
           -2.0041e-01,  4.4618e-01],
          [-6.7039e-02,  1.2193e+00, -5.0822e-01,  ...,  3.5469e-01,
            2.6262e-02, -7.7125e-01],
          [-9.5805e-01,  1.4456e-01, -1.8127e-01,  ...,  3.6328e-01,
            1.4936e+00, -4.5747e-02],
          ...,
          [ 6.8287e-02,  8.2539e-01,  5.4192e-02,  ..., -1.1069e-01,
            6.6216e-01,  7.4946e-01],
          [-1.9581e-01,  6.8329e-01, -2.6928e-01,  ..., -7.0956e-01,
            7.8344e-01,  1.4804e+00],
          [-4.1462e-02,  8.8683e-01, -5.2905e-01,  ..., -2.5274e-01,
            7.1604e-01,  1.2256e+00]],

         [[ 3.5130e-01, -1.5874e+00, -1.7016e+00,  ...,  6.8850e-01,
           -5.8646e-01,  1.7784e-01],
          [ 1.1386e-01,  1.3657e+00, -8.2388e-01,  ...,  4.7490e-01,
            1.2626e+00, -3.1313e-01],
          [-1.1854e+00, -1.1600e-03, -7.3433e-01,  ...,  7.6139e-01,
            1.6375e+00,  1.8955e-01],
          ...,
          [-6.9969e-01,  1.1508e+00,  7.0558e-02,  ...,  4.2873e-01,
            5.6067e-01,  5.2250e-01],
          [-5.0788e-01,  6.6331e-01, -6.1032e-01,  ..., -2.3532e-01,
            8.2221e-01,  7.9204e-01],
          [-2.6820e-01,  8.5643e-01, -4.7090e-01,  ..., -2.8118e-01,
            6.5296e-01,  6.8785e-01]],

         ...,

         [[-9.0217e-02, -2.6741e-01, -1.0890e+00,  ...,  1.8798e+00,
           -3.2522e-03, -1.5653e-01],
          [-6.9740e-01,  1.4951e+00, -6.4886e-01,  ..., -1.3687e-01,
            1.4956e+00,  3.7487e-01],
          [-1.6580e-01,  1.1264e-01, -7.6442e-01,  ...,  4.3402e-01,
            1.9541e+00,  1.2029e+00],
          ...,
          [ 1.9953e-01,  2.6025e+00, -4.9651e-01,  ...,  5.0344e-01,
           -1.2114e-02,  3.9688e-01],
          [-1.0917e+00,  1.2115e+00,  6.2053e-01,  ...,  8.5435e-01,
           -4.5358e-02,  3.5120e-01],
          [ 6.1694e-01,  2.1130e+00, -1.1016e+00,  ...,  2.8187e-01,
            9.5419e-02, -3.5959e-01]],

         [[ 5.0400e-01, -5.3220e-01, -1.0173e+00,  ...,  2.1676e+00,
           -3.6843e-01, -1.8500e-01],
          [-2.1364e-01,  9.2027e-01, -2.5382e-01,  ...,  1.1757e-01,
            9.4363e-01,  6.0816e-01],
          [-1.0163e-01, -3.2413e-02, -7.2567e-01,  ...,  1.1070e+00,
            1.3306e+00,  1.0462e+00],
          ...,
          [ 3.0022e-01,  2.6991e+00, -4.7573e-01,  ..., -1.0428e-01,
           -7.8721e-02,  1.1695e+00],
          [-1.0961e+00,  6.7808e-01, -3.0792e-01,  ...,  7.1660e-01,
           -2.0900e-01,  4.5738e-01],
          [ 8.4948e-01,  1.9340e+00, -1.1624e+00,  ..., -2.2008e-01,
            4.5761e-01,  9.6474e-01]],

         [[-2.9150e-01,  4.8298e-01, -3.7572e-01,  ...,  2.4827e+00,
           -1.9686e-01,  2.9108e-01],
          [-4.5003e-01,  4.0321e-01, -1.0218e+00,  ..., -1.9378e-01,
            5.3391e-01,  3.8499e-01],
          [-7.7064e-02, -5.0206e-01, -1.3377e+00,  ...,  9.1953e-01,
            5.2488e-01,  1.2372e-01],
          ...,
          [ 7.2962e-01,  1.8133e+00,  2.9414e-01,  ...,  7.3038e-01,
           -2.0271e-01,  2.1481e+00],
          [-7.7066e-01, -1.0586e-01, -9.3787e-02,  ...,  1.0239e+00,
           -2.1658e-01,  9.3203e-01],
          [ 1.0556e+00,  9.7592e-01, -1.2148e+00,  ..., -4.7689e-02,
           -1.4709e-02,  2.9145e-01]]]], device='cuda:0') length:torch.Size([2, 9, 120, 512])
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) length:torch.Size([2, 120])

I made this change: ` # gather embeddings embeddings = outputs.hidden_states embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    embeddings = embeddings.detach().cpu()
    
    for i, a in enumerate(attention_mask.detach().cpu()):
        embeddings[i] = embeddings[i][:, a == 1]

`

So, finally, I tried to replace them as tensors and tried to replace, but it obviously threw tensor dimensions mismatch error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: The expanded size of the tensor (120) must match the existing size (109) at non-singleton dimension 1.  Target sizes: [9, 120, 512].  Tensor sizes: [9, 109, 512]

Because embeddings size is: [2, 9, 120, 512]. Whereas attention_mask size is: [2, 120].

What is the end goal of the following snippet? Why does this throw an error? Please help me resolv this.

Elmiar0642 avatar Aug 04 '23 06:08 Elmiar0642