Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练)
Pytorch基礎(chǔ)訓(xùn)練庫Pytorch-Base-Trainer(支持模型剪枝 分布式訓(xùn)練)
目錄
Pytorch基礎(chǔ)訓(xùn)練庫Pytorch-Base-Trainer(PBT)(支持分布式訓(xùn)練)
1.Introduction
2.Install
3.訓(xùn)練框架
?(1)訓(xùn)練引擎(Engine)
(2)回調(diào)函數(shù)(Callback)
4.使用方法
5.Example: 構(gòu)建自己的分類Pipeline
6.可視化
7.其他
- 開源不易,麻煩給個(gè)【Star】
- Github:?GitHub - PanJinquan/Pytorch-Base-Trainer: Pytorch分布式訓(xùn)練框架
- pip安裝包:?basetrainer · PyPI
- 博客地址:Pytorch基礎(chǔ)訓(xùn)練庫Pytorch-Base-Trainer(支持模型剪枝 分布式訓(xùn)練)_pan_jinquan的博客-CSDN博客
尊重原創(chuàng),轉(zhuǎn)載請(qǐng)注明出處:https://panjinquan.blog.csdn.net/article/details/122662902
GitHub - PanJinquan/Pytorch-Base-Trainer: Pytorch分布式訓(xùn)練框架Pytorch分布式訓(xùn)練框架. Contribute to PanJinquan/Pytorch-Base-Trainer development by creating an account on GitHub.https://github.com/PanJinquan/Pytorch-Base-Trainer
1.Introduction
考慮到深度學(xué)習(xí)訓(xùn)練過程都有一套約定成俗的流程,鄙人借鑒Keras開發(fā)了一套基礎(chǔ)訓(xùn)練庫:?Pytorch-Base-Trainer(PBT); 這是一個(gè)基于Pytorch開發(fā)的基礎(chǔ)訓(xùn)練庫,支持以下特征:
- ?支持多卡訓(xùn)練訓(xùn)練(DP模式)和分布式多卡訓(xùn)練(DDP模式),參考build_model_parallel
- ?支持argparse命令行指定參數(shù),也支持config.yaml配置文件
- ?支持最優(yōu)模型保存ModelCheckpoint
- ?支持自定義回調(diào)函數(shù)Callback
- ?支持NNI模型剪枝(L1/L2-Pruner,FPGM-Pruner Slim-Pruner)nni_pruning
- ?非常輕便,安裝簡(jiǎn)單
誠(chéng)然,諸多大公司已經(jīng)開源基礎(chǔ)庫,如MMClassification,MMDetection等庫; 但礙于這些開源庫安裝麻煩,依賴庫多,版本差異大等問題;鄙人開發(fā)了一套比較基礎(chǔ)的訓(xùn)練Pipeline:Pytorch-Base-Trainer(PBT), 基于PBT可以快速搭建自己的訓(xùn)練工程; 目前,基于PBT完成了通用分類庫(PBTClassification),通用檢測(cè)庫(PBTDetection),通用語義分割庫( PBTSegmentation)以及,通用姿態(tài)檢測(cè)庫(PBTPose)
| PBTClassification | 通用分類庫 | 集成常用的分類模型,支持多種數(shù)據(jù)格式,樣本重采樣 |
| PBTDetection | 通用檢測(cè)庫 | 集成常用的檢測(cè)類模型,如RFB,SSD和YOLOX |
| PBTSegmentation | 通用語義分割庫 | 集成常用的語義分割模型,如DeepLab,UNet等 |
| PBTPose | 通用姿態(tài)檢測(cè)庫 | 集成常用的人體姿態(tài)估計(jì)模型,如UDP,Simple-base-line |
基于PBT框架訓(xùn)練的模型,已經(jīng)形成了一套完整的Android端上部署流程,支持CPU和GPU
| CPU/GPU:70/50ms | CPU/GPU:30/20ms | CPU/GPU:150/30ms |
PS:受商業(yè)保護(hù),目前,僅開源Pytorch-Base-Trainer(PBT),基于PBT的分類,檢測(cè)和分割以及姿態(tài)估計(jì)訓(xùn)練庫,暫不開源。
2.Install
- 源碼安裝
- pip安裝
- 使用NNI?模型剪枝工具,需要安裝NNI
3.訓(xùn)練框架
PBT基礎(chǔ)訓(xùn)練庫定義了一個(gè)基類(Base),所有訓(xùn)練引擎(Engine)以及回調(diào)函數(shù)(Callback)都會(huì)繼承基類。
?(1)訓(xùn)練引擎(Engine)
Engine類實(shí)現(xiàn)了訓(xùn)練/測(cè)試的迭代方法(如on_batch_begin,on_batch_end),其迭代過程參考如下,用戶可以根據(jù)自己的需要自定義迭代過程:
self.on_train_begin() for epoch in range(num_epochs):self.set_model() # 設(shè)置模型# 開始訓(xùn)練self.on_epoch_begin() # 開始每個(gè)epoch調(diào)用for inputs in self.train_dataset:self.on_batch_begin() # 每次迭代開始時(shí)回調(diào)self.run_step() # 每次迭代返回outputs, lossesself.on_train_summary() # 每次迭代,訓(xùn)練結(jié)束時(shí)回調(diào)self.on_batch_end() # 每次迭代結(jié)束時(shí)回調(diào)# 開始測(cè)試self.on_test_begin()for inputs in self.test_dataset:self.run_step() # 每次迭代返回outputs, lossesself.on_test_summary() # 每次迭代,測(cè)試結(jié)束時(shí)回調(diào)self.on_test_end() # 結(jié)束測(cè)試# 結(jié)束當(dāng)前epochself.on_epoch_end() self.on_train_end()EngineTrainer類繼承Engine類,用戶需要繼承該類,并實(shí)現(xiàn)相關(guān)接口:
| build_train_loader | 定義訓(xùn)練數(shù)據(jù) |
| build_test_loader | 定義測(cè)試數(shù)據(jù) |
| build_model | 定義模型 |
| build_optimizer | 定義優(yōu)化器 |
| build_criterion | 定義損失函數(shù) |
| build_callbacks | 定義回調(diào)函數(shù) |
另外,EngineTrainer類還是實(shí)現(xiàn)了兩個(gè)重要的類方法(build_dataloader和build_model_parallel),用于構(gòu)建分布式訓(xùn)練
| build_dataloader | 用于構(gòu)建加載方式,參數(shù)distributed設(shè)置是否使用分布式加載數(shù)據(jù) |
| build_model_parallel | 用于構(gòu)建模型,參數(shù)distributed設(shè)置是否使用分布式訓(xùn)練模型 |
(2)回調(diào)函數(shù)(Callback)
每個(gè)回調(diào)函數(shù)都需要繼承(Callback),用戶在回調(diào)函數(shù)中,可實(shí)現(xiàn)對(duì)迭代方法輸入/輸出的處理,例如:
| LogHistory | Log歷史記錄回調(diào)函數(shù),可使用Tensorboard可視化 |
| ModelCheckpoint | 保存模型回調(diào)函數(shù),可選擇最優(yōu)模型保存 |
| LossesRecorder | 單個(gè)Loss歷史記錄回調(diào)函數(shù),可計(jì)算每個(gè)epoch的平均值 |
| MultiLossesRecorder | 用于多任務(wù)Loss的歷史記錄回調(diào)函數(shù) |
| AccuracyRecorder | 用于計(jì)算分類Accuracy回調(diào)函數(shù) |
| get_scheduler | 各種學(xué)習(xí)率調(diào)整策略(MultiStepLR,CosineAnnealingLR,ExponentialLR)的回調(diào)函數(shù) |
4.使用方法
basetrainer使用方法可以參考example.py,構(gòu)建自己的訓(xùn)練器,可通過如下步驟實(shí)現(xiàn):
- step1: 新建一個(gè)類ClassificationTrainer,繼承trainer.EngineTrainer
- step2: 實(shí)現(xiàn)接口
5.Example: 構(gòu)建自己的分類Pipeline
- basetrainer使用方法可以參考example.py
- 目標(biāo)支持的backbone有:resnet[18,34,50,101], ,mobilenet_v2等,詳見backbone等 ,其他backbone可以自定義添加
- 訓(xùn)練參數(shù)可以通過兩種方法指定: (1) 通過argparse命令行指定 (2)通過config.yaml配置文件,當(dāng)存在同名參數(shù)時(shí),以配置文件為默認(rèn)值
| train_data | str, list | - | 訓(xùn)練數(shù)據(jù)文件,可支持多個(gè)文件 |
| test_data | str, list | - | 測(cè)試數(shù)據(jù)文件,可支持多個(gè)文件 |
| work_dir | str | work_space | 訓(xùn)練輸出工作空間 |
| net_type | str | resnet18 | backbone類型,{resnet,resnest,mobilenet_v2,...} |
| input_size | list | [128,128] | 模型輸入大小[W,H] |
| batch_size | int | 32 | batch size |
| lr | float | 0.1 | 初始學(xué)習(xí)率大小 |
| optim_type | str | SGD | 優(yōu)化器,{SGD,Adam} |
| loss_type | str | CELoss | 損失函數(shù) |
| scheduler | str | multi-step | 學(xué)習(xí)率調(diào)整策略,{multi-step,cosine} |
| milestones | list | [30,80,100] | 降低學(xué)習(xí)率的節(jié)點(diǎn),僅僅scheduler=multi-step有效 |
| momentum | float | 0.9 | SGD動(dòng)量因子 |
| num_epochs | int | 120 | 循環(huán)訓(xùn)練的次數(shù) |
| num_warn_up | int | 3 | warn_up的次數(shù) |
| num_workers | int | 12 | DataLoader開啟線程數(shù) |
| weight_decay | float | 5e-4 | 權(quán)重衰減系數(shù) |
| gpu_id | list | [ 0 ] | 指定訓(xùn)練的GPU卡號(hào),可指定多個(gè) |
| log_freq | in | 20 | 顯示LOG信息的頻率 |
| finetune | str | model.pth | finetune的模型 |
| use_prune | bool | True | 是否進(jìn)行模型剪枝 |
| progress | bool | True | 是否顯示進(jìn)度條 |
| distributed | bool | False | 是否使用分布式訓(xùn)練 |
?一個(gè)簡(jiǎn)單分類例子如下:
# -*-coding: utf-8 -*- """@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2021-07-28 22:09:32 """ import os import syssys.path.append(os.getcwd()) import argparse import basetrainer from torchvision import transforms from torchvision.datasets import ImageFolder from basetrainer.engine import trainer from basetrainer.engine.launch import launch from basetrainer.criterion.criterion import get_criterion from basetrainer.metric import accuracy_recorder from basetrainer.callbacks import log_history, model_checkpoint, losses_recorder, multi_losses_recorder from basetrainer.scheduler import build_scheduler from basetrainer.optimizer.build_optimizer import get_optimizer from basetrainer.utils import log, file_utils, setup_config, torch_tools from basetrainer.models import build_modelsprint(basetrainer.__version__)class ClassificationTrainer(trainer.EngineTrainer):""" Training Pipeline """def __init__(self, cfg):super(ClassificationTrainer, self).__init__(cfg)torch_tools.set_env_random_seed()cfg.model_root = os.path.join(cfg.work_dir, "model")cfg.log_root = os.path.join(cfg.work_dir, "log")if self.is_main_process:file_utils.create_dir(cfg.work_dir)file_utils.create_dir(cfg.model_root)file_utils.create_dir(cfg.log_root)file_utils.copy_file_to_dir(cfg.config_file, cfg.work_dir)setup_config.save_config(cfg, os.path.join(cfg.work_dir, "setup_config.yaml"))self.logger = log.set_logger(level="debug",logfile=os.path.join(cfg.log_root, "train.log"),is_main_process=self.is_main_process)# build projectself.build(cfg)self.logger.info("=" * 60)self.logger.info("work_dir :{}".format(cfg.work_dir))self.logger.info("config_file :{}".format(cfg.config_file))self.logger.info("gpu_id :{}".format(cfg.gpu_id))self.logger.info("main device :{}".format(self.device))self.logger.info("num_samples(train):{}".format(self.num_samples))self.logger.info("num_classes :{}".format(cfg.num_classes))self.logger.info("mean_num :{}".format(self.num_samples / cfg.num_classes))self.logger.info("=" * 60)def build_optimizer(self, cfg, **kwargs):"""build_optimizer"""self.logger.info("build_optimizer")self.logger.info("optim_type:{},init_lr:{},weight_decay:{}".format(cfg.optim_type, cfg.lr, cfg.weight_decay))optimizer = get_optimizer(self.model,optim_type=cfg.optim_type,lr=cfg.lr,momentum=cfg.momentum,weight_decay=cfg.weight_decay)return optimizerdef build_criterion(self, cfg, **kwargs):"""build_criterion"""self.logger.info("build_criterion,loss_type:{},num_classes:{}".format(cfg.loss_type, cfg.num_classes))criterion = get_criterion(cfg.loss_type, cfg.num_classes, device=self.device)return criteriondef build_train_loader(self, cfg, **kwargs):"""build_train_loader"""self.logger.info("build_train_loader,input_size:{}".format(cfg.input_size))transform = transforms.Compose([transforms.Resize([int(128 * cfg.input_size[1] / 112), int(128 * cfg.input_size[0] / 112)]),transforms.RandomHorizontalFlip(),transforms.RandomCrop([cfg.input_size[1], cfg.input_size[0]]),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])dataset = ImageFolder(root=cfg.train_data, transform=transform)cfg.num_classes = len(dataset.classes)cfg.classes = dataset.classesloader = self.build_dataloader(dataset, cfg.batch_size, cfg.num_workers, phase="train",shuffle=True, pin_memory=False, drop_last=True, distributed=cfg.distributed)return loaderdef build_test_loader(self, cfg, **kwargs):"""build_test_loader"""self.logger.info("build_test_loader,input_size:{}".format(cfg.input_size))transform = transforms.Compose([transforms.Resize([int(128 * cfg.input_size[1] / 112), int(128 * cfg.input_size[0] / 112)]),transforms.CenterCrop([cfg.input_size[1], cfg.input_size[0]]),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])dataset = ImageFolder(root=cfg.train_data, transform=transform)loader = self.build_dataloader(dataset, cfg.batch_size, cfg.num_workers, phase="test",shuffle=False, pin_memory=False, drop_last=False, distributed=False)return loaderdef build_model(self, cfg, **kwargs):"""build_model"""self.logger.info("build_model,net_type:{}".format(cfg.net_type))model = build_models.get_models(net_type=cfg.net_type, input_size=cfg.input_size,num_classes=cfg.num_classes, pretrained=True)if cfg.finetune:self.logger.info("finetune:{}".format(cfg.finetune))state_dict = torch_tools.load_state_dict(cfg.finetune)model.load_state_dict(state_dict)if cfg.use_prune:from basetrainer.pruning import nni_pruningsparsity = 0.2self.logger.info("use_prune:{},sparsity:{}".format(cfg.use_prune, sparsity))model = nni_pruning.model_pruning(model,input_size=[1, 3, cfg.input_size[1], cfg.input_size[0]],sparsity=sparsity,reuse=False,output_prune=os.path.join(cfg.work_dir, "prune"))model = self.build_model_parallel(model, cfg.gpu_id, distributed=cfg.distributed)return modeldef build_callbacks(self, cfg, **kwargs):"""定義回調(diào)函數(shù)"""self.logger.info("build_callbacks")# 準(zhǔn)確率記錄回調(diào)函數(shù)acc_record = accuracy_recorder.AccuracyRecorder(target_names=cfg.classes,indicator="Accuracy")# loss記錄回調(diào)函數(shù)loss_record = losses_recorder.LossesRecorder(indicator="loss")# Tensorboard Log等歷史記錄回調(diào)函數(shù)history = log_history.LogHistory(log_dir=cfg.log_root,log_freq=cfg.log_freq,logger=self.logger,indicators=["loss", "Accuracy"],is_main_process=self.is_main_process)# 模型保存回調(diào)函數(shù)checkpointer = model_checkpoint.ModelCheckpoint(model=self.model,optimizer=self.optimizer,moder_dir=cfg.model_root,epochs=cfg.num_epochs,start_save=-1,indicator="Accuracy",logger=self.logger)# 學(xué)習(xí)率調(diào)整策略回調(diào)函數(shù)lr_scheduler = build_scheduler.get_scheduler(cfg.scheduler,optimizer=self.optimizer,lr_init=cfg.lr,num_epochs=cfg.num_epochs,num_steps=self.num_steps,milestones=cfg.milestones,num_warn_up=cfg.num_warn_up)callbacks = [acc_record,loss_record,lr_scheduler,history,checkpointer]return callbacksdef run(self, logs: dict = {}):self.logger.info("start train")super().run(logs)def main(cfg):t = ClassificationTrainer(cfg)return t.run()def get_parser():parser = argparse.ArgumentParser(description="Training Pipeline")parser.add_argument("-c", "--config_file", help="configs file", default="configs/config.yaml", type=str)# parser.add_argument("-c", "--config_file", help="configs file", default=None, type=str)parser.add_argument("--train_data", help="train data", default="./data/dataset/train", type=str)parser.add_argument("--test_data", help="test data", default="./data/dataset/val", type=str)parser.add_argument("--work_dir", help="work_dir", default="output", type=str)parser.add_argument("--input_size", help="input size", nargs="+", default=[224, 224], type=int)parser.add_argument("--batch_size", help="batch_size", default=32, type=int)parser.add_argument("--gpu_id", help="specify your GPU ids", nargs="+", default=[0], type=int)parser.add_argument("--num_workers", help="num_workers", default=0, type=int)parser.add_argument("--num_epochs", help="total epoch number", default=50, type=int)parser.add_argument("--scheduler", help=" learning scheduler: multi-step,cosine", default="multi-step", type=str)parser.add_argument("--milestones", help="epoch stages to decay learning rate", nargs="+",default=[10, 20, 40], type=int)parser.add_argument("--num_warn_up", help="num_warn_up", default=3, type=int)parser.add_argument("--net_type", help="net_type", default="mobilenet_v2", type=str)parser.add_argument("--finetune", help="finetune model file", default=None, type=str)parser.add_argument("--loss_type", help="loss_type", default="CELoss", type=str)parser.add_argument("--optim_type", help="optim_type", default="SGD", type=str)parser.add_argument("--lr", help="learning rate", default=0.1, type=float)parser.add_argument("--weight_decay", help="weight_decay", default=0.0005, type=float)parser.add_argument("--momentum", help="momentum", default=0.9, type=float)parser.add_argument("--log_freq", help="log_freq", default=10, type=int)parser.add_argument('--use_prune', action='store_true', help='use prune', default=False)parser.add_argument('--progress', action='store_true', help='display progress bar', default=True)parser.add_argument('--distributed', action='store_true', help='use distributed training', default=False)parser.add_argument('--polyaxon', action='store_true', help='polyaxon', default=False)return parserif __name__ == "__main__":parser = get_parser()cfg = setup_config.parser_config(parser.parse_args(), cfg_updata=True)launch(main,num_gpus_per_machine=len(cfg.gpu_id),dist_url="tcp://127.0.0.1:28661",num_machines=1,machine_rank=0,distributed=cfg.distributed,args=(cfg,))6.可視化
目前訓(xùn)練過程可視化工具是使用Tensorboard,使用方法:
tensorboard --logdir=path/to/log/7.其他
| 聯(lián)系方式 | pan_jinquan@163.com |
總結(jié)
以上是生活随笔為你收集整理的Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 结构光三维重建Projector-Cam
- 下一篇: android 6.0权限