bert4keras
bert4keras copied to clipboard
请问如何取出Attention矩阵进行分析呢?
提问时请尽可能提供如下信息:
基本信息
- 你使用的操作系统: linux64
- 你使用的Python版本: 3.6
- 你使用的Tensorflow版本:1.14
- 你使用的Keras版本: 2.3.1
- 你使用的bert4keras版本: 0.8.3
- 你使用纯keras还是tf.keras: keras
- 你加载的预训练模型: BERT
想取出最后一层多头注意力的Attention计算结果进行分析,请问应该如何操作? 即注意力机制中的Q和K的点积矩阵如何获取?谢谢!
m = keras.Model(bert_model.input, bert_model.get_layer('LAYER_NAME').output)
m.predict([batch_token_ids, batch_segment_ids])
这个,了解一下keras.backend.function
,用它比较好实现~
请问楼主实现了吗