attention icon indicating copy to clipboard operation
attention copied to clipboard

模型加载问题

Open xiong3134 opened this issue 4 years ago • 2 comments

训练完后 加载模型时一直报错,找不到attention层,但我有import attention

xiong3134 avatar Sep 24 '19 09:09 xiong3134

目前我在网上找到的最好的解决办法是重写自定义层的get_config方法,例如:

class Attention(OurLayer):
    """多头注意力机制
    """
    def __init__(self, heads, size_per_head, key_size=None,
                 mask_right=False, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.heads = heads
        self.size_per_head = size_per_head
        self.out_dim = heads * size_per_head
        self.key_size = key_size if key_size else size_per_head
        self.mask_right = mask_right
    def get_config(self):
        config = {'heads': self.heads, 'size_per_head': self.size_per_head}
        base_config = super(Attention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

具体可以参看链接

然后在导入的时候指定custom_objects参数,例如:

model = load_model('model.h5', custom_objects={'Attention': Attention})

Chiang97912 avatar Aug 24 '20 03:08 Chiang97912

此版本不打算再维护了,如果需要最新的keras版attention,请到https://github.com/bojone/bert4keras/blob/master/bert4keras/layers.py 参考。

bojone avatar Aug 24 '20 03:08 bojone