PyTorch之torch.nn.CrossEntropyLoss()
簡介
信息熵: 按照真實分布p來衡量識別一個樣本所需的編碼長度的期望,即平均編碼長度
交叉熵: 使用擬合分布q來表示來自真實分布p的編碼長度的期望,即平均編碼長度
多分類任務中的交叉熵損失函數(shù)
代碼
1)導入包
import torch import torch.nn as nn2)準備數(shù)據(jù)
在圖片單標簽分類時,輸入m張圖片,輸出一個m x N的Tensor,其中N是分類個數(shù)。比如輸入3張圖片,分三類,最后的輸出是一個3 x 3的Tensor,舉個例子:
3)計算概率分布
第123行分別是第123張圖片的結果,假設第123列分別是貓、狗和豬的分類得分。
然后對每一行使用Softmax,這樣可以得到每張圖片的概率分布。
這里dim的意思是計算Softmax的維度,這里設置dim=1,可以看到每一行的加和為1。比如第一行0.1022+0.3831+0.5147=1。
4)對Softmax的結果取自然對數(shù)
log_output=torch.log(soft_output) print('log_output:\n',log_output)
對比softmax與log的結合與nn.LogSoftmaxloss(負對數(shù)似然損失)的輸出結果,兩者是一致的。
5)NLLLoss
NLLLoss的結果就是把上面的輸出與y_label對應的那個值拿出來,再去掉負號,再求均值。
y_target中[1, 2, 0]對應上述第一行的第二個,第二行的第三個,第三行的第1個:
(0.9594+0.4241+0.5265)/3=0.6367
6) CrossEntropyLoss()
參考鏈接:
https://blog.csdn.net/qq_22210253/article/details/85229988
https://zhuanlan.zhihu.com/p/98785902
https://zhuanlan.zhihu.com/p/56638625
總結
以上是生活随笔為你收集整理的PyTorch之torch.nn.CrossEntropyLoss()的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Chrome浏览器密码框自动填充的bug
- 下一篇: Servlet的认识