深度学习和目标检测系列教程 9-300:TorchVision和Albumentation性能对比,如何使用Albumentation对图片数据做数据增强
@Author:Runsen
上次對xml文件進行提取,使用到一個Albumentation模塊。Albumentation模塊是一個數據增強的工具,目標檢測圖像預處理通過使用“albumentation”來應用的,這是一個易于與PyTorch數據轉換集成的python庫。
Albumentation 是一種工具,可以在將(圖像/圖片)插入模型之前自定義 處理(彈性、網格、運動模糊、移位、縮放、旋轉、轉置、對比度、亮度等])到圖像/圖片。
對此,Albumentation 官方文檔:
- https://albumentations.ai/
為什么要看看這個東西?因為將 Torchvision 代碼重構為 Albumentation 的效果最好,運行更快。
上圖是使用 Intel Xeon Platinum 8168 CPU 在 ImageNet中通過 2000 個驗證集圖像的測試結果。每個單元格中的值表示在單個核心中處理的圖像數量。可以看到 Albumentation在許多轉換方面比所有其他庫至少高出 2 倍。
Albumentation Github 的官方 CPU 基準測試https://github.com/albumentations-team/albumentations
下面,我導入了下面的模塊:
from PIL import Image import time import torch import torchvision from torch.utils.data import Dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2 import numpy as np為了演示的目的,我找了一張前幾天畢業回校拍的照片
原始 TorchVision 數據管道
創建一個 Dataloader 來使用 PyTorch 和 Torchvision 處理圖像數據管道。
- 創建一個簡單的 Pytorch 數據集類
- 調用圖像并進行轉換
- 用 100 個循環測量整個處理時間
首先,從torch.utils.data獲取 Dataset抽象類,并創建一個 TorchVision數據集類。然后我插入圖像并使用__getitem__方法進行轉換。另外,我用來total_time = (time.time() - start_t測量需要多長時間
class TorchvisionDataset(Dataset):def __init__(self, file_paths, labels, transform=None):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]image = Image.open(file_path)start_t = time.time()if self.transform:image = self.transform(image)total_time = (time.time() - start_t)return image, label, total_time然后將圖像大小調整為 256x256(高度 * 重量)并隨機裁剪到 224x224 大小。然后以 50% 的概率應用水平翻轉并將其轉換為張量。輸入文件路徑應該是您的圖像所在的 Google Drive 的路徑。
torchvision_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(), ])torchvision_dataset = TorchvisionDataset(file_paths=["demo.jpg"],labels=[1],transform=torchvision_transform, )下面計算從 torchvision_dataset 中提取樣本圖像并對其進行轉換所花費的時間,然后運行 ??100 次循環以檢查它所花費的平均毫秒。
torchvision time/sample: 7.31137752532959 ms在torch中的GPU,原始 TorchVision 數據管道數據預處理的速度大約是0.0731137752532959 ms。最后輸出的圖像則為 224x224而且發生了翻轉!
Albumentation 數據管道
現在創建了一個 Albumentations Dataset 類,具體的transform和原始 TorchVision 數據管道完全一樣。
from PIL import Image import time import torch import torchvision from torch.utils.data import Dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2 import numpy as npclass AlbumentationsDataset(Dataset):"""__init__ and __len__ functions are the same as in TorchvisionDataset"""def __init__(self, file_paths, labels, transform=None):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]# Read an image with OpenCVimage = cv2.imread(file_path)# By default OpenCV uses BGR color space for color images,# so we need to convert the image to RGB color space.image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)start_t = time.time()if self.transform:augmented = self.transform(image=image)image = augmented['image']total_time = (time.time() - start_t)return image, label, total_timealbumentations_transform = albumentations.Compose([albumentations.Resize(256, 256),albumentations.RandomCrop(224, 224),albumentations.HorizontalFlip(), # Same with transforms.RandomHorizontalFlip()albumentations.pytorch.transforms.ToTensor() ]) albumentations_dataset = AlbumentationsDataset(file_paths=["demo.jpg"],labels=[1],transform=albumentations_transform, )total_time = 0 for i in range(100):sample, _, transform_time = albumentations_dataset[0]total_time += transform_timeprint("albumentations time/sample: {} ms".format(total_time*10))plt.figure(figsize=(10, 10)) plt.imshow(transforms.ToPILImage()(sample)) plt.show()具體輸出如下:
albumentations time/sample: 0.5056881904602051 ms在torch中的GPU,Albumentation 數據管道 數據管道數據預處理的速度大約是0.005056881904602051 ms。
因此,在真正的工業落地,基本需要將原始 TorchVision 數據管道改寫成Albumentation 數據管道,因為落地項目的速度很重要。
Albumentation數據增強
最后,我將展示如何使用albumentations中OneOf函數進行書增強,我個人覺得這個函數在 Albumentation 中非常有用。
from PIL import Image import time import torch import torchvision from torch.utils.data import Dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2class AlbumentationsDataset(Dataset):"""__init__ and __len__ functions are the same as in TorchvisionDataset"""def __init__(self, file_paths, labels, transform=None):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]image = cv2.imread(file_path)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transform:augmented = self.transform(image=image)image = augmented['image']return image, label# OneOf隨機采用括號內列出的變換之一。 # 我們甚至可以將發生的概率放在函數本身中。例如,如果 ([…], p=0.5) 之一,它會以 50% 的機會跳過整個變換,并以 1/6 的機會隨機選擇三個變換之一。 albumentations_transform_oneof = albumentations.Compose([albumentations.Resize(256, 256),albumentations.RandomCrop(224, 224),albumentations.OneOf([albumentations.HorizontalFlip(p=1),albumentations.RandomRotate90(p=1),albumentations.VerticalFlip(p=1)], p=1),albumentations.OneOf([albumentations.MotionBlur(p=1),albumentations.OpticalDistortion(p=1), albumentations.GaussNoise(p=1)], p=1),albumentations.pytorch.ToTensor() ])albumentations_dataset = AlbumentationsDataset(file_paths=["demo.jpg"],labels=[1],transform=albumentations_transform_oneof, )num_samples = 5 fig, ax = plt.subplots(1, num_samples, figsize=(25, 5)) for i in range(num_samples):ax[i].imshow(transforms.ToPILImage()(albumentations_dataset[0][0]))ax[i].axis('off')plt.show()
上面的OneOf是在水平翻轉、旋轉、垂直翻轉中隨機選擇,在模糊、失真、噪聲中隨機選擇。所以在這種情況下,我們允許 3x3 = 9 種組合
總結
以上是生活随笔為你收集整理的深度学习和目标检测系列教程 9-300:TorchVision和Albumentation性能对比,如何使用Albumentation对图片数据做数据增强的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 抵押车供求关系
- 下一篇: 深度学习和目标检测系列教程 10-300