当前位置: 首页 > 工具软件 > Han > 使用案例 >

【pytorch模型实现9】HAN_Attention

周奇文
2023-12-01

HAN_Attention模型实现

NLP模型代码github仓库:https://github.com/lyj157175/Models

import torch 
import torch.nn as nn 
from torch.autograd import Variable
import numpy as np 
from torch.nn import functional as F


class HAN_Attention(nn.Module):
    '''层次注意力网络文档分类模型实现,词向量,句子向量'''
    def __init__(self, vocab_size, embedding_dim, gru_size, class_num, weights=None, is_pretrain=False):
        super(HAN_Attention, self).__init__()
        if is_pretrain:
            self.word_embed = nn.Embedding.from_pretrained(weights, freeze=False)
        else:
            self.word_embed = nn.Embedding(vocab_size, embedding_dim)
        # 词注意力
        self.word_gru = nn.GRU(input_size=embedding_dim, hidden_size=gru_size, num_layers=1, bidirectional=True, batch_first=True)
        self.word_query = nn.Parameter(torch.Tensor(2*gru_size, 1), requires_grad=True)   # 公式中的u(w)  
        self.word_fc = nn.Linear(2*gru_size, 2*gru_size)
        # 句子注意力
        self.sentence_gru = nn.GRU(input_size=2*gru_size, hidden_size=gru_size, num_layers=1, bidirectional=True, batch_first=True)
        self.sentence_query = nn.Parameter(torch.Tensor(2*gru_size, 1), requires_grad=True)   # 公式中的u(s)
        self.sentence_fc = nn.Linear(2*gru_size, 2*gru_size)
        # 文档分类
        self.class_fc = nn.Linear(2*gru_size, class_num)

    def forward(self, x, use_gpu=False):  # x: b, sentence_num, sentence_len
        sentence_num = x.size(1)
        sentence_len = x.size(2)
        x = x.view(-1, sentence_len)  # b*sentence_num, sentence_len
        embed_x = self.word_embed(x)  # b*sentence_num , sentence_len, embedding_dim
        word_output, word_hidden = self.word_gru(embed_x)  # word_output: b*sentence_num, sentence_len, 2*gru_size
        # 计算u(it)
        word_attention = torch.tanh(self.word_fc(word_output))  # b*sentence_num, sentence_len, 2*gru_size
        # 计算词注意力向量weights: a(it)
        weights = torch.matmul(word_attention, self.word_query)  # b*sentence_num, sentence_len, 1
        weights = F.softmax(weights, dim=1)   # b*sentence_num, sentence_len, 1

        x = x.unsqueeze(2)  # b*sentence_num, sentence_len, 1
        if use_gpu:
            # 去掉x中padding为0位置的attention比重
            weights = torch.where(x!=0, weights, torch.full_like(x, 0, dtype=torch.float).cuda()) #b*sentence_num, sentence_len, 1
        else:
            weights = torch.where(x!=0, weights, torch.full_like(x, 0, dtype=torch.float))
        # 将x中padding后的结果进行归一化处理,为了避免padding处的weights为0无法训练,加上一个极小值1e-4
        weights = weights / (torch.sum(weights, dim=1).unsqueeze(1) + 1e-4)  # b*sentence_num, sentence_len, 1
        
        # 计算句子向量si = sum(a(it) * h(it)) : b*sentence_num, 2*gru_size -> b*, sentence_num, 2*gru_size
        sentence_vector = torch.sum(weights * word_output, dim=1).view(-1, sentence_num, word_output.size(2))

        sentence_output, sentence_hidden = self.sentence_gru(sentence_vector)  # sentence_output: b, sentence_num, 2*gru_size
        # 计算ui
        sentence_attention = torch.tanh(self.sentence_fc(sentence_output))  # sentence_output: b, sentence_num, 2*gru_size
        # 计算句子注意力向量sentence_weights: a(i)
        sentence_weights = torch.matmul(sentence_attention, self.sentence_query)   # sentence_output: b, sentence_num, 1
        sentence_weights = F.softmax(sentence_weights, dim=1)   # b, sentence_num, 1

        x = x.view(-1, sentence_num, x.size(1))   # b, sentence_num, sentence_len
        x = torch.sum(x, dim=2).unsqueeze(2)  # b, sentence_num, 1
        if use_gpu:
            sentence_weights = torch.where(x!=0, sentence_weights, torch.full_like(x, 0, dtype=torch.float).cuda())  
        else:
            sentence_weights = torch.where(x!=0, sentence_weights, torch.full_like(x, 0, dtype=torch.float))  # b, sentence_num, 1
        sentence_weights = sentence_weights / (torch.sum(sentence_weights, dim=1).unsqueeze(1) + 1e-4)  # b, sentence_num, 1 

        # 计算文档向量v
        document_vector = torch.sum(sentence_weights * sentence_output, dim=1)   # b, sentence_num, 2*gru_size
        document_class = self.class_fc(document_vector)   # b, sentence_num, class_num
        return document_class


if __name__ == '__main__':
    model = HAN_Attention(3000, 200, 50, 4)
    x = torch.zeros(64, 50, 100).long()   # b, sentence_num, sentence_len
    x[0][0][0:10] = 1
    document_class = model(x)
    print(document_class.shape)  # 64, 4






 类似资料: