代价敏感多标签主动学习的代码开发跟踪
1 簡介
代價敏感多標簽主動學習目前是閔老師小組正在進行的一個開發項目,目的是將代價敏感和主動學習思想應用到多標簽學習中。整個Java代碼涵蓋了很多技術:并行計算、batch處理。本文就是在學習這個代碼后的一些總結。學習方法采用的至頂向下。
2 Cmale類
該類為測試的主類。
2.1 數據
- dataset:保存整個多標簽數據集
- numInstances:樣本的個數
- numConditions:條件屬性的個數
- numLabels:標簽的個數
- outputFile:輸出的文件
- multiLabelAnn:用于分類的多標簽神經網絡
- representativenessArray:保存所有實例的代表性
- representativenessRankArray:所有實例代表性的排名
2.2 方法
-
(1)Cmale:構造方法
step 1. 讀數據文件構建dataset
step 2. 計算實例的代表性
step 3. 準備輸出文件 -
(2)initializeMultiLabelAnn:初始化多標簽神經網絡
利用dataset,全連接層節點,并行層節點構建多標簽神經網絡 -
(3)boundedTrain:給定輪數的上界、下界進行訓練
-
(4)boundedEmphasizedTrain:給定訓練輪數的上界, 進行針對性 (增量) 訓練
-
(5)computeInstanceRepresentativeness:基于密度峰值來計算實例的代表性
-
(6)twoStageLearn:兩階段學習: 冷啟動 (僅考慮對象代表性與標簽稀少性) 與 正常訓練 (考慮標簽不確定性) 注意: 這里是算法的核心, 需要改策略
-
(7)randomSelectionLearn: 隨機選擇標簽的學習, 作為對比算法. 如果我們的策略不比隨機策略好, 就沒有意義
3 MultiLabelData類:數據管理
本類讀入 arff 文件, 存儲成一個數據矩陣和標簽矩陣
3.1 數據
- dataMatrix: 數據矩陣
- labelMatrix: 標簽矩陣
- predictedLabelMatrix: 預測的標簽矩陣
- labelQueriedMatrix: 記錄哪些標簽被查詢
- 查詢代價、誤分類代價等
3.2 方法
- reset: 重置以支持多次訓練
- randomQuery: 隨機查詢給定數量的標簽, 支持隨機查詢方案
- getScareLabels: 找出哪些標簽是稀少的
- queryLabels: 查詢某個對象的一組標簽. 需要在內部保持數據的一致性, 出錯影響大
- computeAccuracy: 根據預測的標簽矩陣計算準確率. 需要預先給出預測值
- computeTrainingAccuracy: 計算在訓練集中的準確率, 以支持訓練結束的終止條件
- computeTotalCost: 計算總代價, 包括查詢代價與誤分類代價
- distance: 計算兩個實例之間的距離 (Manhattan 或 Euclidean)
4. 分類器構建 MultiLabelAnn.java
本類的神經網絡支持全連接層和并行連接層. 輸入端口數為條件屬性數, 輸出端口數為標簽數的 2 倍.
4.1 變量
dataset: 數據集
4.2 方法
- train: 訓練一輪, 僅使用被查詢過的對象
- emphasizedTrain: 訓練一輪, 被強調的數據多次訓練, 支持增量學習. 這是因為主動學習過程是增量學習
- test: 使用所有數據測試
- computeLabelUncertaintyMatrix: 計算標簽不確定性矩陣 注意: 這是核心方法, 以后可能要修改策略
- getMostUncertainLabelIndices: 獲得不確定性最高的幾個標簽, 包括對象下標 (一個) 與標簽下標 (多個) 注意: 以后可能修改, 以支持多個對象的批量選取, 縮短程序運行時間
- getUncertainLabelBatch: 大家好, 我就是上一條說的 “以后”
- forward: 神經網絡標準的前向操作
- backPropagation: 神經網絡標準的回饋操作
總結
以上是生活随笔為你收集整理的代价敏感多标签主动学习的代码开发跟踪的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: (pytorch-深度学习)包含并行连结
- 下一篇: 我目前的主要研究方向