PDARTS 网络结构搜索程序分析
PDARTS 即 Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation,是對 DARTS 的改進。DARTS 內(nèi)存占用過高,訓練不了較大的模型;PDARTS 將訓練劃分為3個階段,逐步搜索,在增加網(wǎng)絡深度的同時縮減操作種類。構造3次網(wǎng)絡拉長了訓練周期,過程如下圖所示:
此外,算法還對篩選細節(jié)進行了控制。chenxin061/pdarts 修改自 quark0/darts,主函數(shù)邏輯稍顯復雜。
train_search.py
start_time = time.time()main() end_time = time.time()duration = end_time - start_timelogging.info('Total searching time: %ds', duration)main()
Created with Rapha?l 2.2.0mainargsutils._data_transforms_cifar100torchvision.datasets.CIFAR100torch.utils.data.DataLoadertorch.nn.CrossEntropyLossNetworkoptim.lr_scheduler.CosineAnnealingLRoptim.Optimizer.stepoptim.lr_scheduler.CosineAnnealingLR.get_lrNetwork.update_ptraininferutils.saveNetwork.arch_parameterstorch.nn.functional.softmaxlast stage?get_min_k_no_zerologging_switchesparse_networkcheck_sk_numberdelete_min_sk_probkeep_1_onkeep_2_branchesEndget_min_kyesno if not torch.cuda.is_available():logging.info('No GPU device available')sys.exit(1)np.random.seed(args.seed)torch.cuda.set_device(args.gpu)cudnn.benchmark = Truetorch.manual_seed(args.seed)cudnn.enabled=Truetorch.cuda.manual_seed(args.seed)logging.info('GPU device = %d' % args.gpu)logging.info("args = %s", args)沒有將階段內(nèi)的處理封裝為函數(shù),流程不太直觀。
_data_transforms_cifar100 包括隨機截取、翻轉(zhuǎn)、標準化和隨機裁剪。
CIFAR100 是 CIFAR10 的子類。
torch.utils.data.sampler.SubsetRandomSampler 從給定的索引列表中隨機抽取元素樣本,不替換。
PRIMITIVES 定義了網(wǎng)絡可用的原語,共8種。經(jīng)3輪丟棄num_to_drop后,操作位置上剩1種或無操作。
switches_normal和switches_reduce為操作名稱列表。單元內(nèi)的連接數(shù)量為14。
依次構建每個階段的網(wǎng)絡進行訓練。sp即 search phase。
P-DARTS 網(wǎng)絡深度為5->11->17,DARTS 為7。
Network 構建網(wǎng)絡。
count_parameters_in_MB 統(tǒng)計模型大小。
train 傳入兩種優(yōu)化器,搜索結構用 Adam,訓練模型用 SGD。
最后5個 epoch 調(diào)用 infer 在驗證集上測試模型。
utils.save 保存階段訓練的結果。問題是名字一樣會覆蓋。
switches_normal_2和switches_reduce_2為第2階段處理前的操作列表。
arch_parameters 返回 (αnormal,αreduce)(\alpha_{normal}, \alpha_{reduce})(αnormal?,αreduce?)。
計算normal_prob:
exp(αo(i,j))∑o′∈Oexp(αo′(i,j))\begin{aligned} \frac{\mathrm{exp}(\alpha_o^{(i,j)})}{\sum_{o'\in\mathcal{O}}\mathrm{exp}(\alpha_{o'}^{(i,j)})} \end{aligned} ∑o′∈O?exp(αo′(i,j)?)exp(αo(i,j)?)??
idxs記錄處于活躍狀態(tài)的操作符的類型索引。
get_min_k 返回最小的num_to_drop[sp]個索引。
get_min_k_no_zero 先檢查idxs是否有0。
在最后一個階段丟棄所有空操作,否則丟棄指定數(shù)量的小權重操作。
# drop operations with low architecture weightsarch_param = model.arch_parameters()normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy() for i in range(14):idxs = []for j in range(len(PRIMITIVES)):if switches_normal[i][j]:idxs.append(j)if sp == len(num_to_keep) - 1:# for the last stage, drop all Zero operationsdrop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])else:drop = get_min_k(normal_prob[i, :], num_to_drop[sp])for idx in drop:switches_normal[i][idxs[idx]] = False縮減單元的處理與之相同。
reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()for i in range(14):idxs = []for j in range(len(PRIMITIVES)):if switches_reduce[i][j]:idxs.append(j)if sp == len(num_to_keep) - 1:drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])else:drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])for idx in drop:switches_reduce[i][idxs[idx]] = Falselogging.info('switches_normal = %s', switches_normal)logging_switches(switches_normal)logging.info('switches_reduce = %s', switches_reduce)logging_switches(switches_reduce)在階段的末尾,讀取結構參數(shù)。
normal_final和reduce_final記錄每個單元中非空操作選中的最大概率。
單元中的第1層為兩個操作,start = 2跳過。2-4,5-8,9-13。
tbsn和tbsr為標準和縮減單元當前層供選擇的位置。根據(jù)操作概率的大小排序。keep_normal和keep_reduce記錄需要保持的連接的索引。
過濾得到最終的switches_normal和switches_reduce,每層兩個操作。
parse_network 根據(jù)編碼列表解析得到網(wǎng)絡基因型。
check_sk_number 檢查網(wǎng)絡標準單元中skip_connect的數(shù)量,對應 PRIMITIVES 的索引3。
delete_min_sk_prob 刪除最小權重的跳躍連接。
keep_1_on 丟2留一。
keep_2_branches 修剪連接,每層僅保留兩個。
逐漸減少網(wǎng)絡標準單元中skip_connect的數(shù)量并記錄。
# translate switches into genotypegenotype = parse_network(switches_normal, switches_reduce)logging.info(genotype)## restrict skipconnect (normal cell only)logging.info('Restricting skipconnect...')# generating genotypes with different numbers of skip-connect operationsfor sks in range(0, 9):max_sk = 8 - sks num_sk = check_sk_number(switches_normal) if not num_sk > max_sk:continuewhile num_sk > max_sk:normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)switches_normal = keep_1_on(switches_normal_2, normal_prob)switches_normal = keep_2_branches(switches_normal, normal_prob)num_sk = check_sk_number(switches_normal)logging.info('Number of skip-connect: %d', max_sk)genotype = parse_network(switches_normal, switches_reduce)logging.info(genotype)train
初始化3個指標。
objs = utils.AvgrageMeter()top1 = utils.AvgrageMeter()top5 = utils.AvgrageMeter()如果訓練結構,從valid_queue中取數(shù)據(jù),先行訓練。
for step, (input, target) in enumerate(train_queue):model.train()n = input.size(0)input = input.cuda()target = target.cuda(non_blocking=True)if train_arch:# In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down# the training when using PyTorch 0.4 and above. try:input_search, target_search = next(valid_queue_iter)except:valid_queue_iter = iter(valid_queue)input_search, target_search = next(valid_queue_iter)input_search = input_search.cuda()target_search = target_search.cuda(non_blocking=True)optimizer_a.zero_grad()logits = model(input_search)loss_a = criterion(logits, target_search)loss_a.backward()nn.utils.clip_grad_norm_(model.arch_parameters(), args.grad_clip)optimizer_a.step()在訓練集上訓練權重。
optimizer.zero_grad()logits = model(input)loss = criterion(logits, target)loss.backward()nn.utils.clip_grad_norm_(network_params, args.grad_clip)optimizer.step()調(diào)用 utils.accuracy 計算訓練集上的準確率。
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))objs.update(loss.data.item(), n)top1.update(prec1.data.item(), n)top5.update(prec5.data.item(), n)if step % args.report_freq == 0:logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f', step, objs.avg, top1.avg, top5.avg)return top1.avg, objs.avginfer
Created with Rapha?l 2.2.0infervalid_queuenn.Module.evalNetworknn.CrossEntropyLossutils.accuracyobjs, top1, top5End objs = utils.AvgrageMeter()top1 = utils.AvgrageMeter()top5 = utils.AvgrageMeter()model.eval()for step, (input, target) in enumerate(valid_queue):input = input.cuda()target = target.cuda(non_blocking=True)with torch.no_grad():logits = model(input)loss = criterion(logits, target)prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))n = input.size(0)objs.update(loss.data.item(), n)top1.update(prec1.data.item(), n)top5.update(prec5.data.item(), n)if step % args.report_freq == 0:logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)return top1.avg, objs.avg_data_transforms_cifar10
相比原有變換多了 Cutout。
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(CIFAR_MEAN, CIFAR_STD),])if args.cutout:train_transform.transforms.append(Cutout(args.cutout_length))valid_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(CIFAR_MEAN, CIFAR_STD),])return train_transform, valid_transformCutout
def __init__(self, length):self.length = lengthdef __call__(self, img):h, w = img.size(1), img.size(2)mask = np.ones((h, w), np.float32)y = np.random.randint(h)x = np.random.randint(w)y1 = np.clip(y - self.length // 2, 0, h)y2 = np.clip(y + self.length // 2, 0, h)x1 = np.clip(x - self.length // 2, 0, w)x2 = np.clip(x + self.length // 2, 0, w)mask[y1: y2, x1: x2] = 0.mask = torch.from_numpy(mask)mask = mask.expand_as(img)img *= maskreturn imgcount_parameters_in_MB
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6parse_network
嵌套定義函數(shù)_parse_switches。解析兩種類型的單元,記錄操作類型和所在層次,得到 Genotype 類型的元組。
def _parse_switches(switches):n = 2start = 0gene = []step = 4for i in range(step):end = start + nfor j in range(start, end):for k in range(len(switches[j])):if switches[j][k]:gene.append((PRIMITIVES[k], j - start))start = endn = n + 1return genegene_normal = _parse_switches(switches_normal)gene_reduce = _parse_switches(switches_reduce)concat = range(2, 6)genotype = Genotype(normal=gene_normal, normal_concat=concat, reduce=gene_reduce, reduce_concat=concat)return genotypeNetwork
C為通道數(shù)量,layers為層數(shù),steps為內(nèi)部所劃分的層次,multiplier為輸出通道的乘數(shù),stem_multiplier為柄通道乘數(shù)。
switch_ons記錄每個操作位置可選操作的數(shù)量。self.switch_on直接取第一個位置的操作數(shù)。
網(wǎng)絡起始未下采樣,在1/3和2/3處插入縮減單元。
C_curr = stem_multiplier*Cself.stem = nn.Sequential(nn.Conv2d(3, C_curr, 3, padding=1, bias=False),nn.BatchNorm2d(C_curr))C_prev_prev, C_prev, C_curr = C_curr, C_curr, Cself.cells = nn.ModuleList()reduction_prev = Falsefor i in range(layers):if i in [layers//3, 2*layers//3]:C_curr *= 2reduction = Truecell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_reduce, self.p)else:reduction = Falsecell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_normal, self.p) # cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches)reduction_prev = reductionself.cells += [cell]C_prev_prev, C_prev = C_prev, multiplier*C_curr_initialize_alphas 初始化結構參數(shù),類型為Variable,而不是 torch.nn.Parameter。
self.global_pooling = nn.AdaptiveAvgPool2d(1)self.classifier = nn.Linear(C_prev, num_classes)self._initialize_alphas()forward
同類型的不同單元公用結構參數(shù)。
s0 = s1 = self.stem(input)for i, cell in enumerate(self.cells):if cell.reduction:if self.alphas_reduce.size(1) == 1:weights = F.softmax(self.alphas_reduce, dim=0)else:weights = F.softmax(self.alphas_reduce, dim=-1)else:if self.alphas_normal.size(1) == 1:weights = F.softmax(self.alphas_normal, dim=0)else:weights = F.softmax(self.alphas_normal, dim=-1)s0, s1 = s1, cell(s0, s1, weights)out = self.global_pooling(s1)logits = self.classifier(out.view(out.size(0),-1))return logitsupdate_p
update_p 給數(shù)據(jù)并行帶來了麻煩。
for cell in self.cells:cell.p = self.pcell.update_p()_loss
函數(shù)沒有用到。
logits = self(input)return self._criterion(logits, target)_initialize_alphas
k為單元中 MixedOp 的數(shù)量,self.switch_on為 MixedOp 中候選操作的種類。
k = sum(1 for i in range(self._steps) for n in range(2+i))num_ops = self.switch_onself.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)self._arch_parameters = [self.alphas_normal,self.alphas_reduce,]arch_parameters
return self._arch_parametersCell
preprocess0MixedOp0preprocess1MixedOp1add0MixedOp2MixedOp3MixedOp4add1MixedOp5MixedOp6MixedOp7MixedOp8add2MixedOp9MixedOp10MixedOp11MixedOp12MixedOp13add3concatenateFactorizedReduce 采用位置交錯的兩組卷積。
與 NASNet、AmoebaNet 和 PNAS 一樣卷積采用 ReLUConvBN。
沒有手動初始化權重。
steps=4,使得 Cell 中包含 2+3+4+5=14 個 MixedOp,即len(self.cell_ops)=14。每層多2個用于處理輸入。
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, switches, p):super(Cell, self).__init__()self.reduction = reductionself.p = pif reduction_prev:self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)else:self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)self._steps = stepsself._multiplier = multiplierself.cell_ops = nn.ModuleList()switch_count = 0for i in range(self._steps):for j in range(2+i):stride = 2 if reduction and j < 2 else 1op = MixedOp(C, stride, switch=switches[switch_count], p=self.p)self.cell_ops.append(op)switch_count = switch_count + 1update_p
for op in self.cell_ops:op.p = self.pop.update_p()forward
每個中間節(jié)點都基于其所有先前節(jié)點計算:
x(j)=∑i<jo(i,j)(x(i))\begin{aligned} x^{(j)} = \sum_{i<j} o^{(i, j)}(x^{(i)}) \end{aligned} x(j)=i<j∑?o(i,j)(x(i))?
還包括一個特殊的 zero\mathit{zero}zero 操作,表示兩個節(jié)點之間缺少連接。 因此,學習單元的任務減少了學習其邊緣的操作。
對于每一步,累加所有操作的輸出。offset不斷累加意味著self.cell_ops的數(shù)量為2+3+4+5=14。
s0 = self.preprocess0(s0)s1 = self.preprocess1(s1)states = [s0, s1]offset = 0for i in range(self._steps):s = sum(self.cell_ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))offset += len(states)states.append(s)return torch.cat(states[-self._multiplier:], dim=1)MixedOp
OPS 為操作字典。affine=False設置 nn.BatchNorm2d 屏蔽可學習參數(shù),等效于 Caffe 中的 BN 層。
DARTS 的 A.1.1 中指出由于架構在整個搜索過程中會有所不同,因此其始終使用批量特定的統(tǒng)計信息進行批量標準化而不是全局移動平均值。在搜索過程中禁用所有批量標準化中可學習的仿射參數(shù),以避免重新調(diào)整候選操作的輸出。然而,代碼中并未設置track_running_stats=False。
switch為操作的掩碼,len(switch)=len(PRIMITIVES)。PRIMITIVES 共有8種操作,存儲到self.m_ops。
def __init__(self, C, stride, switch, p):super(MixedOp, self).__init__()self.m_ops = nn.ModuleList()self.p = pfor i in range(len(switch)):if switch[i]:primitive = PRIMITIVES[i]op = OPS[primitive](C, stride, False)if 'pool' in primitive:op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))if isinstance(op, Identity) and p > 0:op = nn.Sequential(op, nn.Dropout(self.p))self.m_ops.append(op)update_p
如果第一個操作是Identity,則在后面添加操作。
for op in self.m_ops:if isinstance(op, nn.Sequential):if isinstance(op[0], Identity):op[1].p = self.pforward
令 O\mathcal{O}O 為一組候選操作(例如卷積、最大合并、zero\mathit{zero}zero),其中每個操作代表應用于 x(i)x^{(i)}x(i) 的函數(shù) o(?)o(\cdot)o(?)。
為了使搜索空間連續(xù),DARTS 將特定操作的分類選擇放寬為所有可能操作的 softmax:
oˉ(i,j)(x)=∑o∈Oexp?(αo(i,j))∑o′∈Oexp?(αo′(i,j))o(x)\begin{aligned} \bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o' \in \mathcal{O}} \exp(\alpha_{o'}^{(i,j)})} o(x) \end{aligned} oˉ(i,j)(x)=o∈O∑?∑o′∈O?exp(αo′(i,j)?)exp(αo(i,j)?)?o(x)?
其中一對節(jié)點 (i,j)(i,j)(i,j) 的操作混合權重由維數(shù) ∣O∣|\mathcal{O}|∣O∣ 的向量 α(i,j)\alpha^{(i,j)}α(i,j) 參數(shù)化。
然后,架構搜索的任務化簡為學習一組連續(xù)變量 α={α(i,j)}\alpha = \big\{ \alpha^{(i,j)} \big\}α={α(i,j)}。在搜索結束時,可以通過用最可能的操作替換每個混合操作 oˉ(i,j)\bar{o}^{(i,j)}oˉ(i,j) 來獲得離散體系結構,即
o(i,j)=argmaxo∈O αo(i,j)o^{(i,j)} = \mathrm{argmax}_{o \in \mathcal{O}} \, \alpha^{(i,j)}_o o(i,j)=argmaxo∈O?αo(i,j)?.
模型中定義forward之外的函數(shù),導致不能正常使用 torch.nn.DataParallel。
delete_min_sk_prob
嵌套定義_get_sk_idx函數(shù)。如果輸入的列表里沒有跳躍連接則返回-1;否則返回原列表switches_bk中的跳躍連接索引。
def _get_sk_idx(switches_in, switches_bk, k):if not switches_in[k][3]:idx = -1else:idx = 0for i in range(3):if switches_bk[k][i]:idx = idx + 1return idx避免修改輸入,sk_prob記錄每個位置上跳躍連接的權重。從中取最小的置為0。
probs_out = copy.deepcopy(probs_in)sk_prob = [1.0 for i in range(len(switches_bk))]for i in range(len(switches_in)):idx = _get_sk_idx(switches_in, switches_bk, i)if not idx == -1:sk_prob[i] = probs_out[i][idx]d_idx = np.argmin(sk_prob)idx = _get_sk_idx(switches_in, switches_bk, d_idx)probs_out[d_idx][idx] = 0.0return probs_outkeep_1_on
keep_1_onget_min_k_no_zero對于每個操作位,idxs記錄可選操作的索引。get_min_k_no_zero 查找操作位概率最小且非空的2個操作,丟棄掉。
switches = copy.deepcopy(switches_in)for i in range(len(switches)):idxs = []for j in range(len(PRIMITIVES)):if switches[i][j]:idxs.append(j)drop = get_min_k_no_zero(probs[i, :], idxs, 2)for idx in drop:switches[i][idxs[idx]] = False return switcheskeep_2_branches
final_prob為每個操作位上操作最大概率。
switches = copy.deepcopy(switches_in)final_prob = [0.0 for i in range(len(switches))]for i in range(len(switches)):final_prob[i] = max(probs[i])第1層只有兩個操作位,所以直接保留。
后續(xù)3層依次取出其最大概率,排序后取最大的兩個位置。
遍歷位置,在switches屏蔽未選中的位置。
for i in range(len(switches)):if not i in keep:for j in range(len(PRIMITIVES)):switches[i][j] = False return switcheslogging_switches
for i in range(len(switches)):ops = []for j in range(len(switches[i])):if switches[i][j]:ops.append(PRIMITIVES[j])logging.info(ops)參考資料:
- Affine parameter in batchnorm
- AutoML (5) - DARTS: multi-gpu extension
- Bug in DataParallel? Only works if the dataset device is cuda:0
- How to print list item + integer/string using logging in Python
- Python String format()
- setting CUDA_VISIBLE_DEVICES just has no effect #9158
- pytorch/examples/imagenet
- 梯度下降學習率的設定策略
總結
以上是生活随笔為你收集整理的PDARTS 网络结构搜索程序分析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: iOS 上架app的过程当中出现the
- 下一篇: 适用于ABAQUS的黏弹性边界(粘弹性边