理解二分類交叉熵|可視化的方法解釋對數(shù)損失
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
?
?如果你在訓(xùn)練一個二分類分類器,很有可能你在使用二值交叉熵,log損失,作為你的損失函數(shù)。
你有沒有想過,使用這個損失函數(shù)到底意味著什么?事實是,現(xiàn)在的各種庫和框架非常的簡單易用,導(dǎo)致大家很容易忽視所使用的損失函數(shù)的真正意義。
我一直在找一個可以通過可視化到的方法清楚而簡單的解釋二元交叉熵(log損失)的背后的真正含義,這樣我可以在 Data Science Retreat上展示給我的學生,但是我一直沒有找到。既然找不到我想要的,那我就自己來:-)
?
讓我們從10個隨機數(shù)開始:
x = [-2.2, -1.4, -0.8, 0.2, 0.4, 0.8, 1.2, 2.2, 2.9, 4.6]這就是我們唯一的特征:x

現(xiàn)在,我們給這些點涂上點顏色:紅色和綠色,作為標簽。
所以,我們的分類問題就很直觀了:給定了特征x,需要我們預(yù)測標簽:紅色或者綠色。
既然是個二分類,我們可以將這個問題描述成:“這個點是綠色的嗎?”,或者,“這個點是綠色的概率是多少?”,理想的狀態(tài)下,綠色點的概率應(yīng)該為1.0,同時紅色點的概率應(yīng)該為0.0。
在這樣的設(shè)定下,綠色點屬于正樣本,紅色點屬于負樣本。
如果我擬合一個模型來進行分類,預(yù)測每個點是綠色的概率。給定點的顏色,我們?nèi)绾蝸碓u估這個預(yù)測的概率的好壞?這就是損失函數(shù)的目的!損失函數(shù)對于好的預(yù)測將返回一個低的值,對于壞的預(yù)測,將返回一個高的值。
對于二分類,比如我們的例子,典型的損失函數(shù)就是二值交叉熵(對數(shù)損失)。
?
如果你仔細看看這個損失函數(shù),你會發(fā)現(xiàn):

y是標簽(1是綠色的,0是紅色的),p(y)是所有的N個點預(yù)測是綠色的概率。
這個公式告訴你,對于每一個綠色(y=1)的點,加了一個log(p(y))到損失中,這就是綠色的對數(shù)概率。相反的,對于每一個紅色(y=0)的點添加了log(1-p(y)),這個是紅色的對數(shù)概率。一點也不難,也很不直觀。
另外,熵和這些有個什么關(guān)系?為什么我們要首先取概率的對數(shù)?這才是有價值的問題,我希望在下面的 “Show me the math” 環(huán)節(jié)中回答。
但是,在我們開始更多的公式之前,我先給你展示一個上面公式的可視化的表示。
?
首先,我們根據(jù)類別將這些點分開,正樣本和負樣本,就像這樣:

現(xiàn)在,我們來訓(xùn)練邏輯回歸模型來分類我們的點。這個回歸的擬合是一個sigmoid的曲線,表示了給定的x是綠色的概率。就像這樣:

對于所有的屬于正樣本的點(綠色),我們的分類器給出的預(yù)測概率是什么?就是sigmoid曲線下面的綠色的條,x的坐標代表了這個點。

到現(xiàn)在為止,一切都好!那么負樣本的點呢?記住,sigmoid曲線之下的綠條表示的該點是綠色的概率。那么,給定的點是紅色的概率是多少呢?當然就是sigmoid曲線上面紅色條啦 :-)

把這些放在一起,我們得到了這樣的東西:

條子代表了每個點對應(yīng)的類別的預(yù)測的概率。
好了,我們有了預(yù)測的概率,是時候計算一下二值交叉熵/對數(shù)損失來評估一下了。
這些概率就是我們需要的東西,所以,我們不需要x的坐標了,我們把豎條一個挨一個排列起來。

現(xiàn)在,這些豎條不再有什么含義了,我們改變一下位置:

既然我們是想計算損失,我們需要懲罰壞的預(yù)測,是嗎?如果對應(yīng)類別的相關(guān)的概率是1.0,我們需要對應(yīng)的loss為零。對應(yīng)的,如果概率很低,比如0.01,我們希望損失很大!
結(jié)果就是,將概率值取對數(shù)能夠很好的滿足我們的需求(實際上,使用對數(shù)的原因是來自于交叉熵的定義)。
下面的圖顯示的很清楚,預(yù)測為真的概率值越趨向于零,損失指數(shù)增加:

很公平!我們?nèi)「怕实膶?shù)——這些就是每個點對應(yīng)的損失。
最后,我們計算所有損失的均值。

好了!我們成功的計算了二元交叉熵/對數(shù)損失的值,是0.3329!
如果你需要重復(fù)確認一些我們的發(fā)現(xiàn),運行下面的代碼,自己看!
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
import numpy as np
x = np.array([-2.2, -1.4, -.8, .2, .4, .8, 1.2, 2.2, 2.9, 4.6])
y = np.array([0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
logr = LogisticRegression(solver='lbfgs')
logr.fit(x.reshape(-1, 1), y)
y_pred = logr.predict_proba(x.reshape(-1, 1))[:, 1].ravel()
loss = log_loss(y, y_pred)
print('x = {}'.format(x))
print('y = {}'.format(y))
print('p(y) = {}'.format(np.round(y_pred, 2)))
print('Log Loss / Cross Entropy = {:.4f}'.format(loss))?
開個玩笑,上面的東西不是那么數(shù)學,如果你想理解熵,對數(shù)在這個里面扮演的角色,我們開始:-)
如果你想深入了解信息論,包括所有的概念——熵,交叉熵等等,可以看看Chris Olah’s寫的的東西http://colah.github.io/posts/2015-09-Visual-Information/,非常的詳細。
?
我們從我們的數(shù)據(jù)分布開始。y代表了我們的點的類別(有3個紅色點,7個綠色點),這就是分布,我們叫做q(y),看起來是這樣的:

?
熵是一個給定的分布的不確定性的度量。
如果所有的點都是綠色的會怎么樣?分布的不確定性是什么樣的?零,對嗎?畢竟,點的顏色是毫無疑問的,永遠是綠色!所以,熵為零!
另外,如果我們知道正好一半的點是綠色而另外一半是紅色呢?這就是最差的情況對嗎?我們猜顏色的時候就沒有任何的優(yōu)勢了:完全的隨機!這種情況下,熵的值由下面的公式給出,我們的類別數(shù)是2:

對于任何一個之間的情況,我們可以計算熵的分布,就像我們的q(y),再使用下面的公式,C是類別的數(shù)量:

所以,如果我們知道了一個隨機變量的真實的分布,我們就可以計算它的熵。但是,為什么一開始要訓(xùn)練個分類器呢?畢竟,我們知道真實的分布了啊
但是,如果我們不知道呢?我們是不是可以通過另外的分布比如說p(y)來估計真實的分布呢?當然可以!
?
我們假設(shè)我們的點服從另外的分布p(y),但是我們知道這個分布是來自于真實(未知)的分布q(y),是嗎?
如果我們計算了熵,我們實際上計算的是這兩個分布的交叉熵:

如果我們可以神奇的將p(y)和q( y)匹配的很好,那么交叉熵的計算值和熵的計算值也會匹配的很好。
既然這個是不太可能發(fā)生的,在真實的分布上,交叉熵永遠會比熵要大那么一點。

原來,交叉熵和熵的差值是有個名字的...
?
KL散度,衡量的是兩個分布之間的差異性:

這個的意思是, p(y)和q(y)越接近,散度的值越小,交叉熵也是這樣。
所以,我們需要找到一個好的p(y)來用,這就是我們的分類器做的事情,是嗎?確實也是這樣!尋找最近的p(y),就是最小化交叉熵。
?
在訓(xùn)練中,分類器使用了N個點找那個的每一個來計算交叉熵的損失,有效的擬合出分布p(y)!既然每個點的概率都是1/N,交叉熵是這樣的:

還記得上面的圖6嗎?我們需要在每個點對應(yīng)的真實類別的概率上計算交叉熵。意思就是正樣本使用綠色條,負樣本使用紅色條,數(shù)學上可以這樣寫:

最后一步是計算所有的點在兩個類別上的平均值,正樣本和負樣本:

最后,再加上一點操作,我們使用任何一個點,不管是正樣本還是負樣本,都用同樣的公式:

好了!我們回到了二元交叉熵/對數(shù)損失最初的公式:-)
?
我真的希望上面的內(nèi)容可以給一些理所當然的概念一些不同的東西。我當然也希望可以展示給你關(guān)于機器學習和信息論是聯(lián)系在一起的。
好消息!?
小白學視覺知識星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴展模塊中文版教程 在「小白學視覺」公眾號后臺回復(fù):擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實戰(zhàn)項目52講 在「小白學視覺」公眾號后臺回復(fù):Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。 下載3:OpenCV實戰(zhàn)項目20講 在「小白學視覺」公眾號后臺回復(fù):OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。 交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

