【论文复现与改进】针对弱标注数据多标签矩阵恢复问题,改进后的MCWD算法,让你的弱标注多标签数据赢在起跑线上
生活随笔
收集整理的這篇文章主要介紹了
【论文复现与改进】针对弱标注数据多标签矩阵恢复问题,改进后的MCWD算法,让你的弱标注多标签数据赢在起跑线上
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
改進后的MCWD算法,讓你的弱標注多標簽數據贏在起跑線上
- 前言
- MCWD算法
- 算法展示
- 算法改進
- 實現代碼
- 實驗結果
- 總結
前言
最近刷完了李航老師的《統計學習與方法》,手癢到又想復現幾個算法,正好碰上在云音樂的云村視頻標簽運維標注不完全問題,也算是弱標注數據吧,之前這比數據作了多標簽分類,盡管特征上線后各項數據都有所提升,但總感覺用神經網絡直接對弱標簽數據進行多標簽分類很不舒服。
基于以下兩個思考點:
MCWD算法
其實在尋找辦法的過程中也看到了周志華團隊的《Learning from Semi-Supervised Weak-Label Data》無奈其中數學公式復雜,盡管看懂了原理,但用python實現起來還是有一定的難度。
最后將目標鎖定在了《針對弱標記數據的多標簽分類算法(王晶晶,楊有龍)》(原文請大家自行搜索)文章利用了樣本間的加權KNN + 標簽相關性的二階策略 對弱標注數據的標簽矩陣進行恢復,再通過多標簽分類算法進行標注。
算法展示
原論文中算法具體步驟如下:
算法改進
實現代碼
import numpy as np import pandas as pd from sklearn.metrics import f1_score,hamming_lossdef data_get_2(path,p):data = pd.read_csv(path)feature = np.array(data.iloc[:,:103])tag = np.array(data.iloc[:,103:])real_tag = tag.copy()for i in range(len(tag)):if len(np.where(tag[i,:] == 1)[0]) > p:index_list = np.random.choice(np.where(tag[i,:] == 1)[0],len(np.where(tag[i, :] == p)[0])-p,replace=False)tag[i,:][index_list] = 0'''隨機drop掉每個樣本的標簽,至每個樣本最多擁有p個標簽。'''return feature,tag,real_tagdef data_get_3(path,p):data = pd.read_csv(path)feature = np.array(data.iloc[:,1:1187])tag = np.array(data.iloc[:,1187:])real_tag = tag.copy()for i in range(len(tag)):if len(np.where(tag[i,:] == 1)[0]) > p:index_list = np.random.choice(np.where(tag[i,:] == 1)[0],len(np.where(tag[i, :] == p)[0])-p,replace=False)tag[i,:][index_list] = 0return feature,tag,real_tagclass MCWD:def __init__(self,e=0.8,c=0.2,k_t=10,s=1,w_e=0.3,rate=0.5):self.e = eself.c = cself.k_t = k_tself.s = sself.w_e = w_eself.rate = ratedef standetlize(self):'''涉及到距離計算,需要將特征標準化'''for i in range(self.feature_dim):mean = np.mean(self.feature[:,i])sii = 1/(self.sample_num-1) * np.sum((self.feature[:,i] - mean)**2)self.feature[:,i] = (self.feature[:,i]-mean)/siidef add_nagtive_tag(self):'''訓練數據的標簽信息中添加一些不相關的標簽信息'''for j in range(self.tag_c):zeros_way = np.where(self.old_tag_list[:,j] == 0)[0]pj = int(len(zeros_way) * self.rate)if len(zeros_way) < pj:index_list = zeros_wayelse:index_list = np.random.choice(np.where(self.old_tag_list[:,j] == 0)[0],pj,replace=False)self.tag_list[:,j][index_list] = -1def cauclate_dis(self,x1,x2):return np.sum(abs(x1-x2)**2)**(1/2)def get_dis_matrix(self):# self.dis_matrix = np.zeros((self.sample_num,self.sample_num))# for i in range(self.sample_num):# print(i)# for j in range(self.sample_num):# if i < j:# dis = self.cauclate_dis(self.feature[i],self.feature[j])# self.dis_matrix[i][j] = dis# self.dis_matrix[j][i] = dis'''快速計算樣本之間的距離矩陣'''G = np.dot(self.feature, self.feature.T)H = np.tile(np.diag(G), (self.sample_num, 1)) # n rows, 1 for each rowD = H + H.T - G * 2self.dis_matrix = np.sqrt(D)def one_iter(self,t):n_near = t * self.k_tfor index in range(self.sample_num):k_near_index = self.dis_matrix[index].argsort()[1:n_near+1]l_w = np.abs(self.tag_W[k_near_index])t_l = self.tag_list[k_near_index]self.tag_list[index] = np.sum(l_w * t_l,0)/(np.sum(l_w,0)+1e-8) #[-1,1]for index in range(self.sample_num):for j in range(self.tag_c):qij = self.tag_list[index,j]wij = self.tag_W[index,j]if np.sign(qij) == np.sign(wij) and np.abs(qij) > self.e and np.abs(wij) > self.e:self.tag_W[index,j] = np.sign(qij)elif np.sign(qij) == -1 and np.sign(self.old_tag_list[index,j]) == 1:self.tag_W[index, j] = self.c * (qij - np.min(self.tag_list[:,j]))/(np.max(self.tag_list[:,j]) - np.min(self.tag_list[:,j]))else:self.tag_W[index, j] = qijif qij > 0:self.tag_list[index, j] = 1elif qij < 0:self.tag_list[index, j] = -1else:self.tag_list[index, j] = 0self.tag_list[np.where(self.old_tag_list == 1)] = 1def compute(self):'''迭代結束后,根據W和C矩陣對標簽矩陣進行恢復和補充'''for index in range(self.sample_num):for j in range(self.tag_c):if self.tag_list[index, j] * self.tag_W[index, j] > self.w_e:self.tag_list[index, j] = np.sign(self.tag_list[index, j])else:self.tag_list[index, j] = 0self.tag_list[np.where(self.old_tag_list == 1)] = 1def get_L_matricx(self):'''計算標簽相關矩陣'''self.L_matricx = np.zeros((self.tag_c,self.tag_c))for i in range(self.tag_c):for j in range(self.tag_c):if i >= j:a = np.sum((self.tag_list[:,i] == 1)*(self.tag_list[:,i] == 1)) + self.sb = np.sum(self.tag_list[:,i] == 1) + 2 * self.sc = np.sum(self.tag_list[:,j] == 1) + 2 * self.sself.L_matricx[i][j] = a/bself.L_matricx[j][i] = a/cdef re_20_p(self):'''根據標簽相關矩陣補全剩余標簽'''self.get_L_matricx()for i,j in zip(np.where(self.tag_list == 0)[0],np.where(self.tag_list == 0)[1]):qij = self.tag_list[i].T.dot(self.L_matricx[:,j])qij = (qij - np.min(self.tag_list[:, j])) / (np.max(self.tag_list[:, j]) - np.min(self.tag_list[:, j]))if qij > 0.5:self.tag_list[i][j] = 1else:self.tag_list[i][j] = -1def fit(self,feature,tag_list,max_iter=100,if_s=True):self.feature = np.array(feature,dtype=float)self.tag_list = np.array(tag_list,dtype=float)self.old_tag_list = np.array(tag_list)self.feature_dim = self.feature.shape[1]self.sample_num = self.feature.shape[0]self.tag_c = self.tag_list.shape[1]if if_s:print('standetlizing')self.standetlize()print('finish')print('add_nagtive_tag')self.add_nagtive_tag()print('finish')self.tag_W = self.tag_list.copy()print('get_dis_matrix')self.get_dis_matrix()print('finish')for iter in range(max_iter):print(iter)self.one_iter(iter+1)num = len(np.where((self.tag_list * self.tag_W) <= self.w_e)[0])print(num/(self.sample_num * self.tag_c))if num < self.sample_num * self.tag_c * 0.2:self.compute()breakself.re_20_p()def main():# ------Yeast數據集--------------------------------------------------feature,tag,real_tag = data_get_2('./yeast.csv',1)mcwd = MCWD(e=0.8,c=0.2,k_t=10,s=1,w_e=0.5,rate=0.25)mcwd.fit(feature,tag)pred_tag = mcwd.tag_listpred_tag[np.where(pred_tag == -1)] = 0print(np.sum(real_tag) - np.sum(tag))print(np.sum(pred_tag[np.where(real_tag==1)]) - np.sum(tag[np.where(real_tag==1)]))print(np.sum(pred_tag[np.where(real_tag==0)]))print("f1_macro")print(f1_score(real_tag,tag, average='macro'))print(f1_score(real_tag,pred_tag,average='macro'))print("f1_micro")print(f1_score(real_tag,tag, average='micro'))print(f1_score(real_tag,pred_tag,average='micro'))print("hamming_loss")print(hamming_loss(real_tag,tag))print(hamming_loss(real_tag,pred_tag))#------genbase數據集--------------------------------------------------feature, tag, real_tag = data_get_3('./file27546ea66c1.csv', 1)mcwd = MCWD(e=0.8, c=0.2, k_t=5, s=1, w_e=0.5, rate=0.25)mcwd.fit(feature,tag,if_s=False)pred_tag = mcwd.tag_listpred_tag[np.where(pred_tag == -1)] = 0print(np.sum(real_tag) - np.sum(tag))print(np.sum(pred_tag[np.where(real_tag == 1)]) - np.sum(tag[np.where(real_tag == 1)]))print(np.sum(pred_tag[np.where(real_tag == 0)]))print("f1_macro")print(f1_score(real_tag,tag, average='macro'))print(f1_score(real_tag,pred_tag,average='macro'))print("f1_micro")print(f1_score(real_tag,tag, average='micro'))print(f1_score(real_tag,pred_tag,average='micro'))print("hamming_loss")print(hamming_loss(real_tag,tag))print(hamming_loss(real_tag,pred_tag))if __name__ == '__main__':main()實驗結果
總結
根據實驗結果可以發現,經過算法的補全,與補全前的標簽矩陣比,F1值、hamming_loss等指標都有不同程度提升,證明了該算法的有效性。
特別的:
總結
以上是生活随笔為你收集整理的【论文复现与改进】针对弱标注数据多标签矩阵恢复问题,改进后的MCWD算法,让你的弱标注多标签数据赢在起跑线上的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 信息抽取(四)【NLP论文复现】Mult
- 下一篇: 【论文阅读】开放域问答论文总结,文本召回