Comp-Agg (A Compare-Aggregate Model for Matching Text Sequences)
CompareAggregate研究意義:
1、采用“比較-聚合”框架,并對此進行改進
2、采用多種數據集驗證模型的泛化性
?
本文主要結構如下所示:
一、Abstract
? ? ? 摘要部分主要介紹本文利用詞嵌入作為輸入,CNN網絡作為聚合函數,提出比較聚合框架;關注于不同的比較函數來對文本向量進行匹配;并且使用不同的幾份數據評估模型;
基于element-wise的比較函數可能會比復雜神經網絡效果更好。
二、Introudction
? ? ? 首先提及了很多自然語言處理任務都需要對兩個或多個句子進行匹配,然后作出決定。
? ? ? ?
三、Method
? ? ? ? ? ? ?主要介紹模型的結構以及六個不同的比較函數
四、Experiment
? ? ? ? ? ? 實驗部分主要介紹不同比較函數以及組合函數在四個不同任務數據集合的效果,證明組合比較函數模型的有效性
五、Related Work
? ? ? ? ? ?相關工作部分簡單的描述了孿生網絡、注意力機制以及比較-聚合網絡相關的應用
六、Conclusions
? ? ? ? ? ? 最后一部分總結了本文系統分析“比較-聚合”模型在四個不同任務數據集上的有效性,此外還提出了詞級別的比較函數element-wise 比較函數表現好于其它函數,并且根據實驗結果很多不同任務可以共享“比較-聚合”結構,在未來的任務中,可以把它使用在多任務學習中。
? ? ? ? ?關鍵點: 采用“比較-聚合”框架;利用多種數據集證明模型的有效性;提出多種比較函數并探究了交互的最佳方式
? ? ? ? ?創新點: 利用門控單元提取語義特征,利用注意力機制完成句子權重匹配,利用向量的差和積進行特征提取
七、Code
# -*- coding: utf-8 -*-# @Time : 2021/2/14 下午2:07 # @Author : TaoWang # @Description : "比較-聚合" 模型結構import torch import torch.nn as nn import numpy as np from torch.utils.data import DataLoader, Dataset from torch.autograd import Variable# 預處理層 class Preprocess(nn.Module):def __init__(self, in_features, out_features):""":param in_features: :param out_features: """super().__init__()self.Wi = nn.Parameter(torch.randn(in_features, out_features))self.bi = nn.Parameter(torch.randn(out_features))self.wu = nn.Parameter(torch.randn(in_features, out_features))self.bu = nn.Parameter(torch.randn(out_features))def forward(self, x):""":param x: :return: """gate = torch.matmul(x, self.Wi)gate = torch.sigmoid(gate + self.bi.expand_as(date))out = torch.matmul(x, self.Wu)out = torch.tanh(out + self.bu.expand_as(out))return gate * out# 注意力層 class Attention(nn.Module):def __init__(self):super().__init__()self.wg = nn.Parameter(torch.randn(hidden_size, hidden_size))self.bg = nn.Parameter(torch.randn(hidden_size))def forward(self, q, a):""":param q: :param a: :return: """G = torch.matmul(q, self.wg)G = G + self.bg.expand_as(G)G = torch.matmul(G, a.permute(0, 2, 1))G = torch.softmax(G, dim=1)H = torch.matmul(G.permute(0, 2, 1), q)return H# 模型比較層 class Compare(nn.Module):def __init__(self):super().__init__()self.W = nn.Parameter(torch.randn(2*hidden_size, hidden_size))self.b = nn.Parameter(torch.randn(hidden_size))def forward(self, h, a):""":param h: :param a: :return: """sub = (h - a) * (h - a)mult = h * aT = torch.matmul(torch.cat([sub, mult], dim=2), self.W)T = torch.relu(T + self.b.expand_as(T))return T# 模型比較聚合層匯總class CompAgg(torch.nn.Module):def __init__(self):super(CompAgg, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_size)self.embedding.weight.data.copy_(torch.from_numpy(embed))self.preprocess = Preprocess(embedding_size, hidden_size)self.attention = Attention()self.compare = Compare()self.aggregate = nn.Conv1d(in_channels=max_len, out_channels=window, kernel_size=3, stride=1, padding=1)self.predict = nn.Linear(window * hidden_size, classes)def forward(self, q, a):""":param q: 設 q長度 20:param a: 設 a長度 40:return: """# emb_q: batch * 20 * 200, emb_a: batch * 40 * 200emb_q, emb_a = self.embedding(q), self.embedding(a)# q_bar: batch * 20 * 100, a_bar: batch * 40 * 100q_bar, a_bar = self.preprocess(emb_q), self.preprocess(emb_a)# H: batch * 40 * 100H = self.attention(q_bar, a_bar)# T: batch * 40 * 100T = self.compare(H, a_bar)# r: batch * 3 * 100r = self.aggregate(T)# r: batch * 300r = r.view(-1, window * hidden_size)# out: batch * 3out = self.predict(r)return out?
總結
以上是生活随笔為你收集整理的Comp-Agg (A Compare-Aggregate Model for Matching Text Sequences)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: SiameseNet(Learning
- 下一篇: ESIM (Enhanced LSTM