DNABERT
DNABERT copied to clipboard
Readme section 5.2
The README outlines installation and a tutorial for use.
Section 5 says:
Visualiazation of DNABERT consists of 2 steps. Calcualate attention scores and Plot.
However section 5.2 seems to be incomplete. It consists only of:
####5.2 Plotting tool
Do you have any further instructions how to use the plotting tool?
Thanks for making DNABERT available!
P.S. The following generates one plot for me. Not sure I know what it is. :-)
pip install matplotlib
pip install seaborn
cd data_process_template
export KMER=6
export MODEL_PATH=../ft/$KMER
python ../visualize.py --model_path $MODEL_PATH --kmer $KMER

You can use BerViz as well for visualization. DNABert is available on the HuggingFace platform, find it here: https://huggingface.co/zhihan1996. Load it using HiggingFace then use BerViz for visualization of weights. I assume you can load it like this:
from transformers import AutoTokenizer, AutoModel
from bertviz import head_view, model_view
tokenizer = auto_tokenizer.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', do_lower_case=False)
model = auto_model.from_pretrained('https://huggingface.co/zhihan1996/DNA_bert_6', output_attentions=True)
def call_html():
display(IPython.core.display.HTML('''
<script src="/static/components/requirejs/require.js"></script>
<script>
requirejs.config({
paths: {
base: '/static/base',
"d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
},
});
</script>
'''))
sequence = # your sequence
call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids.to(device))[-1]
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
model_view(attention, tokens)
I haven't run this code, however, I have written this code and it works with ESM and ProtTrans models, I assume it works with DNABert as well. Though there might be slight differences that you should handle, but this is the general way of doing it.
Thanks @Moeinh77 !
I would greatly appreciate your wise feedback on the following.
If I modify your code to this
%matplotlib widget
from bertviz import head_view, model_view
from transformers import AutoTokenizer, AutoModel
import IPython
# I downloaded the model to a local directory 'DNA_bert_6'
tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True)
def call_html():
display(IPython.core.display.HTML('''
<script src="/static/components/requirejs/require.js"></script>
<script>
requirejs.config({
paths: {
base: '/static/base',
"d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
},
});
</script>
'''))
sequence = 'CACAGCACAGCCCAGCCAAGCCAGGCCAGCCCAGCCCAGCCAAGCCACGCCACTCCACTACACTAGACTAGGCTAGGCTAGGCCAGGCCCGGCCCTGCCCTGCCCTGTCCTGTCCTGTCCTGTCCTGTCCTGTCCTGCCCTGCACTGCAGTGCAGCGCAGCCCAGCCCAGCCCCGCCCCCCCCCCTCCCCTGCCCTGTCCTGTACTGTAGTGTAGGGTAGGGTAGGGGAGGGGTGGGGTCGGGTCTGGTCTGGTCTGGTCTGGACTGGAATGGAACGGAACAGAACAGAACAGCACAGCCCAGCCAAGCCAGGCCAGGCCAGGACAGGAGAGGAGTGGAGTGGAGTGGAGTGGTGTGGTTTGGTTTGGTTTAGTTTAATTTAAGTTAAGATAAGAGAAGAGGAGAGGCGAGGCAAGGCAGGGCAGGGCAGGGCAGGGGAGGGGAGGGGAGGGGAGTGGAGTCGAGTCGAGTCGCGTCGCCTCGCCTCGCCTTGCCTTGCCTTGCCTTGCCTTGCCCTGCCCTGCCCTGCCCTGTCCTGTGCTGTGCTGTGCCGTGCCATGCCACGCCACACCACAC'
#print(sequence, len(sequence))
call_html()
inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
device = 'cpu'
attention = model(input_ids.to(device))[-1]
print(attention)
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
model_view(attention, tokens)
I get the error
============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_131706/782624532.py in <module>
7 # Also note I commented out "output_attention=True"
8 tokenizer = AutoTokenizer.from_pretrained('DNA_bert_6', do_lower_case=False)
----> 9 model = AutoModel.from_pretrained('DNA_bert_6', output_attention=True)
10
11 def call_html():
~/Repos/DNABERT/src/transformers/modeling_auto.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
382 for config_class, model_class in MODEL_MAPPING.items():
383 if isinstance(config, config_class):
--> 384 return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
385 raise ValueError(
386 "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
~/Repos/DNABERT/src/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
645
646 # Instantiate model.
--> 647 model = cls(config, *model_args, **model_kwargs)
648
649 if state_dict is None and not from_tf:
TypeError: __init__() got an unexpected keyword argument 'output_attention'
But if I comment out 'output_attention' like this:
model = AutoModel.from_pretrained('DNA_bert_6')#, output_attention=True)
then I get the error:
============================================================
<class 'transformers.tokenization_bert.BertTokenizerFast'>
tensor([[ 0.6328, -0.1791, 0.0956, -0.5731, -0.1763, 0.0412, 0.0398, -0.1986,
-0.1466, 0.0585, -0.3806, 0.3187, 0.2130, -0.2569, 0.0940, -0.2212,
-0.1783, 0.4327, -0.1394, 0.3384, 0.0865, -0.5700, -0.2146, -0.8339,
0.5354, 0.2046, -0.1657, -0.3046, 0.0650, -0.3055, -0.1878, 0.3082,
-0.3724, 0.1021, -0.1528, 0.0377, -0.0578, -0.0887, 0.0704, -0.1101,
-0.0773, -0.7024, 0.5353, 0.3198, 0.3903, -0.2056, -0.5153, 0.4037,
0.3320, 0.1946, 0.5429, -0.2637, 0.6125, -0.0789, 0.1112, 0.7812,
0.0152, -0.0238, -0.1098, -0.4792, -0.4241, -0.0332, -0.0680, 0.0209,
0.3960, -0.3283, 0.3168, -0.3727, -0.6130, 0.4901, 0.1921, 0.1362,
0.2909, 0.3878, 0.1141, 0.2536, 0.3613, 0.1088, 0.1635, 0.6388,
0.0469, -0.0086, 0.4082, 0.0579, -0.2604, 0.0479, -0.1125, -0.0638,
0.0445, -0.3106, 0.5073, 0.7865, -0.6308, -0.1794, -0.2932, 0.3073,
0.7342, -0.5972, -0.1660, -0.2717, 0.0208, 0.4985, 0.0841, 0.2110,
-0.1882, -0.5626, -0.0462, 0.3095, 0.3285, -0.3160, -0.6916, 0.4086,
-0.4803, 0.3591, 0.7099, -0.4982, -0.0836, 0.0289, -0.0527, 0.0656,
0.0248, -0.0470, -0.1063, -0.2031, 0.1532, -0.6646, -0.2786, -0.1763,
0.2845, 0.0204, 0.0968, 0.4840, -0.1410, 0.0664, 0.7044, -0.1921,
-0.0878, 0.7824, -0.0782, -0.3539, -0.4634, 0.1302, 0.0379, -0.3873,
0.4631, 0.6339, 0.3224, -0.3064, 0.7668, -0.6782, -0.7405, 0.2909,
0.3727, 0.0092, 0.4768, -0.3835, 0.0044, 0.3539, -0.2782, 0.0300,
-0.4242, -0.0869, 0.2858, 0.1832, -0.2757, -0.0408, 0.3069, 0.3939,
-0.2805, -0.3999, -0.6797, 0.5870, 0.3561, 0.1782, 0.0386, 0.2089,
-0.4678, 0.5330, -0.0612, -0.0711, -0.3962, 0.3030, 0.2728, 0.7759,
-0.2187, 0.4681, 0.8431, -0.1215, 0.0179, -0.3706, 0.2816, 0.3015,
0.4501, -0.5754, 0.0781, 0.1056, 0.2775, -0.5127, 0.2704, -0.0778,
-0.4359, -0.3127, -0.3191, -0.0699, 0.1088, 0.1312, 0.2286, -0.1017,
-0.4808, 0.0402, -0.4942, -0.0275, 0.3092, 0.3350, 0.5882, 0.3709,
0.3021, 0.5267, 0.0873, -0.1373, -0.3230, -0.4778, -0.1138, -0.2788,
0.3652, -0.1727, 0.4334, 0.3024, -0.5430, 0.1484, 0.3245, -0.1186,
0.6870, 0.4794, 0.2478, 0.2582, -0.3006, 0.1935, -0.6847, 0.2272,
-0.2336, -0.2726, 0.3207, -0.1575, -0.4818, -0.5784, -0.1170, -0.1105,
-0.2947, 0.0261, 0.0238, -0.2571, -0.1420, 0.1536, 0.0366, -0.2715,
0.2502, -0.4401, -0.3884, 0.4987, -0.1169, -0.3366, -0.4294, 0.1843,
0.0435, -0.0389, 0.1158, 0.0505, 0.5883, -0.4889, 0.3945, 0.6702,
-0.1314, 0.5126, 0.1850, 0.0217, 0.3293, 0.3679, 0.0524, 0.4328,
0.1084, -0.0995, -0.3061, -0.4815, 0.8196, -0.1747, -0.7363, -0.3690,
-0.4089, -0.3178, -0.3008, 0.0842, 0.5494, -0.3024, 0.0686, -0.0024,
0.5566, -0.4074, -0.1000, 0.5726, 0.2140, 0.3414, 0.2051, -0.5276,
-0.0900, -0.4059, 0.6995, -0.6243, 0.1423, 0.0790, -0.5684, -0.7503,
0.0812, 0.2904, -0.0825, -0.5300, 0.2102, 0.2130, 0.1844, -0.2488,
-0.3667, -0.2187, -0.0555, 0.1509, -0.5548, -0.1680, -0.1770, -0.3859,
0.2160, 0.6033, -0.0487, 0.0178, 0.2910, 0.6036, -0.5233, 0.0991,
-0.3996, -0.5462, 0.4345, -0.0706, -0.1149, 0.6331, -0.4878, -0.2699,
0.3667, -0.1658, -0.2681, -0.4240, -0.2277, -0.0285, -0.4568, -0.5831,
0.2132, -0.4749, -0.3003, 0.2729, 0.2058, 0.2281, -0.2277, 0.4400,
0.3480, 0.4202, -0.1285, -0.0944, 0.3023, 0.1910, -0.0172, -0.3824,
0.2430, 0.6200, -0.2266, -0.2784, 0.4599, 0.2767, -0.4933, -0.0863,
0.6689, 0.5872, 0.1265, -0.2042, 0.0360, 0.5149, -0.3186, -0.2453,
-0.8713, 0.5262, 0.4512, -0.4293, 0.2160, -0.2895, 0.5623, -0.0468,
0.3609, 0.0478, 0.5212, 0.3534, 0.6020, -0.5200, -0.0011, 0.1896,
0.1710, -0.4412, 0.3991, 0.3104, -0.1117, 0.1043, 0.2124, -0.6773,
0.3843, -0.1273, 0.6962, 0.5702, -0.2138, 0.4218, -0.1075, -0.1051,
0.7056, 0.5510, -0.2203, -0.4441, -0.1577, 0.0099, 0.7935, 0.0558,
-0.3306, 0.4195, 0.2848, -0.0217, 0.3778, 0.1274, 0.0162, -0.5062,
-0.1455, -0.1946, -0.3807, -0.5371, -0.1086, -0.3016, -0.4707, -0.4386,
-0.2609, 0.2080, 0.5916, -0.1465, 0.6091, -0.3506, 0.5574, 0.1640,
-0.0207, -0.0840, -0.0200, -0.5686, 0.4613, 0.0057, -0.1011, -0.2511,
-0.6345, 0.4575, 0.0240, 0.4279, -0.5957, 0.0820, -0.5342, -0.3385,
0.2554, 0.3760, -0.2332, 0.3575, -0.6959, 0.3710, 0.3159, -0.5322,
0.7105, 0.2933, 0.1016, 0.2927, 0.1187, -0.2822, -0.2821, -0.1209,
0.0618, 0.2846, 0.0760, -0.1153, -0.0721, 0.3869, 0.1510, 0.0321,
0.7199, 0.3929, 0.2044, 0.3253, -0.1251, 0.3145, 0.2683, 0.3081,
0.1782, -0.2001, 0.1666, -0.0118, 0.1197, 0.3457, 0.7933, 0.1135,
-0.2071, 0.0576, -0.3225, -0.1244, 0.3777, 0.2173, -0.1528, 0.2772,
0.1458, -0.0610, -0.7065, 0.2422, -0.3388, -0.1959, -0.0714, 0.3345,
0.5365, 0.1304, -0.2894, 0.4952, -0.3224, 0.5075, -0.7477, -0.3850,
0.0277, -0.1673, 0.0037, -0.8024, -0.2407, -0.3995, 0.3362, -0.2044,
0.4753, 0.2201, 0.3023, 0.4583, 0.5867, -0.5060, -0.1469, 0.4736,
0.4303, -0.1274, 0.3662, 0.4385, -0.2217, 0.0379, 0.4506, 0.0816,
0.1533, 0.0871, -0.2132, -0.5684, -0.0166, 0.6519, -0.1739, 0.2253,
0.0837, 0.4929, 0.0724, 0.2116, -0.2263, -0.2795, 0.2216, 0.0188,
-0.2020, -0.7023, -0.2023, -0.2113, -0.0123, 0.1137, -0.1966, -0.1871,
0.0188, -0.0423, 0.2619, -0.3462, 0.4886, -0.0335, -0.1431, -0.6855,
0.3971, 0.0438, -0.6949, 0.4389, -0.4733, 0.4454, -0.3445, 0.0646,
-0.2358, -0.2089, 0.7600, -0.4973, -0.4342, 0.0079, 0.1723, 0.0921,
-0.0841, 0.5822, 0.2439, -0.4994, -0.1663, -0.4547, -0.1404, 0.0726,
-0.0211, 0.3611, 0.0262, -0.5103, -0.1072, -0.4839, 0.5536, -0.7744,
0.4779, -0.1998, 0.4857, 0.3890, 0.2740, 0.0436, 0.0469, 0.1837,
-0.0826, 0.3785, 0.0765, 0.2037, 0.4014, 0.2767, 0.1950, 0.0096,
-0.2013, 0.3759, -0.5088, 0.1774, -0.2251, -0.8234, -0.4353, -0.1012,
0.4806, -0.4940, 0.0107, -0.0193, 0.1832, 0.3600, 0.1512, -0.1332,
0.2977, -0.1987, 0.0201, -0.0363, -0.1626, 0.3619, -0.7445, 0.1957,
-0.3357, -0.1715, -0.4738, 0.0771, -0.1537, -0.4313, -0.0137, -0.0934,
0.4966, 0.7250, -0.5224, 0.0362, 0.3701, -0.0547, -0.0383, -0.4689,
-0.6313, 0.1891, -0.4483, 0.2346, -0.0071, 0.2558, 0.5207, -0.6604,
0.1391, -0.3340, 0.7140, -0.7954, 0.1538, 0.1664, 0.4493, -0.3938,
0.5326, 0.8514, -0.6008, 0.5339, 0.2108, 0.1381, 0.2696, -0.2429,
0.1073, 0.4450, 0.4119, -0.3839, -0.0521, -0.0234, -0.2992, -0.5740,
-0.4005, -0.1870, 0.5233, 0.6859, 0.4890, -0.1400, 0.2023, 0.3482,
-0.2123, 0.0724, -0.3865, 0.7546, -0.1868, -0.5763, -0.1203, -0.1293,
0.5864, -0.0392, 0.2688, 0.0660, 0.1680, 0.3938, 0.2407, 0.0487,
0.3158, -0.0237, -0.3317, -0.0711, -0.5653, 0.2615, 0.0772, -0.0171,
0.2291, 0.3454, 0.0991, -0.5713, -0.1992, -0.0625, -0.3561, 0.3915,
-0.3596, 0.2565, 0.0120, -0.2206, 0.4874, 0.5524, -0.2768, 0.3071,
0.4948, 0.1323, -0.4300, 0.1927, 0.2829, -0.0604, -0.2676, -0.0276,
0.3177, 0.0729, -0.6886, -0.0284, 0.3048, 0.0920, 0.2971, -0.0899]],
grad_fn=<TanhBackward0>)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
/tmp/ipykernel_131706/694873838.py in <module>
35 input_id_list = input_ids[0].tolist() # Batch index 0
36 tokens = tokenizer.convert_ids_to_tokens(input_id_list)
---> 37 model_view(attention, tokens)
~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/model_view.py in model_view(attention, tokens, sentence_b_start, prettify_tokens, display_mode, encoder_attention, decoder_attention, cross_attention, encoder_tokens, decoder_tokens, include_layers, include_heads, html_action)
60 raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This"
61 " argument is only for self-attention models.")
---> 62 n_heads = num_heads(attention)
63 if include_layers is None:
64 include_layers = list(range(num_layers(attention)))
~/anaconda3/envs/dnabert37/lib/python3.7/site-packages/bertviz/util.py in num_heads(attention)
24
25 def num_heads(attention):
---> 26 return attention[0][0].size(0)
27
28
IndexError: Dimension specified as 0 but tensor has no dimensions
(Note, it is not raising that ValueError on line 60 of model_view.py. The error is on line 62.)
And at that point I am stuck. Any ideas? Thanks a lot!
Hi @mepster I don't specifically know why this error is raised but I suggest trying BertModel and BertTokenizer instead of the AutoModel and AutoTokenizer. When you comment out the output_attention then there is no attention for visualizing. So avoid commenting it out and try the BertModel hopefully it will fix it up.