BPG-MF学习笔记
論文及代碼出處
論文原文:Beyond Alternating Updates for Matrix Factorization with Inertial Bregman Proximal Gradient Algorithms
補(bǔ)充材料下載鏈接:https://proceedings.neurips.cc/paper/2019/file/bc7f621451b4f5df308a8e098112185d-Supplemental.zip
代碼出處:https://github.com/mmahesh/cocain-bpg-matrix-factorization
BPG-MF算法
算法流程
無(wú)正則
根據(jù)算法流程對(duì) P k P^k Pk和 Q k Q^k Qk進(jìn)一步推導(dǎo),可以得到無(wú)正則項(xiàng)的BPG-MF算法為如下形式:
L2正則
代碼結(jié)構(gòu)
??作者提供的程序包實(shí)現(xiàn)了BPG-MF、CoCaIn BPG-MF、
BPG-MF-WB、PALM和iPALM五種算法,可以通過(guò)修改main.py文件中的algo參數(shù)進(jìn)行選擇。同時(shí)還可以通過(guò)修改dataset_option對(duì)使用的數(shù)據(jù)集進(jìn)行選擇。
??算法功能實(shí)現(xiàn)的函數(shù)在my_functions.py中,主要函數(shù)及其功能如下:
| main_func | 計(jì)算數(shù)據(jù)一致項(xiàng) |
| grad | 計(jì)算光滑項(xiàng)g的梯度 |
| make_update | 實(shí)現(xiàn)算法的更新策略 |
| breg | 計(jì)算Bregman距離 |
make_update函數(shù)
??breg_num為1時(shí),該函數(shù)實(shí)現(xiàn)了PALM與iPALM;breg_num為2時(shí),該函數(shù)實(shí)現(xiàn)了BPG相關(guān)算法。接下來(lái)討論BPG算法的代碼實(shí)現(xiàn)。
??BPG算法的abs_fun_num可以選擇正則化形式,等于3時(shí)實(shí)現(xiàn)了無(wú)正則和L2正則(因?yàn)長(zhǎng)2正則與無(wú)正則僅差一次項(xiàng)系數(shù),詳見(jiàn)supplementary),等于2時(shí)實(shí)現(xiàn)了L1正則。
無(wú)正則和L2正則
if breg_num ==2:# Calculates CoCaIn BPG-MF, BPG-MF, BPG-MF updates# 計(jì)算g對(duì)U和Z的偏導(dǎo)grad_u, grad_z = grad(A, U1, Z1, lam, fun_num=0)# 計(jì)算h對(duì)U和Z的偏導(dǎo)grad_h_1_a = U1*(np.linalg.norm(U1)**2 + np.linalg.norm(Z1)**2)grad_h_1_b = Z1*(np.linalg.norm(U1)**2 + np.linalg.norm(Z1)**2)grad_h_2_a = U1grad_h_2_b = Z1# 是否為對(duì)稱(chēng)矩陣sym_setting = 0if abs_fun_num == 3:# Code for No-Regularization and L2 Regularizationif exp_option==1:# No-Regularization is equivalent to L2 Regularization with lam=0# 計(jì)算P^kp_l = (1/uL_est)*grad_u - (c_1*grad_h_1_a + c_2*grad_h_2_a) # uL_est = 1, means lambda = 1# 計(jì)算 Q^k # lambda_0 is corresponding to lamq_l = (1/uL_est)*grad_z - (c_1*grad_h_1_b + c_2*grad_h_2_b)if sym_setting == 0: #default option# 解三次方程,temp_y為根coeff = [c_1*(np.linalg.norm(p_l)**2 + np.linalg.norm(q_l)**2), 0,(c_2 + (lam/uL_est)), -1]temp_y = np.roots(coeff)[-1].real# U^(k+1) = -r * P^k, Z^(k+1) = -r * Q^k, return (-1)*temp_y*p_l, (-1)*temp_y*q_lelse:p_new = p_l + q_l.Tcoeff = [4*c_1*(np.linalg.norm(p_new)**2), 0,2*(c_2 + (lam/uL_est)), -1]temp_y = np.roots(coeff)[-1].realreturn (-1)*temp_y*p_new, (-1)*temp_y*(p_new.T)L1正則
if abs_fun_num == 2:if exp_option==1:# L1 Regularization simpletp_l = (1/uL_est)*grad_u - (c_1*grad_h_1_a + c_2*grad_h_2_a) # 計(jì)算 P^kp_l = -np.maximum(0, np.abs(-tp_l)-lam*(1/uL_est))*np.sign(-tp_l) # 計(jì)算 -S_t0(-P^k)tq_l = (1/uL_est)*grad_z - (c_1*grad_h_1_b + c_2*grad_h_2_b) # 計(jì)算 Q^Kq_l = -np.maximum(0, np.abs(-tq_l)-lam*(1/uL_est))*np.sign(-tq_l) # 計(jì)算 -S_t0(-Q^k)# 解三次方程,temp_y為根coeff = [c_1*(np.linalg.norm(p_l)**2 + np.linalg.norm(q_l)**2), 0,(c_2), -1]temp_y = np.roots(coeff)[-1].real# 為了統(tǒng)一U^k和Z^k的計(jì)算形式,p_l和q_l相較于論文多了負(fù)號(hào)return (-1)*temp_y*p_l, (-1)*temp_y*q_l總結(jié)
以上是生活随笔為你收集整理的BPG-MF学习笔记的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: pyhton输油管问题
- 下一篇: 数据查询语句