一、Focal Loss理论及代码实现
文章目錄
- 前言
- 一、基本理論
- 二、實(shí)現(xiàn)
- 1.公式
- 2.代碼實(shí)現(xiàn)
- 1.基于二分類交叉熵實(shí)現(xiàn)。
- 2.知乎大佬的實(shí)現(xiàn)
前言
本文參考:幾時(shí)見得清夢(mèng)博主文章
參考原文:https://www.jianshu.com/p/30043bcc90b6
一、基本理論
1.采用soft - gamma: 在訓(xùn)練的過程中階段性的增大gamma 可能會(huì)有更好的性能提升。
2.alpha 與每個(gè)類別在訓(xùn)練數(shù)據(jù)中的頻率有關(guān)。
3.F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函數(shù)功能與F.cross_entropy相同。
F.nll_loss中實(shí)現(xiàn)了對(duì)于target的one-hot encoding,將其編碼成與input shape相同的tensor,然后與前面那一項(xiàng)(即F.nll_loss輸入的第一項(xiàng))進(jìn)行 element-wise production。
基于alpha=1采用不同的gamma值進(jìn)行實(shí)驗(yàn)的結(jié)果
4.focal loss解決了什么問題?
(1)不同類別不均衡
(2)難易樣本不均衡
5.在retinanet中,除了使用呢focal loss外,還對(duì)初始化做了特殊處理,具體是怎么做的?
在retinanet中,對(duì) classification subnet 的最后一層conv設(shè)置它的偏置b為:
二、實(shí)現(xiàn)
1.公式
標(biāo)準(zhǔn)的Cross Entropy 和Focal Loss 為:
關(guān)于的前向與后向推導(dǎo)見知乎:https://zhuanlan.zhihu.com/p/32631517
2.代碼實(shí)現(xiàn)
1.基于二分類交叉熵實(shí)現(xiàn)。
# 1.基于二分類交叉熵實(shí)現(xiàn)
class FocalLoss(nn.Module):def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):super(FocalLoss, self).__init__()self.alpha = alphaself.gamma = gammaself.logits = logitsself.reduce = reducedef forward(self, inputs, targets):if self.logits:BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)else:BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)pt = torch.exp(-BCE_loss)F_loss = self.alpha * (1-pt)**self.gamma * BCE_lossif self.reduce:return torch.mean(F_loss)else:return F_loss
2.知乎大佬的實(shí)現(xiàn)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variableclass FocalLoss(nn.Module):r"""This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection.Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])The losses are averaged across observations for each minibatch.Args:alpha(1D Tensor, Variable) : the scalar factor for this criteriongamma(float, double) : gamma > 0; reduces the relative loss for well-classi?ed examples (p > .5), putting more focus on hard, misclassi?ed examplessize_average(bool): By default, the losses are averaged over observations for each minibatch.However, if the field size_average is set to False, the losses areinstead summed for each minibatch."""def __init__(self, class_num, alpha=None, gamma=2, size_average=True):super(FocalLoss, self).__init__()if alpha is None:self.alpha = Variable(torch.ones(class_num, 1))else:if isinstance(alpha, Variable):self.alpha = alphaelse:self.alpha = Variable(alpha)self.gamma = gammaself.class_num = class_numself.size_average = size_averagedef forward(self, inputs, targets):N = inputs.size(0)C = inputs.size(1)P = F.softmax(inputs)class_mask = inputs.data.new(N, C).fill_(0)class_mask = Variable(class_mask)ids = targets.view(-1, 1)class_mask.scatter_(1, ids.data, 1.)#print(class_mask)if inputs.is_cuda and not self.alpha.is_cuda:self.alpha = self.alpha.cuda()alpha = self.alpha[ids.data.view(-1)]probs = (P*class_mask).sum(1).view(-1,1)log_p = probs.log()#print('probs size= {}'.format(probs.size()))#print(probs)batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p #print('-----bacth_loss------')#print(batch_loss)if self.size_average:loss = batch_loss.mean()else:loss = batch_loss.sum()return loss
``
總結(jié)
以上是生活随笔為你收集整理的一、Focal Loss理论及代码实现的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 一、PyTorch Cookbook(常
- 下一篇: 一、迁移学习与fine-tuning有什