pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)
pytorch時空數據處理4——圖像轉文本/字幕Image-Captionning(二)
- pytorch時空數據處理4——圖像轉文本/字幕Image-Captionning(二)
- Dataset
- Inputs to model
- Caption Lengths
- Data pipeline
- Encoder
- Attention
- Decoder
- 代碼
- 數據集初始化 create_input_files.py
- 訓練 train.py
- 測試
pytorch時空數據處理4——圖像轉文本/字幕Image-Captionning(二)
書接上文,本篇主要講解工程代碼結構和代碼運行。
代碼來源:git
Dataset
我正在使用MSCOCO '14數據集。您需要下載訓練(13GB)和驗證(6GB)。
我們將使用安德烈·卡帕西的訓練、驗證和測試分割方法。這個壓縮文件包含標題。您還可以找到FlushT 8K和FlushT 30K數據集的拆分和標題,所以如果MSCOCO對您的計算機來說太大,請隨意使用它們來代替MSCOCO。
Inputs to model
圖像由于我們使用的是預處理編碼器,我們需要將圖像處理成預處理編碼器習慣的形式。
預訓練的ImageNet模型可作為PyTorch的torchvision模塊的一部分。論文原文詳細說明了我們需要執行的預處理或轉換——像素值必須在[0,1]范圍內,然后我們必須通過ImageNet圖像的RGB通道的平均值和標準偏差對圖像進行歸一化。
此外,PyTorch遵循NCHW慣例,這意味著通道尺寸?必須在尺寸尺寸之前。我們將調整所有MSCOCO圖像的大小為256x256,以保持一致性。因此,饋送到模型的圖像必須是維度為N,3,256,256的浮動張量,并且必須通過前述的平均值和標準偏差進行歸一化。n為批量大小。字幕字幕是解碼器的目標和輸入,因為每個單詞都用來生成下一個單詞。然而,要生成第一個單詞,我們需要第零個單詞< start >。最后,我們應該預測解碼器必須學會預測字幕的結束。這是必要的,因為我們需要知道在推理過程中什么時候停止解碼。
例如:<start> a man holds a football <end>
因為我們將標題作為固定大小的張量傳遞,所以我們需要用< pad >標記將標題(自然長度可變)填充到相同的長度。
<start> a man holds a football <end> <pad> <pad> <pad>…
此外,我們創建一個word_map,它是語料庫中每個單詞的索引映射,包括<start>,<end>和<pad>標記。像其他庫一樣,PyTorch也需要編碼為索引的單詞來為其查找嵌入或標識其在預測單詞分數中的位置。
例如:9876 1 5 120 1 5406 9877 9878 9878 9878…
因此,提供給模型的字幕必須是尺寸為N,L的Int張量,其中L是填充長度。
Caption Lengths
由于字幕是填充的,因此我們需要跟蹤每個字幕的長度。這是實際長度+ 2(對于和標記)。
字幕長度也很重要,因為您可以使用PyTorch構建動態圖形。我們僅處理序列的長度,并且不會在上浪費計算量。
因此,提供給模型的字幕長度必須是維度N的Int張量。
Data pipeline
請參閱utils.py中的create_input_files()。
在我們保存這些文件,我們可以選擇只使用字幕是短于閾值,并且倉不太頻繁的話到標記。
我們將HDF5文件用于圖像,因為我們將在訓練/驗證期間直接從磁盤讀取它們。它們太大了,無法一次放入RAM。但是我們確實將所有字幕及其長度加載到內存中。
請參閱datasets.py中的CaptionDataset。
這是PyTorch數據集的子類。它需要定義一個__len__方法,該方法返回數據集的大小,以及一個__getitem__方法,該方法返回第i個圖像,標題和標題長度。
我們從磁盤讀取圖像,將像素轉換為[0,255],然后在此類內對其進行規范化。
PyTorch DataLoader在train.py中將使用該數據集,以創建一批數據并將其饋送到模型中以進行訓練或驗證。
Encoder
請參閱models.py中的編碼器。
我們使用PyTorch的Torchvision模塊中已經提供的經過預訓練的ResNet-101。丟棄最后兩層(池化層和線性層),因為我們只需要對圖像進行編碼,而無需對其進行分類。
我們確實添加了AdaptiveAvgPool2d()層,以將編碼大小調整為固定大小。這樣就可以將可變大小的圖像饋送到編碼器。 (但是,我們確實將輸入圖像的大小調整為256、256,因為我們必須將它們存儲為單個張量。)由于我們可能想對編碼器進行微調,因此我們添加了fine_tune()方法來啟用或禁用計算編碼器參數的梯度。我們僅在ResNet中微調卷積塊2到4,因為第一個卷積塊通常會學到一些非常重要的圖像處理基礎知識,例如檢測直線,邊緣,曲線等。我們不會打亂基準特征。
Attention
請參閱models.py中的Attention。
注意網絡很簡單–它僅由線性層和幾個激活組成。
單獨的線性層將解碼器的編碼圖像(展平為N,14 * 14,2048)和隱藏狀態(輸出)都轉換為相同尺寸,即。注意大小。然后添加它們并激活ReLU。第三線性層將此結果轉換為1的維度,隨后我們應用softmax生成權重alpha。
Decoder
請參閱models.py中的DecoderWithAttention。
此處接收編碼器的輸出,并將其展平為N,14 * 14,2048尺寸。這很方便,并且避免了多次調整張量的形狀。
我們使用init_hidden_??state()方法使用編碼圖像初始化LSTM的隱藏狀態和單元狀態,該方法使用兩個單獨的線性層。
首先,我們通過減少字幕長度來對N個圖像和字幕進行排序。這樣一來,我們只能處理有效的時間步,即不能處理。
我們可以遍歷每個時間步,僅處理有色區域,該區域是該時間步的有效批次大小N_t。通過排序,可以使任何時間步長的頂部N_t與上一步的輸出對齊。例如,在第三時間步,我們使用上一步的前5個輸出僅處理前5個圖像。
使用PyTorch LSTMCell在for循環中手動執行此迭代,而不是使用PyTorch LSTM在沒有循環的情況下自動迭代。這是因為我們需要在每個解碼步驟之間執行Attention機制。 LSTMCell是單個時間步操作,而LSTM將連續地在多個時間步上迭代并立即提供所有輸出。
我們使用Attention網絡在每個時間步計算權重和注意力加權編碼。在論文的第4.2.1節中,他們建議通過濾波器或門傳遞注意力加權編碼。此門是解碼器先前隱藏狀態的S型激活線性變換。作者指出,這有助于Attention網絡將更多的重點放在圖像中的對象上。
我們將過濾后的注意力加權編碼與上一個單詞的嵌入(開始)連接起來,然后運行LSTMCell生成新的隱藏狀態(或輸出)。線性層將這種新的隱藏狀態轉換為詞匯表中每個單詞的分數,并將其存儲起來。
我們還存儲每個時間步長的注意力網絡返回的權重。您會很快明白為什么。
代碼
數據集初始化 create_input_files.py
為了快速訓練和跑通
選擇使用flickr8k數據集,樣本量較小,調整和訓練都節約時間。
訓練 train.py
原文代碼pytorch0.4 這里的代碼是修改后可以在 pytorch1.0以后版本可以運行的。
import time import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data import torchvision.transforms as transforms from torch import nn from torch.nn.utils.rnn import pack_padded_sequence from models import Encoder, DecoderWithAttention from datasets import * from utils import * from nltk.translate.bleu_score import corpus_bleu torch.cuda.set_device(9)# Data parameters data_folder = '/home/wy/docker/resource/cocodataset/Flickr8k/data' # folder with data files saved by create_input_files.py data_name = 'flickr8k_5_cap_per_img_5_min_word_freq' # base name shared by data files# Model parameters emb_dim = 512 # dimension of word embeddings attention_dim = 512 # dimension of attention linear layers decoder_dim = 512 # dimension of decoder RNN dropout = 0.5 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead# Training parameters start_epoch = 0 epochs = 10 # number of epochs to train for (if early stopping is not triggered) epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU batch_size = 32 workers = 1 # for data-loading; right now, only 1 works with h5py encoder_lr = 1e-4 # learning rate for encoder if fine-tuning decoder_lr = 4e-4 # learning rate for decoder grad_clip = 5. # clip gradients at an absolute value of alpha_c = 1. # regularization parameter for 'doubly stochastic attention', as in the paper best_bleu4 = 0. # BLEU-4 score right now print_freq = 100 # print training/validation stats every __ batches fine_tune_encoder = False # fine-tune encoder? checkpoint = None # path to checkpoint, None if nonedef main():"""Training and validation."""global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map# Read word mapword_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')with open(word_map_file, 'r') as j:word_map = json.load(j)# Initialize / load checkpointif checkpoint is None:decoder = DecoderWithAttention(attention_dim=attention_dim,embed_dim=emb_dim,decoder_dim=decoder_dim,vocab_size=len(word_map),dropout=dropout)decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),lr=decoder_lr)encoder = Encoder()encoder.fine_tune(fine_tune_encoder)encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),lr=encoder_lr) if fine_tune_encoder else Noneelse:checkpoint = torch.load(checkpoint)start_epoch = checkpoint['epoch'] + 1epochs_since_improvement = checkpoint['epochs_since_improvement']best_bleu4 = checkpoint['bleu-4']decoder = checkpoint['decoder']decoder_optimizer = checkpoint['decoder_optimizer']encoder = checkpoint['encoder']encoder_optimizer = checkpoint['encoder_optimizer']if fine_tune_encoder is True and encoder_optimizer is None:encoder.fine_tune(fine_tune_encoder)encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),lr=encoder_lr)# Move to GPU, if availabledecoder = decoder.to(device)encoder = encoder.to(device)# Loss functioncriterion = nn.CrossEntropyLoss().to(device)# Custom dataloadersnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])train_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)val_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)# Epochsfor epoch in range(start_epoch, epochs):# Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20if epochs_since_improvement == 20:breakif epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:adjust_learning_rate(decoder_optimizer, 0.8)if fine_tune_encoder:adjust_learning_rate(encoder_optimizer, 0.8)# One epoch's trainingtrain(train_loader=train_loader,encoder=encoder,decoder=decoder,criterion=criterion,encoder_optimizer=encoder_optimizer,decoder_optimizer=decoder_optimizer,epoch=epoch)# One epoch's validationrecent_bleu4 = validate(val_loader=val_loader,encoder=encoder,decoder=decoder,criterion=criterion)# Check if there was an improvementis_best = recent_bleu4 > best_bleu4best_bleu4 = max(recent_bleu4, best_bleu4)if not is_best:epochs_since_improvement += 1print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))else:epochs_since_improvement = 0# Save checkpoint 保存在utils包里面,全部倒導入的。save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,decoder_optimizer, recent_bleu4, is_best)def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):"""Performs one epoch's training.:param train_loader: DataLoader for training data:param encoder: encoder model:param decoder: decoder model:param criterion: loss layer:param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning):param decoder_optimizer: optimizer to update decoder's weights:param epoch: epoch number"""decoder.train() # train mode (dropout and batchnorm is used)encoder.train()#utils 方法 AverageMeterbatch_time = AverageMeter() # forward prop. + back prop. timedata_time = AverageMeter() # data loading timelosses = AverageMeter() # loss (per word decoded)top5accs = AverageMeter() # top5 accuracystart = time.time()# Batchesfor i, (imgs, caps, caplens) in enumerate(train_loader):data_time.update(time.time() - start)# Move to GPU, if availableimgs = imgs.to(device)caps = caps.to(device)caplens = caplens.to(device)# Forward prop.imgs = encoder(imgs)scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>targets = caps_sorted[:, 1:]# Remove timesteps that we didn't decode at, or are pads# pack_padded_sequence is an easy trick to do thisscores = pack_padded_sequence(scores, decode_lengths, batch_first=True).datatargets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data# Calculate lossloss = criterion(scores, targets)# Add doubly stochastic attention regularizationloss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()# Back prop.decoder_optimizer.zero_grad()if encoder_optimizer is not None:encoder_optimizer.zero_grad()loss.backward()# Clip gradientsif grad_clip is not None:clip_gradient(decoder_optimizer, grad_clip)if encoder_optimizer is not None:clip_gradient(encoder_optimizer, grad_clip)# Update weightsdecoder_optimizer.step()if encoder_optimizer is not None:encoder_optimizer.step()# Keep track of metricstop5 = accuracy(scores, targets, 5)losses.update(loss.item(), sum(decode_lengths))top5accs.update(top5, sum(decode_lengths))batch_time.update(time.time() - start)start = time.time()# Print statusif i % print_freq == 0:print('Epoch: [{0}][{1}/{2}]\t''Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),batch_time=batch_time,data_time=data_time, loss=losses,top5=top5accs))def validate(val_loader, encoder, decoder, criterion):"""Performs one epoch's validation.:param val_loader: DataLoader for validation data.:param encoder: encoder model:param decoder: decoder model:param criterion: loss layer:return: BLEU-4 score"""decoder.eval() # eval mode (no dropout or batchnorm)if encoder is not None:encoder.eval()batch_time = AverageMeter()losses = AverageMeter()top5accs = AverageMeter()start = time.time()references = list() # references (true captions) for calculating BLEU-4 scorehypotheses = list() # hypotheses (predictions)# explicitly disable gradient calculation to avoid CUDA memory error# solves the issue #57with torch.no_grad():# Batchesfor i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):# Move to device, if availableimgs = imgs.to(device)caps = caps.to(device)caplens = caplens.to(device)# Forward prop.if encoder is not None:imgs = encoder(imgs)scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>targets = caps_sorted[:, 1:]# Remove timesteps that we didn't decode at, or are pads# pack_padded_sequence is an easy trick to do thisscores_copy = scores.clone()scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).datatargets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data# Calculate lossloss = criterion(scores, targets)# Add doubly stochastic attention regularizationloss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()# Keep track of metricslosses.update(loss.item(), sum(decode_lengths))top5 = accuracy(scores, targets, 5)top5accs.update(top5, sum(decode_lengths))batch_time.update(time.time() - start)start = time.time()if i % print_freq == 0:print('Validation: [{0}/{1}]\t''Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,loss=losses, top5=top5accs))# Store references (true captions), and hypothesis (prediction) for each image# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]# Referencesallcaps = allcaps[sort_ind] # because images were sorted in the decoderfor j in range(allcaps.shape[0]):img_caps = allcaps[j].tolist()img_captions = list(map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],img_caps)) # remove <start> and padsreferences.append(img_captions)# Hypotheses_, preds = torch.max(scores_copy, dim=2)preds = preds.tolist()temp_preds = list()for j, p in enumerate(preds):temp_preds.append(preds[j][:decode_lengths[j]]) # remove padspreds = temp_predshypotheses.extend(preds)assert len(references) == len(hypotheses)# Calculate BLEU-4 scoresbleu4 = corpus_bleu(references, hypotheses)print('\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(loss=losses,top5=top5accs,bleu=bleu4))return bleu4if __name__ == '__main__':main()測試
使用了原git上提供的預訓練模型
import torch import torch.nn.functional as F import numpy as np import json import torchvision.transforms as transforms import matplotlib.pyplot as plt import matplotlib.cm as cm import skimage.transform import argparse from scipy.misc import imread, imresize from PIL import Imagetorch.cuda.set_device(9)device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device)def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):"""Reads an image and captions it with beam search.:param encoder: encoder model:param decoder: decoder model:param image_path: path to image:param word_map: word map:param beam_size: number of sequences to consider at each decode-step:return: caption, weights for visualization"""k = beam_sizevocab_size = len(word_map)# Read image and processimg = imread(image_path)#當為單通道圖像時,轉化為三通道if len(img.shape) == 2:img = img[:, :, np.newaxis] #增加緯度img = np.concatenate([img, img, img], axis=2) #拼接為三通道img = imresize(img, (256, 256))img = img.transpose(2, 0, 1)#矩陣轉置 通道數放在前面img = img / 255.img = torch.FloatTensor(img).to(device)normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])transform = transforms.Compose([normalize])image = transform(img) # (3, 256, 256)# Encodeimage = image.unsqueeze(0) # (1, 3, 256, 256)encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 1,14,14,2048enc_image_size = encoder_out.size(1)print('enc_image_size:',enc_image_size)encoder_dim = encoder_out.size(3)print('encoder_dim:',encoder_dim)# Flatten encodingencoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 1,196,2048#表示了圖像的196個區域各自的特征# print('encoder_out:',encoder_out)num_pixels = encoder_out.size(1)#第二位 196 #print('num_pixels:',num_pixels)# We'll treat the problem as having a batch size of k#print(encoder_out.size())encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)1->k緯度擴展,五份特征#print(encoder_out.size())# Tensor to store top k previous words at each step; now they're just <start>k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device) # (k, 1)#print('k_prev_words:',k_prev_words)# Tensor to store top k sequences; now they're just <start>seqs = k_prev_words # (k, 1)# Tensor to store top k sequences' scores; now they're just 0top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)# Tensor to store top k sequences' alphas; now they're just 1s 這里其實就是存儲每個字對應圖像上的關注區域,映射在14*14的張量上面seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)# Lists to store completed sequences, their alphas and scorescomplete_seqs = list()complete_seqs_alpha = list()complete_seqs_scores = list()# Start decodingstep = 1h, c = decoder.init_hidden_state(encoder_out)#h0print('h, c',h.size(),c.size())# s is a number less than or equal to k, because sequences are removed from this process once they hit <end>while True:embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) (5,隱層512)print('embeddings',embeddings.size())#encode的圖片表示 和 隱狀態awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)(5,2048(),5,196(attention 存儲字對應圖像各部分的權重))print(' awe, alpha',awe.size(),alpha.size())#0/0alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)(5,14,14)gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim)awe = gate * awe#給特征賦予權重h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)輸入(512,2048),(512,512)帶權重的特征和上一次的lstm輸出和細胞狀態值scores = decoder.fc(h) # (s, vocab_size)scores = F.log_softmax(scores, dim=1)print('scores',scores.size())# Add 每一句 含有多少詞 更新scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)print('top_k_scores,scores',top_k_scores.size(),scores.size())# For the first step, all k points will have the same scores (since same k previous words, h, c)if step == 1:top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)else:# Unroll and find top scores, and their unrolled indicestop_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 取詞,topprint('top_k_scores,top_k_words',top_k_scores.size(),top_k_words.size())# Convert unrolled indices to actual indices of scoresprev_word_inds = torch.floor_divide(top_k_words, vocab_size)#prev_word_inds = top_k_words / vocab_size # (s)next_word_inds = top_k_words % vocab_size # (s)print('top_k_scores,top_k_words,prev_word_inds,next_word_inds',top_k_words,top_k_scores,prev_word_inds,next_word_inds)# Add new words to sequences, alphasseqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)#詞加一seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], #詞對應圖像區域加一dim=1) # (s, step+1, enc_image_size, enc_image_size)# Which sequences are incomplete (didn't reach <end>)? 挑出這次循環完結的 句子incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) ifnext_word != word_map['<end>']]complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))# Set aside complete sequences 挑出完整序列if len(complete_inds) > 0:complete_seqs.extend(seqs[complete_inds].tolist()) #追加全部序列complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())complete_seqs_scores.extend(top_k_scores[complete_inds])k -= len(complete_inds) # reduce beam length accordingly# Proceed with incomplete sequencesif k == 0:break#更新參數 只保留未完全序列參數seqs = seqs[incomplete_inds]seqs_alpha = seqs_alpha[incomplete_inds]h = h[prev_word_inds[incomplete_inds]]c = c[prev_word_inds[incomplete_inds]]encoder_out = encoder_out[prev_word_inds[incomplete_inds]]top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)# Break if things have been going on too longif step > 50:breakstep += 1#標記 scores分數最高序列作為返回值。i = complete_seqs_scores.index(max(complete_seqs_scores))seq = complete_seqs[i]alphas = complete_seqs_alpha[i]return seq, alphasdef visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):"""Visualizes caption with weights at every word.Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb:param image_path: path to image that has been captioned:param seq: caption:param alphas: weights:param rev_word_map: reverse word mapping, i.e. ix2word:param smooth: smooth weights?"""image = Image.open(image_path)image = image.resize([14 * 12, 14 * 12], Image.LANCZOS)words = [rev_word_map[ind] for ind in seq]print(words)for t in range(len(words)):if t > 50:breakplt.subplot(np.ceil(len(words) / 5.), 5, t + 1)plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)plt.imshow(image)current_alpha = alphas[t, :]if smooth:alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=12, sigma=8)else:alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 12, 14 * 12])if t == 0:plt.imshow(alpha, alpha=0)else:plt.imshow(alpha, alpha=0.8)plt.set_cmap(cm.Greys_r)plt.axis('off')plt.show()import scipyprint(scipy.__version__) checkpoint = torch.load('./BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar', map_location=str(device)) decoder = checkpoint['decoder'] decoder = decoder.to(device) decoder.eval() encoder = checkpoint['encoder'] encoder = encoder.to(device) encoder.eval()# Load word map (word2ix) with open('./WORDMAP_coco_5_cap_per_img_5_min_word_freq.json', 'r') as j:word_map = json.load(j) rev_word_map = {v: k for k, v in word_map.items()} # ix2word # Encode, decode with attention and beam search seq, alphas = caption_image_beam_search(encoder, decoder,'img/q.jpg', word_map,5) alphas = torch.FloatTensor(alphas)# Visualize caption and attention of best sequence visualize_att('img/q.jpg', seq, alphas, rev_word_map,True)這里使用了齊天大圣作為測試圖片,輸出很有趣,一個長頭發的女人在看著相機。
更多細節請查看原文和git。
總結
以上是生活随笔為你收集整理的pytorch时空数据处理4——图像转文本/字幕Image-Captionning(二)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 数据库知识点汇总
- 下一篇: 文件生成BASE64,base64转文件