nn.损失函数
nn.L1Loss
https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html#torch.nn.L1Loss
例子:
input=[1,3,4]
target=[2,3,7]
則loss=(|2-1|+|3-3|+|7-4|)/3=4/3=1.333
其中可以加入?yún)?shù),求和,默認(rèn)是求平均
loss_fn2=nn.L1Loss(reduction='sum') loss2=loss_fn2(input,target) print("loss2:{}".format(loss2)) loss2:4.0nn.MSELoss
均方差損失函數(shù)
https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
input=[1,3,4]
target=[2,3,7]
loss=(2-1)**2+(3-3)**2+(7-4)**2=0+9+1=10
默認(rèn)求平均,即10/3=3.3333,可通過(guò)指定進(jìn)行求和
loss_fn4=nn.MSELoss(reduction='sum') loss4=loss_fn4(input,target) print("loss4:{}".format(loss4)) loss4:10.0交叉熵?fù)p失函數(shù) CrossEntropyLoss
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
計(jì)算公式為:
比如對(duì)于一個(gè)3分類(lèi)的圖片,假設(shè)最終結(jié)果input為[0.2, 0.6 , 0.3],而正確的標(biāo)簽為1,即第2個(gè),注意這是單個(gè)圖片,input的尺寸為1x3,
此時(shí)的loss計(jì)算為:-x[1]+ln(exp(x[1])+exp(x[2])+exp(x[3])=-0.6+ln(exp(0.2)+exp(0.6)+exp(0.3))=0.88009
要注意輸入格式
如果batchsize為N,則input為NxC,C為有幾個(gè)分類(lèi),target為1xN,結(jié)果默認(rèn)是取平均,即除以batchsize的大小
比如:
input=【【0.2, 0.6 , 0.3】
【0.5,0.1,0.7】】
target=【1,2】
第二行的loss=0.8618,
兩行取平均,(0.8801+0.8618)/2=0.8709
https://www.bilibili.com/video/BV1hE411t7RN?p=23
總結(jié)
- 上一篇: 中科大c语言期末考试试卷,中科大–中科院
- 下一篇: STM32的串口中断详解