日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 > 编程资源 > 编程问答 >内容正文

编程问答

PDARTS 网络结构搜索程序分析

發(fā)布時間:2023/12/20 编程问答 37 豆豆
生活随笔 收集整理的這篇文章主要介紹了 PDARTS 网络结构搜索程序分析 小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.

PDARTS 即 Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation,是對 DARTS 的改進。DARTS 內(nèi)存占用過高,訓練不了較大的模型;PDARTS 將訓練劃分為3個階段,逐步搜索,在增加網(wǎng)絡深度的同時縮減操作種類。構造3次網(wǎng)絡拉長了訓練周期,過程如下圖所示:


此外,算法還對篩選細節(jié)進行了控制。chenxin061/pdarts 修改自 quark0/darts,主函數(shù)邏輯稍顯復雜。

train_search.py

start_time = time.time()main() end_time = time.time()duration = end_time - start_timelogging.info('Total searching time: %ds', duration)

main()

Created with Rapha?l 2.2.0mainargsutils._data_transforms_cifar100torchvision.datasets.CIFAR100torch.utils.data.DataLoadertorch.nn.CrossEntropyLossNetworkoptim.lr_scheduler.CosineAnnealingLRoptim.Optimizer.stepoptim.lr_scheduler.CosineAnnealingLR.get_lrNetwork.update_ptraininferutils.saveNetwork.arch_parameterstorch.nn.functional.softmaxlast stage?get_min_k_no_zerologging_switchesparse_networkcheck_sk_numberdelete_min_sk_probkeep_1_onkeep_2_branchesEndget_min_kyesno if not torch.cuda.is_available():logging.info('No GPU device available')sys.exit(1)np.random.seed(args.seed)torch.cuda.set_device(args.gpu)cudnn.benchmark = Truetorch.manual_seed(args.seed)cudnn.enabled=Truetorch.cuda.manual_seed(args.seed)logging.info('GPU device = %d' % args.gpu)logging.info("args = %s", args)

沒有將階段內(nèi)的處理封裝為函數(shù),流程不太直觀。

_data_transforms_cifar100 包括隨機截取、翻轉(zhuǎn)、標準化和隨機裁剪。
CIFAR100 是 CIFAR10 的子類。
torch.utils.data.sampler.SubsetRandomSampler 從給定的索引列表中隨機抽取元素樣本,不替換。

# prepare datasetif args.cifar100:train_transform, valid_transform = utils._data_transforms_cifar100(args)else:train_transform, valid_transform = utils._data_transforms_cifar10(args)if args.cifar100:train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)else:train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)num_train = len(train_data)indices = list(range(num_train))split = int(np.floor(args.train_portion * num_train))train_queue = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),pin_memory=True, num_workers=args.workers)valid_queue = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),pin_memory=True, num_workers=args.workers)

PRIMITIVES 定義了網(wǎng)絡可用的原語,共8種。經(jīng)3輪丟棄num_to_drop后,操作位置上剩1種或無操作。
switches_normal和switches_reduce為操作名稱列表。單元內(nèi)的連接數(shù)量為14。

# build Networkcriterion = nn.CrossEntropyLoss()criterion = criterion.cuda()switches = []for i in range(14):switches.append([True for j in range(len(PRIMITIVES))])switches_normal = copy.deepcopy(switches)switches_reduce = copy.deepcopy(switches)# To be moved to argsnum_to_keep = [5, 3, 1]num_to_drop = [3, 2, 2]if len(args.add_width) == 3:add_width = args.add_widthelse:add_width = [0, 0, 0]if len(args.add_layers) == 3:add_layers = args.add_layerselse:add_layers = [0, 6, 12]if len(args.dropout_rate) ==3:drop_rate = args.dropout_rateelse:drop_rate = [0.0, 0.0, 0.0]eps_no_archs = [10, 10, 10]

依次構建每個階段的網(wǎng)絡進行訓練。sp即 search phase。
P-DARTS 網(wǎng)絡深度為5->11->17,DARTS 為7。
Network 構建網(wǎng)絡。
count_parameters_in_MB 統(tǒng)計模型大小。
train 傳入兩種優(yōu)化器,搜索結構用 Adam,訓練模型用 SGD。
最后5個 epoch 調(diào)用 infer 在驗證集上測試模型。

for sp in range(len(num_to_keep)):model = Network(args.init_channels + int(add_width[sp]), CIFAR_CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]))model = model.cuda()logging.info("param size = %fMB", utils.count_parameters_in_MB(model))network_params = []for k, v in model.named_parameters():if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):network_params.append(v) optimizer = torch.optim.SGD(network_params,args.learning_rate,momentum=args.momentum,weight_decay=args.weight_decay)optimizer_a = torch.optim.Adam(model.arch_parameters(),lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min)sm_dim = -1epochs = args.epochseps_no_arch = eps_no_archs[sp]scale_factor = 0.2for epoch in range(epochs):scheduler.step()lr = scheduler.get_lr()[0]logging.info('Epoch: %d lr: %e', epoch, lr)epoch_start = time.time()# trainingif epoch < eps_no_arch:model.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochsmodel.update_p()train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=False)else:model.p = float(drop_rate[sp]) * np.exp(-(epoch - eps_no_arch) * scale_factor) model.update_p() train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=True)logging.info('Train_acc %f', train_acc)epoch_duration = time.time() - epoch_startlogging.info('Epoch time: %ds', epoch_duration)# validationif epochs - epoch < 5:valid_acc, valid_obj = infer(valid_queue, model, criterion)logging.info('Valid_acc %f', valid_acc)

utils.save 保存階段訓練的結果。問題是名字一樣會覆蓋。
switches_normal_2和switches_reduce_2為第2階段處理前的操作列表。

utils.save(model, os.path.join(args.save, 'weights.pt'))print('------Dropping %d paths------' % num_to_drop[sp])# Save switches info for s-c refinement. if sp == len(num_to_keep) - 1:switches_normal_2 = copy.deepcopy(switches_normal)switches_reduce_2 = copy.deepcopy(switches_reduce)

arch_parameters 返回 (αnormal,αreduce)(\alpha_{normal}, \alpha_{reduce})(αnormal?,αreduce?)
計算normal_prob:
exp(αo(i,j))∑o′∈Oexp(αo′(i,j))\begin{aligned} \frac{\mathrm{exp}(\alpha_o^{(i,j)})}{\sum_{o&#x27;\in\mathcal{O}}\mathrm{exp}(\alpha_{o&#x27;}^{(i,j)})} \end{aligned} oO?exp(αo(i,j)?)exp(αo(i,j)?)??
idxs記錄處于活躍狀態(tài)的操作符的類型索引。
get_min_k 返回最小的num_to_drop[sp]個索引。
get_min_k_no_zero 先檢查idxs是否有0。

在最后一個階段丟棄所有空操作,否則丟棄指定數(shù)量的小權重操作。

# drop operations with low architecture weightsarch_param = model.arch_parameters()normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy() for i in range(14):idxs = []for j in range(len(PRIMITIVES)):if switches_normal[i][j]:idxs.append(j)if sp == len(num_to_keep) - 1:# for the last stage, drop all Zero operationsdrop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])else:drop = get_min_k(normal_prob[i, :], num_to_drop[sp])for idx in drop:switches_normal[i][idxs[idx]] = False

縮減單元的處理與之相同。

reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()for i in range(14):idxs = []for j in range(len(PRIMITIVES)):if switches_reduce[i][j]:idxs.append(j)if sp == len(num_to_keep) - 1:drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])else:drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])for idx in drop:switches_reduce[i][idxs[idx]] = Falselogging.info('switches_normal = %s', switches_normal)logging_switches(switches_normal)logging.info('switches_reduce = %s', switches_reduce)logging_switches(switches_reduce)

在階段的末尾,讀取結構參數(shù)。
normal_final和reduce_final記錄每個單元中非空操作選中的最大概率。

if sp == len(num_to_keep) - 1:arch_param = model.arch_parameters()normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()normal_final = [0 for idx in range(14)]reduce_final = [0 for idx in range(14)]# remove all Zero operationsfor i in range(14):if switches_normal_2[i][0] == True:normal_prob[i][0] = 0normal_final[i] = max(normal_prob[i])if switches_reduce_2[i][0] == True:reduce_prob[i][0] = 0reduce_final[i] = max(reduce_prob[i])

單元中的第1層為兩個操作,start = 2跳過。2-4,5-8,9-13。
tbsn和tbsr為標準和縮減單元當前層供選擇的位置。根據(jù)操作概率的大小排序。keep_normal和keep_reduce記錄需要保持的連接的索引。
過濾得到最終的switches_normal和switches_reduce,每層兩個操作。

# Generate Architecture, similar to DARTSkeep_normal = [0, 1]keep_reduce = [0, 1]n = 3start = 2for i in range(3):end = start + ntbsn = normal_final[start:end]tbsr = reduce_final[start:end]edge_n = sorted(range(n), key=lambda x: tbsn[x])keep_normal.append(edge_n[-1] + start)keep_normal.append(edge_n[-2] + start)edge_r = sorted(range(n), key=lambda x: tbsr[x])keep_reduce.append(edge_r[-1] + start)keep_reduce.append(edge_r[-2] + start)start = endn = n + 1# set switches according the ranking of arch parametersfor i in range(14):if not i in keep_normal:for j in range(len(PRIMITIVES)):switches_normal[i][j] = Falseif not i in keep_reduce:for j in range(len(PRIMITIVES)):switches_reduce[i][j] = False

parse_network 根據(jù)編碼列表解析得到網(wǎng)絡基因型。
check_sk_number 檢查網(wǎng)絡標準單元中skip_connect的數(shù)量,對應 PRIMITIVES 的索引3。
delete_min_sk_prob 刪除最小權重的跳躍連接。
keep_1_on 丟2留一。
keep_2_branches 修剪連接,每層僅保留兩個。

逐漸減少網(wǎng)絡標準單元中skip_connect的數(shù)量并記錄。

# translate switches into genotypegenotype = parse_network(switches_normal, switches_reduce)logging.info(genotype)## restrict skipconnect (normal cell only)logging.info('Restricting skipconnect...')# generating genotypes with different numbers of skip-connect operationsfor sks in range(0, 9):max_sk = 8 - sks num_sk = check_sk_number(switches_normal) if not num_sk > max_sk:continuewhile num_sk > max_sk:normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)switches_normal = keep_1_on(switches_normal_2, normal_prob)switches_normal = keep_2_branches(switches_normal, normal_prob)num_sk = check_sk_number(switches_normal)logging.info('Number of skip-connect: %d', max_sk)genotype = parse_network(switches_normal, switches_reduce)logging.info(genotype)

train

初始化3個指標。

objs = utils.AvgrageMeter()top1 = utils.AvgrageMeter()top5 = utils.AvgrageMeter()

如果訓練結構,從valid_queue中取數(shù)據(jù),先行訓練。

for step, (input, target) in enumerate(train_queue):model.train()n = input.size(0)input = input.cuda()target = target.cuda(non_blocking=True)if train_arch:# In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down# the training when using PyTorch 0.4 and above. try:input_search, target_search = next(valid_queue_iter)except:valid_queue_iter = iter(valid_queue)input_search, target_search = next(valid_queue_iter)input_search = input_search.cuda()target_search = target_search.cuda(non_blocking=True)optimizer_a.zero_grad()logits = model(input_search)loss_a = criterion(logits, target_search)loss_a.backward()nn.utils.clip_grad_norm_(model.arch_parameters(), args.grad_clip)optimizer_a.step()

在訓練集上訓練權重。

optimizer.zero_grad()logits = model(input)loss = criterion(logits, target)loss.backward()nn.utils.clip_grad_norm_(network_params, args.grad_clip)optimizer.step()

調(diào)用 utils.accuracy 計算訓練集上的準確率。

prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))objs.update(loss.data.item(), n)top1.update(prec1.data.item(), n)top5.update(prec5.data.item(), n)if step % args.report_freq == 0:logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f', step, objs.avg, top1.avg, top5.avg)return top1.avg, objs.avg

infer

Created with Rapha?l 2.2.0infervalid_queuenn.Module.evalNetworknn.CrossEntropyLossutils.accuracyobjs, top1, top5End objs = utils.AvgrageMeter()top1 = utils.AvgrageMeter()top5 = utils.AvgrageMeter()model.eval()for step, (input, target) in enumerate(valid_queue):input = input.cuda()target = target.cuda(non_blocking=True)with torch.no_grad():logits = model(input)loss = criterion(logits, target)prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))n = input.size(0)objs.update(loss.data.item(), n)top1.update(prec1.data.item(), n)top5.update(prec5.data.item(), n)if step % args.report_freq == 0:logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)return top1.avg, objs.avg

_data_transforms_cifar10

相比原有變換多了 Cutout。

CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(CIFAR_MEAN, CIFAR_STD),])if args.cutout:train_transform.transforms.append(Cutout(args.cutout_length))valid_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(CIFAR_MEAN, CIFAR_STD),])return train_transform, valid_transform

Cutout

def __init__(self, length):self.length = lengthdef __call__(self, img):h, w = img.size(1), img.size(2)mask = np.ones((h, w), np.float32)y = np.random.randint(h)x = np.random.randint(w)y1 = np.clip(y - self.length // 2, 0, h)y2 = np.clip(y + self.length // 2, 0, h)x1 = np.clip(x - self.length // 2, 0, w)x2 = np.clip(x + self.length // 2, 0, w)mask[y1: y2, x1: x2] = 0.mask = torch.from_numpy(mask)mask = mask.expand_as(img)img *= maskreturn img

count_parameters_in_MB

return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6

parse_network

嵌套定義函數(shù)_parse_switches。解析兩種類型的單元,記錄操作類型和所在層次,得到 Genotype 類型的元組。

def _parse_switches(switches):n = 2start = 0gene = []step = 4for i in range(step):end = start + nfor j in range(start, end):for k in range(len(switches[j])):if switches[j][k]:gene.append((PRIMITIVES[k], j - start))start = endn = n + 1return genegene_normal = _parse_switches(switches_normal)gene_reduce = _parse_switches(switches_reduce)concat = range(2, 6)genotype = Genotype(normal=gene_normal, normal_concat=concat, reduce=gene_reduce, reduce_concat=concat)return genotype

Network

C為通道數(shù)量,layers為層數(shù),steps為內(nèi)部所劃分的層次,multiplier為輸出通道的乘數(shù),stem_multiplier為柄通道乘數(shù)。
switch_ons記錄每個操作位置可選操作的數(shù)量。self.switch_on直接取第一個位置的操作數(shù)。

def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3, switches_normal=[], switches_reduce=[], p=0.0):super(Network, self).__init__()self._C = Cself._num_classes = num_classesself._layers = layersself._criterion = criterionself._steps = stepsself._multiplier = multiplierself.p = pself.switches_normal = switches_normalswitch_ons = []for i in range(len(switches_normal)):ons = 0for j in range(len(switches_normal[i])):if switches_normal[i][j]:ons = ons + 1switch_ons.append(ons)ons = 0self.switch_on = switch_ons[0]

網(wǎng)絡起始未下采樣,在1/3和2/3處插入縮減單元。

C_curr = stem_multiplier*Cself.stem = nn.Sequential(nn.Conv2d(3, C_curr, 3, padding=1, bias=False),nn.BatchNorm2d(C_curr))C_prev_prev, C_prev, C_curr = C_curr, C_curr, Cself.cells = nn.ModuleList()reduction_prev = Falsefor i in range(layers):if i in [layers//3, 2*layers//3]:C_curr *= 2reduction = Truecell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_reduce, self.p)else:reduction = Falsecell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_normal, self.p) # cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches)reduction_prev = reductionself.cells += [cell]C_prev_prev, C_prev = C_prev, multiplier*C_curr

_initialize_alphas 初始化結構參數(shù),類型為Variable,而不是 torch.nn.Parameter。

self.global_pooling = nn.AdaptiveAvgPool2d(1)self.classifier = nn.Linear(C_prev, num_classes)self._initialize_alphas()

forward

同類型的不同單元公用結構參數(shù)。

s0 = s1 = self.stem(input)for i, cell in enumerate(self.cells):if cell.reduction:if self.alphas_reduce.size(1) == 1:weights = F.softmax(self.alphas_reduce, dim=0)else:weights = F.softmax(self.alphas_reduce, dim=-1)else:if self.alphas_normal.size(1) == 1:weights = F.softmax(self.alphas_normal, dim=0)else:weights = F.softmax(self.alphas_normal, dim=-1)s0, s1 = s1, cell(s0, s1, weights)out = self.global_pooling(s1)logits = self.classifier(out.view(out.size(0),-1))return logits

update_p

update_p 給數(shù)據(jù)并行帶來了麻煩。

for cell in self.cells:cell.p = self.pcell.update_p()

_loss

函數(shù)沒有用到。

logits = self(input)return self._criterion(logits, target)

_initialize_alphas

k為單元中 MixedOp 的數(shù)量,self.switch_on為 MixedOp 中候選操作的種類。

k = sum(1 for i in range(self._steps) for n in range(2+i))num_ops = self.switch_onself.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)self._arch_parameters = [self.alphas_normal,self.alphas_reduce,]

arch_parameters

return self._arch_parameters

Cell

preprocess0MixedOp0preprocess1MixedOp1add0MixedOp2MixedOp3MixedOp4add1MixedOp5MixedOp6MixedOp7MixedOp8add2MixedOp9MixedOp10MixedOp11MixedOp12MixedOp13add3concatenate

FactorizedReduce 采用位置交錯的兩組卷積。
與 NASNet、AmoebaNet 和 PNAS 一樣卷積采用 ReLUConvBN。

沒有手動初始化權重。

steps=4,使得 Cell 中包含 2+3+4+5=14 個 MixedOp,即len(self.cell_ops)=14。每層多2個用于處理輸入。

def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, switches, p):super(Cell, self).__init__()self.reduction = reductionself.p = pif reduction_prev:self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)else:self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)self._steps = stepsself._multiplier = multiplierself.cell_ops = nn.ModuleList()switch_count = 0for i in range(self._steps):for j in range(2+i):stride = 2 if reduction and j < 2 else 1op = MixedOp(C, stride, switch=switches[switch_count], p=self.p)self.cell_ops.append(op)switch_count = switch_count + 1

update_p

for op in self.cell_ops:op.p = self.pop.update_p()

forward

每個中間節(jié)點都基于其所有先前節(jié)點計算:

x(j)=∑i&lt;jo(i,j)(x(i))\begin{aligned} x^{(j)} = \sum_{i&lt;j} o^{(i, j)}(x^{(i)}) \end{aligned} x(j)=i<j?o(i,j)(x(i))?

還包括一個特殊的 zero\mathit{zero}zero 操作,表示兩個節(jié)點之間缺少連接。 因此,學習單元的任務減少了學習其邊緣的操作。

對于每一步,累加所有操作的輸出。offset不斷累加意味著self.cell_ops的數(shù)量為2+3+4+5=14。

s0 = self.preprocess0(s0)s1 = self.preprocess1(s1)states = [s0, s1]offset = 0for i in range(self._steps):s = sum(self.cell_ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))offset += len(states)states.append(s)return torch.cat(states[-self._multiplier:], dim=1)

MixedOp

OPS 為操作字典。affine=False設置 nn.BatchNorm2d 屏蔽可學習參數(shù),等效于 Caffe 中的 BN 層。

DARTS 的 A.1.1 中指出由于架構在整個搜索過程中會有所不同,因此其始終使用批量特定的統(tǒng)計信息進行批量標準化而不是全局移動平均值。在搜索過程中禁用所有批量標準化中可學習的仿射參數(shù),以避免重新調(diào)整候選操作的輸出。然而,代碼中并未設置track_running_stats=False。

switch為操作的掩碼,len(switch)=len(PRIMITIVES)。PRIMITIVES 共有8種操作,存儲到self.m_ops。

def __init__(self, C, stride, switch, p):super(MixedOp, self).__init__()self.m_ops = nn.ModuleList()self.p = pfor i in range(len(switch)):if switch[i]:primitive = PRIMITIVES[i]op = OPS[primitive](C, stride, False)if 'pool' in primitive:op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))if isinstance(op, Identity) and p > 0:op = nn.Sequential(op, nn.Dropout(self.p))self.m_ops.append(op)

update_p

如果第一個操作是Identity,則在后面添加操作。

for op in self.m_ops:if isinstance(op, nn.Sequential):if isinstance(op[0], Identity):op[1].p = self.p

forward

O\mathcal{O}O 為一組候選操作(例如卷積、最大合并、zero\mathit{zero}zero),其中每個操作代表應用于 x(i)x^{(i)}x(i) 的函數(shù) o(?)o(\cdot)o(?)

為了使搜索空間連續(xù),DARTS 將特定操作的分類選擇放寬為所有可能操作的 softmax:
oˉ(i,j)(x)=∑o∈Oexp?(αo(i,j))∑o′∈Oexp?(αo′(i,j))o(x)\begin{aligned} \bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o&#x27; \in \mathcal{O}} \exp(\alpha_{o&#x27;}^{(i,j)})} o(x) \end{aligned} oˉ(i,j)(x)=oO?oO?exp(αo(i,j)?)exp(αo(i,j)?)?o(x)?
其中一對節(jié)點 (i,j)(i,j)(i,j) 的操作混合權重由維數(shù) ∣O∣|\mathcal{O}|O 的向量 α(i,j)\alpha^{(i,j)}α(i,j) 參數(shù)化。

然后,架構搜索的任務化簡為學習一組連續(xù)變量 α={α(i,j)}\alpha = \big\{ \alpha^{(i,j)} \big\}α={α(i,j)}。在搜索結束時,可以通過用最可能的操作替換每個混合操作 oˉ(i,j)\bar{o}^{(i,j)}oˉ(i,j) 來獲得離散體系結構,即
o(i,j)=argmaxo∈O&ThinSpace;αo(i,j)o^{(i,j)} = \mathrm{argmax}_{o \in \mathcal{O}} \, \alpha^{(i,j)}_o o(i,j)=argmaxoO?αo(i,j)?.

return sum(w * op(x) for w, op in zip(weights, self.m_ops))

模型中定義forward之外的函數(shù),導致不能正常使用 torch.nn.DataParallel。

delete_min_sk_prob

嵌套定義_get_sk_idx函數(shù)。如果輸入的列表里沒有跳躍連接則返回-1;否則返回原列表switches_bk中的跳躍連接索引。

def _get_sk_idx(switches_in, switches_bk, k):if not switches_in[k][3]:idx = -1else:idx = 0for i in range(3):if switches_bk[k][i]:idx = idx + 1return idx

避免修改輸入,sk_prob記錄每個位置上跳躍連接的權重。從中取最小的置為0。

probs_out = copy.deepcopy(probs_in)sk_prob = [1.0 for i in range(len(switches_bk))]for i in range(len(switches_in)):idx = _get_sk_idx(switches_in, switches_bk, i)if not idx == -1:sk_prob[i] = probs_out[i][idx]d_idx = np.argmin(sk_prob)idx = _get_sk_idx(switches_in, switches_bk, d_idx)probs_out[d_idx][idx] = 0.0return probs_out

keep_1_on

keep_1_onget_min_k_no_zero

對于每個操作位,idxs記錄可選操作的索引。get_min_k_no_zero 查找操作位概率最小且非空的2個操作,丟棄掉。

switches = copy.deepcopy(switches_in)for i in range(len(switches)):idxs = []for j in range(len(PRIMITIVES)):if switches[i][j]:idxs.append(j)drop = get_min_k_no_zero(probs[i, :], idxs, 2)for idx in drop:switches[i][idxs[idx]] = False return switches

keep_2_branches

final_prob為每個操作位上操作最大概率。

switches = copy.deepcopy(switches_in)final_prob = [0.0 for i in range(len(switches))]for i in range(len(switches)):final_prob[i] = max(probs[i])

第1層只有兩個操作位,所以直接保留。
后續(xù)3層依次取出其最大概率,排序后取最大的兩個位置。

keep = [0, 1]n = 3start = 2for i in range(3):end = start + ntb = final_prob[start:end]edge = sorted(range(n), key=lambda x: tb[x])keep.append(edge[-1] + start)keep.append(edge[-2] + start)start = endn = n + 1

遍歷位置,在switches屏蔽未選中的位置。

for i in range(len(switches)):if not i in keep:for j in range(len(PRIMITIVES)):switches[i][j] = False return switches

logging_switches

for i in range(len(switches)):ops = []for j in range(len(switches[i])):if switches[i][j]:ops.append(PRIMITIVES[j])logging.info(ops)

參考資料:

  • Affine parameter in batchnorm
  • AutoML (5) - DARTS: multi-gpu extension
  • Bug in DataParallel? Only works if the dataset device is cuda:0
  • How to print list item + integer/string using logging in Python
  • Python String format()
  • setting CUDA_VISIBLE_DEVICES just has no effect #9158
  • pytorch/examples/imagenet
  • 梯度下降學習率的設定策略

總結

以上是生活随笔為你收集整理的PDARTS 网络结构搜索程序分析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網(wǎng)站內(nèi)容還不錯,歡迎將生活随笔推薦給好友。