轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur
作者丨科技猛獸
編輯丨極市平臺
清華大學自動化系智能計算實驗室團隊開源基于 PyTorch 的視頻 (圖片) 去模糊框架 SimDeblur。
基于 PyTorch 的視頻 (圖片) 去模糊框架 SimDeblur
它的特點是:
- 全面: 涵蓋經典的視頻 (圖像) 去模糊算法,如 MSCNN, SRN, DeblurGAN, EDVR, 等等。
- 高效: 支持 DDP 多機多卡訓練。
- 輕量: 便于拓展,易上手,讓更多的人能更快地上手使用。
- 專注: 使我們在實現自己的新模型時只需要關注一個文件或很少的幾個文件。
Github link:
ljzycmd/SimDeblur?github.com
目錄
1 為什么要做這個開源框架?
1.1 怎么總是這幾個baseline?
1.2 同一個baseline,在不同論文中的質量差別很大
1.3 同一個baseline,同一個數據集實驗結果可比嗎?
1.4 低質量的代碼開源
2 SimDeblur: 基于PyTorch的視頻 (圖片) 去模糊框架
2.1 已實現模型
2.2 使用方法
2.3 代碼解讀
3 作者團隊信息
1 為什么要做這個開源框架?
在深度學習領域,有幾個問題我覺得很有必要提一下:
1.1 怎么總是這幾個baseline?
比如說
在檢測領域,baseline一般都有:
在分割領域,baseline一般都有:
在Vision Transformer領域,baseline一般都有:
在超分領域,baseline一般都有:
大家都不比較那些“最好”的baseline,而是去比較很 Popular 的baseline。
這就像買顯卡時,
1060說:我比960好。
1080說:我比960好。
2080Ti說:我比960好。
有很多自稱達到了 SOTA 的模型,漲到了比較高的性能,但是很難考證。所以后續研究者在選擇比較對象的時候就會選擇一些性能相對較低的,但是代碼高質量開源的論文去比較。原因有2點:
這樣做的好處是有百花齊放百家爭鳴的感覺。但壞處是有的真正好的 baseline 模型被忽略掉了,導致了劣幣驅逐良幣。
如果今天你問一個你所在領域的專家,隨便挑一個人,你問他:
" 我們這個任務目前最好的模型是哪個?"
他一定也很難回答。
你可能會問了:
" 這有啥難的?我直接把最新的論文都找出來,看看這個任務里面,誰超過baseline最多,誰提升的幅度最大,誰不就是最好的嗎?"
這就引出了第2個問題:
1.2 同一個baseline,在不同論文中的質量差別很大
這句話的意思是說:同一個baseline模型,相同的任務,不同論文中給出的結果性能是不同的。 為什么呢?
這是因為:很多研究者對baseline的復現,其實并沒有做到“全心全意”。換句話說,對baseline參數的調整其實帶有相當大的隨意性,對baseline的調整不會下過多的功夫,導致得到的baseline的性能沒有達到其可以達到的最佳狀態。
在這種情況下,如果你想比較2個自稱達到了SOTA的模型的性能,因為它們對比的baseline的性能有差距,所以假設它們都相對baseline漲了3個點,但其實它們的性能是有差別的,所以就不具備很好的可比性。可能甲把baseline調得非常好,另一個乙把baseline沒有調得很好,那么乙的提升就不具備很高的可信度。
你可能又會問了:
" 那我就直接找出baseline論文中給出的它在某個數據集上的性能,直接使用它的結果不就好了嗎?"
這就引出了第3個問題:
- 1.3 同一個baseline,同一個數據集實驗結果可比嗎?
即使baseline在用一個數據集上,其實驗結果也是不可比的。這是因為實驗中的很多其他變量無法得到相同的控制。比如在數據預處理環節,每篇論文所列的baseline方法是否做到了完全一致?再比如在超參數的設置上,每篇論文所列的baseline方法是否做到了完全相同?
我們看下面的2張圖,圖1是DeiT模型的超參數設置 (DeiT是一種用于分類任務的視覺Transformer模型),圖2是不同超參數設置下的模型性能對比。我們可以看到,相同的模型在相同的數據集下面,性能還是有差別的。所以這些看似不起眼的設定,其實是對模型的性能有著相對重大的影響,而這些卻不會出現在引用DeiT的論文里面。所以你可能會看到:相同的模型在相同的數據集下面,結果又是會出現很大的差異。假設我們有8個超參數,每個超參數只有2種選擇,那么不同的組合就多達282^{8}28種。
圖1:DeiT模型的超參數設置
圖2:DeiT模型不同超參數設置下的模型性能對比
總之這里想說的就是:很難保證 A 和 B 兩篇論文的一切實驗設置都是相同的。這就導致即使我們找到了 A 和 B 兩篇在相同的模型在相同的數據集下面進行的實驗,它們的結果也不是那么的可比。
你可能又會問了:
" 那很多論文都提供了開源代碼,我直接下載下來在自己的任務上跑跑不就行了嗎?"
這就引出了第4個問題:
1.4 低質量的代碼開源
目前一篇頂會論文開源代碼的最低要求是:能復現論文中所列的實驗結果。但遺憾的是,許多開源代碼根本無法達到這個要求。對于有些達到了這個要求的代碼,它們的可重用性也非常差,想把它移植到你自己的實驗環境下也十分地困難。我之前遇到過很多種奇葩的開源代碼,這里隨便舉一個例子 (具體的論文就不說了。。)。比如它做 NAS 的論文,開源的代碼里面沒有 NAS 搜索的代碼,只有模型的 model.py,那這樣的開源代碼就缺乏了最核心的 NAS 算法的開源,就是無意義的。那遇到這樣的情況可能一周過去了,你還是無法復現出原論文的結果,這時候開組會時:
導師:你這周干了啥?
你:復現某某某論文失敗了。
導師:這代碼不是開源了嗎,怎么還是復現不出來,你有沒有認真做實驗?
你:。。。。。。(委屈臉)
這種情況其實是很普遍且很不合理的情況,真的不是你的能力不行,而是目前領域中廣泛存在的問題,Are we really making progress?所以在目前領域文章看似百花齊放的前提下,其實隱藏著一個潛在的,使領域停滯不前的問題。
這里我在舉一個良性的例子。
比如2020年是視覺Transformer爆火的一年,從20年下半年開始一直持續到21年,Transformer模型被應用在了視覺的各個領域,想詳細了解的童鞋們可以參考:
科技猛獸:Vision Transformer 超詳細解讀 (原理分析+代碼解讀) (目錄)?zhuanlan.zhihu.com
但是,在2020年爆火的Vision Transformer背后,其實是有一個重要的依托,就是**Ross Wightman大佬創建的timm庫**。PyTorchImageModels,簡稱timm,包含很多種PyTorch的視覺模型,是一個巨大的PyTorch代碼集合,包括了一系列:
- image models
- layers
- utilities
- optimizers
- schedulers
- data-loaders / augmentations
- training / validation scripts
旨在將各種SOTA模型整合在一起,并具有復現ImageNet訓練結果的能力,詳細的介紹如下:
科技猛獸:視覺Transformer優秀開源工作:timm庫vision transformer代碼解讀?zhuanlan.zhihu.com
許多Vision Transformer,包含高引的DeiT,CaiT等,其實都是基于timm庫來實現的。所以這給了我們啟發:我們需要一個benchmark平臺,包含多種模型,使得它們在同一條件下得到公平的評測,這也是我們開發這一框架的初衷。
在設計這個框架時,我們的思想是:
- 首先它應該輕量,易上手,讓更多的人能更快地上手使用。
- 其次它應該高效,使使用者專注于模型的實現,對于訓練和評估的過程盡量少關心。
- 其次它應該靈活,適配不同的數據輸入格式和實驗設定。
- 最后就是專注,使我們在實現新模型時只需要關注一個文件。
2 SimDeblur: 基于PyTorch的視頻 (圖片) 去模糊框架
2.1 已實現模型
(粗體表示已經實現的模型,其他是待實現的模型)
-
Single Image Deblurring
- MSCNN [Paper, Project]
- SRN [Paper, Project]
-
Video Deblurring
- DBN [Paper, Project]
- STRCNN [paper]
- DBLRNet [Paper]
- EDVR [Paper, Project]
- STFAN [Paper, Project]
- IFIRNN [Paper]
- CDVD-TSP [Paper, Project]
- ESTRNN [Paper, Project]
-
Benchmarks
- GoPro [Paper, Data]
- DVD [Paper, Data]
- REDS [Paper, Data]
2.2 使用方法
1) 安裝依賴
Python 3 (Conda is recommended) Pytorch 1.5.1 (with GPU) CUDA 10.2+Clone the repositry or download the zip file:
git clone https://github.com/ljzycmd/SimDeblur.gitInstall SimDeblur:
# create a pytorch env conda create -n simdeblur python=3.7 conda activate simdeblur # install the packages cd SimDeblur bash Install.sh2) 使用默認的 trainer 來搭建一個訓練進程,如下所示:
from simdeblur.config import build_config, merge_args from simdeblur.engine.parse_arguments import parse_arguments from simdeblur.engine.trainer import Trainerargs = parse_arguments()cfg = build_config(args.config_file) cfg = merge_args(cfg, args) cfg.args = argstrainer = Trainer(cfg) trainer.train()3) 單卡訓練:
CUDA_VISIBLE_DEVICES=0 bash ./tools/train.sh ./config/dbn/dbn_dvd.yaml 14) 多卡訓練:
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./tools/train.sh ./config/dbn/dbn_dvd.yaml 4train.sh:
CONFIG=$1 GPUS=$2 PORT=${PORT:=10086} # PORT=10086 # single gpu training if [ GPUS == 1 ] then echo start single GPU training python train.py $CONFIG --gpus=$GPUS else echo start distributed training # distributed training PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \train.py $CONFIG --gpus=$GPUS fi5) 也可以直接通過 SimDeblur 中的函數構建各種模塊:
build the a dataset:
from easydict import EasyDict as edict from simdeblur.dataset import build_datasetdataset = build_dataset(edict({"name": "DVD","mode": "train","sampling": "n_c","overlapping": True,"interval": 1,"root_gt": "./dataset/DVD/quantitative_datasets","num_frames": 5,"augmentation": {"RandomCrop": {"size": [256, 256] },"RandomHorizontalFlip": {"p": 0.5 },"RandomVerticalFlip": {"p": 0.5 },"RandomRotation90": {"p": 0.5 },} }))print(dataset[0])build the model:
from simdeblur.model import build_backbonemodel = build_backbone({"name": "DBN","num_frames": 5,"in_channels": 3,"inner_channels": 64 })x = torch.randn(1, 5, 3, 256, 256) out = model(x)build the loss:
from simdeblur.model import build_losscriterion = build_loss({"name": "MSELoss", }) x = torch.randn(2, 3, 256, 256) y = torch.randn(2, 3, 256, 256) print(criterion(x, y))2.3 代碼解讀:
1 框架架構:
/configs
→ /dblrnet: dblrnet配置文件
→ /dbn: dbn配置文件
→ /edvr: edvr配置文件
→ /…
/datasets: 數據集位置
/docs
/simdeblur
→ __init__.py
→ /config
→ → __init__.py
→ → build.py:讀取配置信息的一些函數
→ → default_config.py:默認配置信息
→ /dataset
→ → __init__.py
→ → build.py:創建數據集的接口
→ → augment.py:數據增強的函數
→ → dvd.py
→ → gopro.py
→ → red.py
→ /engine
→ → __init__.py
→ → parse_arguments.py
→ → trainer.py:主要的訓練代碼
→ → hook.py
→ /model
→ → __init__.py
→ → build.py:創建模型的接口
→ → /backbone:各種 backbone 具體實現
→ → →/dblrnet:dblrnet 具體實現
→ → →/dbn:dbn 具體實現
→ → →/edvr:edvr 具體實現
→ → →/ifirnn:ifirnn 具體實現
→ → →/stfan:stfan 具體實現
→ → →/strcnn:strcnn 具體實現
→ → /layer:各種 layer 具體實現
→ → →__init__.py
→ → →non_local.py:non_local block 具體實現
→ → →res_block.py:殘差塊具體實現
→ → →vgg.py:VGG 塊具體實現
→ → /loss:各種損失函數具體實現
→ → →__init__.py
→ → →loss.py
→ → →perceptual_loss.py
→ → /meta_arch
→ /scheduler: 優化器和學習率 scheduler 函數
→ /utils: 打印日志的相關函數
/tools: 生成demo的一些工具函數,以及啟動文件 train.sh
/utils: 其它涉及到的一些工具函數
/requirements.txt: 運行需要的依賴庫
setup.py: 上傳 PYPI 需要的文件
test.py: 模型測試的接口文件,需要傳入.yaml格式的配置文件
train.py: 模型訓練的接口文件,需要傳入.yaml格式的配置文件
2 train.py:
import torchfrom simdeblur.config import build_config, merge_args from simdeblur.engine.parse_arguments import parse_arguments from simdeblur.engine.trainer import Trainerdef main():args = parse_arguments()cfg = build_config(args.config_file)cfg = merge_args(cfg, args)cfg.args = argstrainer = Trainer(cfg)trainer.train()if __name__ == "__main__":main()build_config:根據配置文件 (.yaml) 得到配置信息cfg (字典)。
merge_args:融合命令行參數。
得到包含了所有配置信息的變量 cfg,傳入Trainer類。
3 Trainer 類介紹:
(a) 定義 Trainer 類屬性:
from simdeblur.dataset import build_dataset from simdeblur.scheduler import build_optimizer, build_lr_scheduler from simdeblur.model import build_backbone, build_meta_arch, build_loss from simdeblur.utils.logger import LogBuffer, SimpleMetricPrinter, TensorboardWriter from simdeblur.utils.metrics import calculate_psnr, calculate_ssim from simdeblur.utils import dist_utilsfrom simdeblur.engine import hookslogging.basicConfig(format='%(asctime)s - %(levelname)s - SimDeblur: %(message)s',level=logging.INFO) logging.info("******* A simple deblurring framework ********")class Trainer:def __init__(self, cfg):"""Argscfg(edict): the config file, which contains arguments form comand line"""self.cfg = copy.deepcopy(cfg)# initialize the distributed trainingif cfg.args.gpus > 1:dist_utils.init_distributed(cfg)# create the working dirsself.current_work_dir = os.path.join(cfg.work_dir, cfg.name)if not os.path.exists(self.current_work_dir):os.makedirs(self.current_work_dir, exist_ok=True)self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# self.device = torch.device("cpu")# default loggerlogger = logging.getLogger("simdeblur")logger.setLevel(logging.INFO)logger.addHandler(logging.FileHandler(os.path.join(self.current_work_dir, self.cfg.name.split("_")[0] + ".json")))# construct the modulesself.model = self.build_model(cfg).to(self.device)self.criterion = build_loss(cfg.loss).to(self.device)self.train_dataloader, self.train_sampler = self.build_dataloder(cfg, mode="train")self.val_datalocaer, _ = self.build_dataloder(cfg, mode="val")self.optimizer = self.build_optimizer(cfg, self.model)self.lr_scheduler = self.build_lr_scheduler(cfg, self.optimizer)# trainer hooksself._hooks = self.build_hooks()# some induces when trainingself.epochs = 0self.iters = 0self.batch_idx = 0 self.start_epoch = 0self.start_iter = 0self.total_train_epochs = self.cfg.schedule.epochsself.total_train_iters = self.total_train_epochs * len(self.train_dataloader)# resume or load the ckpt as init-weightsif self.cfg.resume_from != "None":self.resume_or_load_ckpt(ckpt_path=self.cfg.resume_from)# log bufffer(dict to save) self.log_buffer = LogBuffer()(b) 每個 epoch 開始前 shuffle the dataloader when dist training:
def before_epoch(self):for h in self._hooks:h.before_epoch(self)# shuffle the data when dist training ...if self.train_sampler:self.train_sampler.set_epoch(self.epochs)(c) 每個 iteration 開始前 shuffle the dataloader when dist training:
def before_epoch(self):for h in self._hooks:h.before_epoch(self)# shuffle the data when dist training ...if self.train_sampler:self.train_sampler.set_epoch(self.epochs)(d) 準備輸入信息:
def preprocess(self, batch_data):"""prepare for input"""return batch_data["input_frames"].to(self.device)(e) 模型輸出的后處理:
def postprocess(self):"""post process for model outputs"""# When the outputs is a img tensorif isinstance(self.outputs, torch.Tensor) and self.outputs.dim() == 5:self.outputs = self.outputs.flatten(0, 1)(f) 計算損失:
def calculate_loss(self, batch_data, model_outputs):"""calculate the loss"""gt_frames = batch_data["gt_frames"].to(self.device).flatten(0, 1)if model_outputs.dim() == 5:model_outputs = model_outputs.flatten(0, 1) # (b*n, c, h, w)return self.criterion(gt_frames, model_outputs)(g) 優化器更新參數:
def update_params(self):"""update paramspipline: zero_grad, backward and update grad"""self.optimizer.zero_grad()self.loss.backward()self.optimizer.step()(h) 每個 iteration 或者 epoch 結束以后,使用 hook 干一些事情,比如:lr_scheduler 更新,calculate metrics,保存日志等等,具體可以查看 /simdeblur/engine.hook.py 文件。
def after_iter(self):for h in self._hooks:h.after_iter(self)def after_epoch(self):for h in self._hooks:h.after_epoch(self)(i) 根據以上工具函數寫訓練函數 train():
def train(self, **kwargs):self.model.train()self.before_train()logger = logging.getLogger("simdeblur")logger.info("Starting training...")for self.epochs in range(self.start_epoch, self.cfg.schedule.epochs):# shuffle the dataloader when dist training: dist_data_loader.set_epoch(epoch)self.before_epoch()for self.batch_idx, self.batch_data in enumerate(self.train_dataloader):self.before_iter()input_frames = self.preprocess(self.batch_data)self.outputs = self.model(input_frames)self.postprocess()self.loss = self.calculate_loss(self.batch_data, self.outputs)self.update_params()self.iters += 1self.after_iter()if self.epochs % self.cfg.schedule.val_epochs == 0:self.val()self.after_epoch()before_epoch(), after_epoch(), before_iter(), after_iter() 這四個函數都是通過 hook 來定義每個 epoch 之前或之后,每個 iteration 之前或之后要做的事情,具體可以查看 /simdeblur/engine.hook.py 文件。
3 作者團隊信息
曹銘登:
清華大學自動化系19級碩士,目前實習于騰訊 AI Lab。
郵箱:mingdengcao@gmail.com
王家豪:
清華大學自動化系19級碩士,目前實習于北京華為諾亞方舟實驗室。
郵箱:wang-jh19@mails.tsinghua.edu.cn
智能計算實驗室信息:
https://sites.google.com/view/iigroup-thu?sites.google.com
學術合作 or 溝通交流歡迎私信聯系~
cite as:
@Article{wang2021simdeblur,author = {Mingdeng Cao, Jiahao Wang},title = {清華智能計算實驗室團隊開源基于PyTorch的視頻 (圖片) 去模糊框架SimDeblur},journal = {https://zhuanlan.zhihu.com/},howpublished = {\url{https://github.com/ljzycmd/SimDeblur}},year = {2021},url= {https://zhuanlan.zhihu.com/p/368312516/}, }總結
以上是生活随笔為你收集整理的轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 霸榜COCO和Cityscapes!南理
- 下一篇: ICML2021|超越SE、CBAM,中