DNABERT icon indicating copy to clipboard operation
DNABERT copied to clipboard

Readme section 5.2

Open mepster opened this issue 2 years ago • 4 comments

The README outlines installation and a tutorial for use.

Section 5 says:

Visualiazation of DNABERT consists of 2 steps. Calcualate attention scores and Plot.

However section 5.2 seems to be incomplete. It consists only of:

####5.2 Plotting tool

Do you have any further instructions how to use the plotting tool?

Thanks for making DNABERT available!

mepster avatar Dec 16 '22 01:12 mepster

P.S. The following generates one plot for me. Not sure I know what it is. :-)

pip install matplotlib
pip install seaborn

cd data_process_template

export KMER=6
export MODEL_PATH=../ft/$KMER

python ../visualize.py --model_path $MODEL_PATH --kmer $KMER

plot

mepster avatar Dec 16 '22 02:12 mepster

You can use BerViz as well for visualization. DNABert is available on the HuggingFace platform, find it here: https://huggingface.co/zhihan1996. Load it using HiggingFace then use BerViz for visualization of weights. I assume you can load it like this:

from transformers import AutoTokenizer, AutoModel
from bertviz import head_view, model_view

tokenizer = auto_tokenizer.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', do_lower_case=False)
model = auto_model.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', output_attentions=True) 

def call_html():
    display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

sequence = # your sequence 
call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids.to(device))[-1]

input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
model_view(attention, tokens)

I haven't run this code, however, I have written this code and it works with ESM and ProtTrans models, I assume it works with DNABert as well. Though there might be slight differences that you should handle, but this is the general way of doing it.

Moeinh77 avatar Dec 16 '22 16:12 Moeinh77

Thanks @Moeinh77 !

I would greatly appreciate your wise feedback on the following.

If I modify your code to this

%matplotlib widget
from bertviz import head_view, model_view
from transformers import AutoTokenizer, AutoModel
import IPython

# I downloaded the model to a local directory 'DNA_bert_6'
tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True) 

def call_html():
    display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

sequence = 'CACAGCACAGCCCAGCCAAGCCAGGCCAGCCCAGCCCAGCCAAGCCACGCCACTCCACTACACTAGACTAGGCTAGGCTAGGCCAGGCCCGGCCCTGCCCTGCCCTGTCCTGTCCTGTCCTGTCCTGTCCTGTCCTGCCCTGCACTGCAGTGCAGCGCAGCCCAGCCCAGCCCCGCCCCCCCCCCTCCCCTGCCCTGTCCTGTACTGTAGTGTAGGGTAGGGTAGGGGAGGGGTGGGGTCGGGTCTGGTCTGGTCTGGTCTGGACTGGAATGGAACGGAACAGAACAGAACAGCACAGCCCAGCCAAGCCAGGCCAGGCCAGGACAGGAGAGGAGTGGAGTGGAGTGGAGTGGTGTGGTTTGGTTTGGTTTAGTTTAATTTAAGTTAAGATAAGAGAAGAGGAGAGGCGAGGCAAGGCAGGGCAGGGCAGGGCAGGGGAGGGGAGGGGAGGGGAGTGGAGTCGAGTCGAGTCGCGTCGCCTCGCCTCGCCTTGCCTTGCCTTGCCTTGCCTTGCCCTGCCCTGCCCTGCCCTGTCCTGTGCTGTGCTGTGCCGTGCCATGCCACGCCACACCACAC'
#print(sequence, len(sequence))

call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
device = 'cpu'
attention = model(input_ids.to(device))[-1]
print(attention)

input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
model_view(attention, tokens)

I get the error

============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_131706/782624532.py in <module>
      7 # Also note I commented out "output_attention=True"
      8 tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
----> 9 model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True)
     10 
     11 def call_html():

~/Repos/DNABERT/src/transformers/modeling_auto.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    382         for config_class, model_class in MODEL_MAPPING.items():
    383             if isinstance(config, config_class):
--> 384                 return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
    385         raise ValueError(
    386             "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"

~/Repos/DNABERT/src/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    645 
    646         # Instantiate model.
--> 647         model = cls(config, *model_args, **model_kwargs)
    648 
    649         if state_dict is None and not from_tf:

TypeError: __init__() got an unexpected keyword argument 'output_attention'

But if I comment out 'output_attention' like this:

model = AutoModel.from_pretrained('DNA_bert_6')#, output_attention=True) 

then I get the error:

============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>

tensor([[ 0.6328, -0.1791,  0.0956, -0.5731, -0.1763,  0.0412,  0.0398, -0.1986,
         -0.1466,  0.0585, -0.3806,  0.3187,  0.2130, -0.2569,  0.0940, -0.2212,
         -0.1783,  0.4327, -0.1394,  0.3384,  0.0865, -0.5700, -0.2146, -0.8339,
          0.5354,  0.2046, -0.1657, -0.3046,  0.0650, -0.3055, -0.1878,  0.3082,
         -0.3724,  0.1021, -0.1528,  0.0377, -0.0578, -0.0887,  0.0704, -0.1101,
         -0.0773, -0.7024,  0.5353,  0.3198,  0.3903, -0.2056, -0.5153,  0.4037,
          0.3320,  0.1946,  0.5429, -0.2637,  0.6125, -0.0789,  0.1112,  0.7812,
          0.0152, -0.0238, -0.1098, -0.4792, -0.4241, -0.0332, -0.0680,  0.0209,
          0.3960, -0.3283,  0.3168, -0.3727, -0.6130,  0.4901,  0.1921,  0.1362,
          0.2909,  0.3878,  0.1141,  0.2536,  0.3613,  0.1088,  0.1635,  0.6388,
          0.0469, -0.0086,  0.4082,  0.0579, -0.2604,  0.0479, -0.1125, -0.0638,
          0.0445, -0.3106,  0.5073,  0.7865, -0.6308, -0.1794, -0.2932,  0.3073,
          0.7342, -0.5972, -0.1660, -0.2717,  0.0208,  0.4985,  0.0841,  0.2110,
         -0.1882, -0.5626, -0.0462,  0.3095,  0.3285, -0.3160, -0.6916,  0.4086,
         -0.4803,  0.3591,  0.7099, -0.4982, -0.0836,  0.0289, -0.0527,  0.0656,
          0.0248, -0.0470, -0.1063, -0.2031,  0.1532, -0.6646, -0.2786, -0.1763,
          0.2845,  0.0204,  0.0968,  0.4840, -0.1410,  0.0664,  0.7044, -0.1921,
         -0.0878,  0.7824, -0.0782, -0.3539, -0.4634,  0.1302,  0.0379, -0.3873,
          0.4631,  0.6339,  0.3224, -0.3064,  0.7668, -0.6782, -0.7405,  0.2909,
          0.3727,  0.0092,  0.4768, -0.3835,  0.0044,  0.3539, -0.2782,  0.0300,
         -0.4242, -0.0869,  0.2858,  0.1832, -0.2757, -0.0408,  0.3069,  0.3939,
         -0.2805, -0.3999, -0.6797,  0.5870,  0.3561,  0.1782,  0.0386,  0.2089,
         -0.4678,  0.5330, -0.0612, -0.0711, -0.3962,  0.3030,  0.2728,  0.7759,
         -0.2187,  0.4681,  0.8431, -0.1215,  0.0179, -0.3706,  0.2816,  0.3015,
          0.4501, -0.5754,  0.0781,  0.1056,  0.2775, -0.5127,  0.2704, -0.0778,
         -0.4359, -0.3127, -0.3191, -0.0699,  0.1088,  0.1312,  0.2286, -0.1017,
         -0.4808,  0.0402, -0.4942, -0.0275,  0.3092,  0.3350,  0.5882,  0.3709,
          0.3021,  0.5267,  0.0873, -0.1373, -0.3230, -0.4778, -0.1138, -0.2788,
          0.3652, -0.1727,  0.4334,  0.3024, -0.5430,  0.1484,  0.3245, -0.1186,
          0.6870,  0.4794,  0.2478,  0.2582, -0.3006,  0.1935, -0.6847,  0.2272,
         -0.2336, -0.2726,  0.3207, -0.1575, -0.4818, -0.5784, -0.1170, -0.1105,
         -0.2947,  0.0261,  0.0238, -0.2571, -0.1420,  0.1536,  0.0366, -0.2715,
          0.2502, -0.4401, -0.3884,  0.4987, -0.1169, -0.3366, -0.4294,  0.1843,
          0.0435, -0.0389,  0.1158,  0.0505,  0.5883, -0.4889,  0.3945,  0.6702,
         -0.1314,  0.5126,  0.1850,  0.0217,  0.3293,  0.3679,  0.0524,  0.4328,
          0.1084, -0.0995, -0.3061, -0.4815,  0.8196, -0.1747, -0.7363, -0.3690,
         -0.4089, -0.3178, -0.3008,  0.0842,  0.5494, -0.3024,  0.0686, -0.0024,
          0.5566, -0.4074, -0.1000,  0.5726,  0.2140,  0.3414,  0.2051, -0.5276,
         -0.0900, -0.4059,  0.6995, -0.6243,  0.1423,  0.0790, -0.5684, -0.7503,
          0.0812,  0.2904, -0.0825, -0.5300,  0.2102,  0.2130,  0.1844, -0.2488,
         -0.3667, -0.2187, -0.0555,  0.1509, -0.5548, -0.1680, -0.1770, -0.3859,
          0.2160,  0.6033, -0.0487,  0.0178,  0.2910,  0.6036, -0.5233,  0.0991,
         -0.3996, -0.5462,  0.4345, -0.0706, -0.1149,  0.6331, -0.4878, -0.2699,
          0.3667, -0.1658, -0.2681, -0.4240, -0.2277, -0.0285, -0.4568, -0.5831,
          0.2132, -0.4749, -0.3003,  0.2729,  0.2058,  0.2281, -0.2277,  0.4400,
          0.3480,  0.4202, -0.1285, -0.0944,  0.3023,  0.1910, -0.0172, -0.3824,
          0.2430,  0.6200, -0.2266, -0.2784,  0.4599,  0.2767, -0.4933, -0.0863,
          0.6689,  0.5872,  0.1265, -0.2042,  0.0360,  0.5149, -0.3186, -0.2453,
         -0.8713,  0.5262,  0.4512, -0.4293,  0.2160, -0.2895,  0.5623, -0.0468,
          0.3609,  0.0478,  0.5212,  0.3534,  0.6020, -0.5200, -0.0011,  0.1896,
          0.1710, -0.4412,  0.3991,  0.3104, -0.1117,  0.1043,  0.2124, -0.6773,
          0.3843, -0.1273,  0.6962,  0.5702, -0.2138,  0.4218, -0.1075, -0.1051,
          0.7056,  0.5510, -0.2203, -0.4441, -0.1577,  0.0099,  0.7935,  0.0558,
         -0.3306,  0.4195,  0.2848, -0.0217,  0.3778,  0.1274,  0.0162, -0.5062,
         -0.1455, -0.1946, -0.3807, -0.5371, -0.1086, -0.3016, -0.4707, -0.4386,
         -0.2609,  0.2080,  0.5916, -0.1465,  0.6091, -0.3506,  0.5574,  0.1640,
         -0.0207, -0.0840, -0.0200, -0.5686,  0.4613,  0.0057, -0.1011, -0.2511,
         -0.6345,  0.4575,  0.0240,  0.4279, -0.5957,  0.0820, -0.5342, -0.3385,
          0.2554,  0.3760, -0.2332,  0.3575, -0.6959,  0.3710,  0.3159, -0.5322,
          0.7105,  0.2933,  0.1016,  0.2927,  0.1187, -0.2822, -0.2821, -0.1209,
          0.0618,  0.2846,  0.0760, -0.1153, -0.0721,  0.3869,  0.1510,  0.0321,
          0.7199,  0.3929,  0.2044,  0.3253, -0.1251,  0.3145,  0.2683,  0.3081,
          0.1782, -0.2001,  0.1666, -0.0118,  0.1197,  0.3457,  0.7933,  0.1135,
         -0.2071,  0.0576, -0.3225, -0.1244,  0.3777,  0.2173, -0.1528,  0.2772,
          0.1458, -0.0610, -0.7065,  0.2422, -0.3388, -0.1959, -0.0714,  0.3345,
          0.5365,  0.1304, -0.2894,  0.4952, -0.3224,  0.5075, -0.7477, -0.3850,
          0.0277, -0.1673,  0.0037, -0.8024, -0.2407, -0.3995,  0.3362, -0.2044,
          0.4753,  0.2201,  0.3023,  0.4583,  0.5867, -0.5060, -0.1469,  0.4736,
          0.4303, -0.1274,  0.3662,  0.4385, -0.2217,  0.0379,  0.4506,  0.0816,
          0.1533,  0.0871, -0.2132, -0.5684, -0.0166,  0.6519, -0.1739,  0.2253,
          0.0837,  0.4929,  0.0724,  0.2116, -0.2263, -0.2795,  0.2216,  0.0188,
         -0.2020, -0.7023, -0.2023, -0.2113, -0.0123,  0.1137, -0.1966, -0.1871,
          0.0188, -0.0423,  0.2619, -0.3462,  0.4886, -0.0335, -0.1431, -0.6855,
          0.3971,  0.0438, -0.6949,  0.4389, -0.4733,  0.4454, -0.3445,  0.0646,
         -0.2358, -0.2089,  0.7600, -0.4973, -0.4342,  0.0079,  0.1723,  0.0921,
         -0.0841,  0.5822,  0.2439, -0.4994, -0.1663, -0.4547, -0.1404,  0.0726,
         -0.0211,  0.3611,  0.0262, -0.5103, -0.1072, -0.4839,  0.5536, -0.7744,
          0.4779, -0.1998,  0.4857,  0.3890,  0.2740,  0.0436,  0.0469,  0.1837,
         -0.0826,  0.3785,  0.0765,  0.2037,  0.4014,  0.2767,  0.1950,  0.0096,
         -0.2013,  0.3759, -0.5088,  0.1774, -0.2251, -0.8234, -0.4353, -0.1012,
          0.4806, -0.4940,  0.0107, -0.0193,  0.1832,  0.3600,  0.1512, -0.1332,
          0.2977, -0.1987,  0.0201, -0.0363, -0.1626,  0.3619, -0.7445,  0.1957,
         -0.3357, -0.1715, -0.4738,  0.0771, -0.1537, -0.4313, -0.0137, -0.0934,
          0.4966,  0.7250, -0.5224,  0.0362,  0.3701, -0.0547, -0.0383, -0.4689,
         -0.6313,  0.1891, -0.4483,  0.2346, -0.0071,  0.2558,  0.5207, -0.6604,
          0.1391, -0.3340,  0.7140, -0.7954,  0.1538,  0.1664,  0.4493, -0.3938,
          0.5326,  0.8514, -0.6008,  0.5339,  0.2108,  0.1381,  0.2696, -0.2429,
          0.1073,  0.4450,  0.4119, -0.3839, -0.0521, -0.0234, -0.2992, -0.5740,
         -0.4005, -0.1870,  0.5233,  0.6859,  0.4890, -0.1400,  0.2023,  0.3482,
         -0.2123,  0.0724, -0.3865,  0.7546, -0.1868, -0.5763, -0.1203, -0.1293,
          0.5864, -0.0392,  0.2688,  0.0660,  0.1680,  0.3938,  0.2407,  0.0487,
          0.3158, -0.0237, -0.3317, -0.0711, -0.5653,  0.2615,  0.0772, -0.0171,
          0.2291,  0.3454,  0.0991, -0.5713, -0.1992, -0.0625, -0.3561,  0.3915,
         -0.3596,  0.2565,  0.0120, -0.2206,  0.4874,  0.5524, -0.2768,  0.3071,
          0.4948,  0.1323, -0.4300,  0.1927,  0.2829, -0.0604, -0.2676, -0.0276,
          0.3177,  0.0729, -0.6886, -0.0284,  0.3048,  0.0920,  0.2971, -0.0899]],
       grad_fn=<TanhBackward0>)

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_131706/694873838.py in <module>
     35 input_id_list = input_ids[0].tolist() # Batch index 0
     36 tokens = tokenizer.convert_ids_to_tokens(input_id_list)
---> 37 model_view(attention, tokens)

~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/model_view.py in model_view(attention, tokens, sentence_b_start, prettify_tokens, display_mode, encoder_attention, decoder_attention, cross_attention, encoder_tokens, decoder_tokens, include_layers, include_heads, html_action)
     60             raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This"
     61                              " argument is only for self-attention models.")
---> 62         n_heads = num_heads(attention)
     63         if include_layers is None:
     64             include_layers = list(range(num_layers(attention)))

~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/util.py in num_heads(attention)
     24 
     25 def num_heads(attention):
---> 26     return attention[0][0].size(0)
     27 
     28 

IndexError: Dimension specified as 0 but tensor has no dimensions


(Note, it is not raising that ValueError on line 60 of model_view.py. The error is on line 62.)

And at that point I am stuck. Any ideas? Thanks a lot!

mepster avatar Feb 15 '23 02:02 mepster

Hi @mepster I don't specifically know why this error is raised but I suggest trying BertModel and BertTokenizer instead of the AutoModel and AutoTokenizer. When you comment out the output_attention then there is no attention for visualizing. So avoid commenting it out and try the BertModel hopefully it will fix it up.

Moeinh77 avatar Feb 16 '23 17:02 Moeinh77