Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

模型加载问题 #16

Open
xiong3134 opened this issue Sep 24, 2019 · 2 comments
Open

模型加载问题 #16

xiong3134 opened this issue Sep 24, 2019 · 2 comments

Comments

@xiong3134
Copy link

训练完后 加载模型时一直报错,找不到attention层,但我有import attention

@Chiang97912
Copy link

Chiang97912 commented Aug 24, 2020

目前我在网上找到的最好的解决办法是重写自定义层的get_config方法,例如:

class Attention(OurLayer):
    """多头注意力机制
    """
    def __init__(self, heads, size_per_head, key_size=None,
                 mask_right=False, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.heads = heads
        self.size_per_head = size_per_head
        self.out_dim = heads * size_per_head
        self.key_size = key_size if key_size else size_per_head
        self.mask_right = mask_right
    def get_config(self):
        config = {'heads': self.heads, 'size_per_head': self.size_per_head}
        base_config = super(Attention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

具体可以参看链接

然后在导入的时候指定custom_objects参数,例如:

model = load_model('model.h5', custom_objects={'Attention': Attention})

@bojone
Copy link
Owner

bojone commented Aug 24, 2020

此版本不打算再维护了,如果需要最新的keras版attention,请到https://github.com/bojone/bert4keras/blob/master/bert4keras/layers.py 参考。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants