attention
attention copied to clipboard
模型加载问题
训练完后 加载模型时一直报错,找不到attention层,但我有import attention
目前我在网上找到的最好的解决办法是重写自定义层的get_config方法,例如:
class Attention(OurLayer):
"""多头注意力机制
"""
def __init__(self, heads, size_per_head, key_size=None,
mask_right=False, **kwargs):
super(Attention, self).__init__(**kwargs)
self.heads = heads
self.size_per_head = size_per_head
self.out_dim = heads * size_per_head
self.key_size = key_size if key_size else size_per_head
self.mask_right = mask_right
def get_config(self):
config = {'heads': self.heads, 'size_per_head': self.size_per_head}
base_config = super(Attention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
具体可以参看链接
然后在导入的时候指定custom_objects参数,例如:
model = load_model('model.h5', custom_objects={'Attention': Attention})
此版本不打算再维护了,如果需要最新的keras版attention,请到https://github.com/bojone/bert4keras/blob/master/bert4keras/layers.py 参考。