【機(jī)器學(xué)習(xí)基礎(chǔ)】(三):理解邏輯回歸及二分類、多分類代碼實(shí)踐
本文是機(jī)器學(xué)習(xí)系列的第三篇,算上前置機(jī)器學(xué)習(xí)系列是第八篇。本文的概念相對簡單,主要側(cè)重于代碼實(shí)踐。
上一篇文章說到,我們可以用線性回歸做預(yù)測,但顯然現(xiàn)實(shí)生活中不止有預(yù)測的問題還有分類的問題。我們可以從預(yù)測值的類型上簡單區(qū)分:連續(xù)變量的預(yù)測為回歸,離散變量的預(yù)測為分類。
一、邏輯回歸:二分類
1.1 理解邏輯回歸
我們把連續(xù)的預(yù)測值進(jìn)行人工定義,邊界的一邊定義為1,另一邊定義為0。這樣我們就把回歸問題轉(zhuǎn)換成了分類問題。

如上圖,我們把連續(xù)的變量分布壓制在0-1的范圍內(nèi),并以0.5作為我們分類決策的邊界,大于0.5的概率則判別為1,小于0.5的概率則判別為0。

我們無法使用無窮大和負(fù)無窮大進(jìn)行算術(shù)運(yùn)算,我們通過邏輯回歸函數(shù)(Sigmoid函數(shù)/S型函數(shù)/Logistic函數(shù))可以講數(shù)值計(jì)算限定在0-1之間。
以上就是邏輯回歸的簡單解釋。下面我們應(yīng)用真實(shí)的數(shù)據(jù)案例來進(jìn)行二分類代碼實(shí)踐。
1.2 代碼實(shí)踐 - 導(dǎo)入數(shù)據(jù)集
添加引用:
import?numpy?as?np
import?pandas?as?pd
import?seaborn?as?sns
import?matplotlib.pyplot?as?plt
導(dǎo)入數(shù)據(jù)集(大家不用在意這個域名):
df?=?pd.read_csv('https://blog.caiyongji.com/assets/hearing_test.csv')
df.head()
| age | physical_score | test_result |
|---|---|---|
| 33 | 40.7 | 1 |
| 50 | 37.2 | 1 |
| 52 | 24.7 | 0 |
| 56 | 31 | 0 |
| 35 | 42.9 | 1 |
該數(shù)據(jù)集,對5000名參與者進(jìn)行了一項(xiàng)實(shí)驗(yàn),以研究年齡和身體健康對聽力損失的影響,尤其是聽高音的能力。此數(shù)據(jù)顯示了研究結(jié)果對參與者進(jìn)行了身體能力的評估和評分,然后必須進(jìn)行音頻測試(通過/不通過),以評估他們聽到高頻的能力。
特征:1. 年齡 2. 健康得分 標(biāo)簽:(1通過/0不通過)
1.3 觀察數(shù)據(jù)
sns.scatterplot(x='age',y='physical_score',data=df,hue='test_result')
我們用seaborn繪制年齡和健康得分特征對應(yīng)測試結(jié)果的散點(diǎn)圖。

sns.pairplot(df,hue='test_result')
我們通過pairplot方法繪制特征兩兩之間的對應(yīng)關(guān)系。

我們可以大致做出判斷,當(dāng)年齡超過60很難通過測試,通過測試者普遍健康得分超過30。
1.4 訓(xùn)練模型
from?sklearn.model_selection?import?train_test_split
from?sklearn.preprocessing?import?StandardScaler
from?sklearn.linear_model?import?LogisticRegression
from?sklearn.metrics?import?accuracy_score,classification_report,plot_confusion_matrix
#準(zhǔn)備數(shù)據(jù)
X?=?df.drop('test_result',axis=1)
y?=?df['test_result']
X_train,?X_test,?y_train,?y_test?=?train_test_split(X,?y,?test_size=0.1,?random_state=50)
scaler?=?StandardScaler()
scaled_X_train?=?scaler.fit_transform(X_train)
scaled_X_test?=?scaler.transform(X_test)
#定義模型
log_model?=?LogisticRegression()
#訓(xùn)練模型
log_model.fit(scaled_X_train,y_train)
#預(yù)測數(shù)據(jù)
y_pred?=?log_model.predict(scaled_X_test)
accuracy_score(y_test,y_pred)
我們經(jīng)過準(zhǔn)備數(shù)據(jù),定義模型為LogisticRegression邏輯回歸模型,通過fit方法擬合訓(xùn)練數(shù)據(jù),最后通過predict方法進(jìn)行預(yù)測。
最終我們調(diào)用accuracy_score方法得到模型的準(zhǔn)確率為92.2%。
二、模型性能評估:準(zhǔn)確率、精確度、召回率
我們是如何得到準(zhǔn)確率是92.2%的呢?我們調(diào)用plot_confusion_matrix方法繪制混淆矩陣。
plot_confusion_matrix(log_model,scaled_X_test,y_test)
我們觀察500個測試實(shí)例,得到矩陣如下:

我們對以上矩陣進(jìn)行定義如下:
真正類TP(True Positive) :預(yù)測為正,實(shí)際結(jié)果為正。如,上圖右下角285。 真負(fù)類TN(True Negative) :預(yù)測為負(fù),實(shí)際結(jié)果為負(fù)。如,上圖左上角176。 假正類FP(False Positive) :預(yù)測為正,實(shí)際結(jié)果為負(fù)。如,上圖左下角19。 假負(fù)類FN(False Negative) :預(yù)測為負(fù),實(shí)際結(jié)果為正。如,上圖右上角20。
準(zhǔn)確率(Accuracy) 公式如下:
帶入本例得:
精確度(Precision) 公式如下:
帶入本例得:
召回率(Recall) 公式如下:
帶入本例得:
我們調(diào)用classification_report方法可驗(yàn)證結(jié)果。
print(classification_report(y_test,y_pred))

三、Softmax:多分類
3.1 理解softmax多元邏輯回歸
Logistic回歸和Softmax回歸都是基于線性回歸的分類模型,兩者無本質(zhì)區(qū)別,都是從伯努利分結(jié)合最大對數(shù)似然估計(jì)。
最大似然估計(jì):簡單來說,最大似然估計(jì)就是利用已知的樣本結(jié)果信息,反推最具有可能(最大概率)導(dǎo)致這些樣本結(jié)果出現(xiàn)的模型參數(shù)值。
術(shù)語“概率”(probability)和“似然”(likelihood)在英語中經(jīng)?;Q使用,但是它們在統(tǒng)計(jì)學(xué)中的含義卻大不相同。給定具有一些參數(shù)θ的統(tǒng)計(jì)模型,用“概率”一詞描述未來的結(jié)果x的合理性(知道參數(shù)值θ),而用“似然”一詞表示描述在知道結(jié)果x之后,一組特定的參數(shù)值θ的合理性。
Softmax回歸模型首先計(jì)算出每個類的分?jǐn)?shù),然后對這些分?jǐn)?shù)應(yīng)用softmax函數(shù),估計(jì)每個類的概率。我們預(yù)測具有最高估計(jì)概率的類,簡單來說就是找得分最高的類。
3.2 代碼實(shí)踐 - 導(dǎo)入數(shù)據(jù)集
導(dǎo)入數(shù)據(jù)集(大家不用在意這個域名):
df?=?pd.read_csv('https://blog.caiyongji.com/assets/iris.csv')
df.head()
| sepal_length | sepal_width | petal_length | petal_width | species |
|---|---|---|---|---|
| 5.1 | 3.5 | 1.4 | 0.2 | setosa |
| 4.9 | 3 | 1.4 | 0.2 | setosa |
| 4.7 | 3.2 | 1.3 | 0.2 | setosa |
| 4.6 | 3.1 | 1.5 | 0.2 | setosa |
| 5 | 3.6 | 1.4 | 0.2 | setosa |
該數(shù)據(jù)集,包含150個鳶尾花樣本數(shù)據(jù),數(shù)據(jù)特征包含花瓣的長度和寬度和萼片的長度和寬度,包含三個屬種的鳶尾花,分別是山鳶尾(setosa)、變色鳶尾(versicolor)和維吉尼亞鳶尾(virginica)。
特征:1. 花萼長度 2. 花萼寬度 3. 花瓣長度 4 花萼寬度 標(biāo)簽:種類:山鳶尾(setosa)、變色鳶尾(versicolor)和維吉尼亞鳶尾(virginica)
3.3 觀察數(shù)據(jù)
sns.scatterplot(x='sepal_length',y='sepal_width',data=df,hue='species')
我們用seaborn繪制花萼長度和寬度特征對應(yīng)鳶尾花種類的散點(diǎn)圖。

sns.scatterplot(x='petal_length',y='petal_width',data=df,hue='species')
我們用seaborn繪制花瓣長度和寬度特征對應(yīng)鳶尾花種類的散點(diǎn)圖。

sns.pairplot(df,hue='species')
我們通過pairplot方法繪制特征兩兩之間的對應(yīng)關(guān)系。

我們可以大致做出判斷,綜合考慮花瓣和花萼尺寸最小的為山鳶尾花,中等尺寸的為變色鳶尾花,尺寸最大的為維吉尼亞鳶尾花。
3.4 訓(xùn)練模型
#準(zhǔn)備數(shù)據(jù)
X?=?df.drop('species',axis=1)
y?=?df['species']
X_train,?X_test,?y_train,?y_test?=?train_test_split(X,?y,?test_size=0.25,?random_state=50)
scaler?=?StandardScaler()
scaled_X_train?=?scaler.fit_transform(X_train)
scaled_X_test?=?scaler.transform(X_test)
#定義模型
softmax_model?=?LogisticRegression(multi_class="multinomial",solver="lbfgs",?C=10,?random_state=50)
#訓(xùn)練模型
softmax_model.fit(scaled_X_train,y_train)
#預(yù)測數(shù)據(jù)
y_pred?=?softmax_model.predict(scaled_X_test)
accuracy_score(y_test,y_pred)
我們經(jīng)過準(zhǔn)備數(shù)據(jù),定義模型LogisticRegression的multi_class="multinomial"多元邏輯回歸模型,設(shè)置求解器為lbfgs,通過fit方法擬合訓(xùn)練數(shù)據(jù),最后通過predict方法進(jìn)行預(yù)測。
最終我們調(diào)用accuracy_score方法得到模型的準(zhǔn)確率為92.1%。
我們調(diào)用classification_report方法查看準(zhǔn)確率、精確度、召回率。
print(classification_report(y_test,y_pred))

3.5 拓展:繪制花瓣分類
我們僅提取花瓣長度和花瓣寬度的特征來繪制鳶尾花的分類圖像。
#提取特征
X?=?df[['petal_length','petal_width']].to_numpy()?
y?=?df["species"].factorize(['setosa',?'versicolor','virginica'])[0]
#定義模型
softmax_reg?=?LogisticRegression(multi_class="multinomial",solver="lbfgs",?C=10,?random_state=50)
#訓(xùn)練模型
softmax_reg.fit(X,?y)
#隨機(jī)測試數(shù)據(jù)
x0,?x1?=?np.meshgrid(
????????np.linspace(0,?8,?500).reshape(-1,?1),
????????np.linspace(0,?3.5,?200).reshape(-1,?1),
????)
X_new?=?np.c_[x0.ravel(),?x1.ravel()]
#預(yù)測
y_proba?=?softmax_reg.predict_proba(X_new)
y_predict?=?softmax_reg.predict(X_new)
#繪制圖像
zz1?=?y_proba[:,?1].reshape(x0.shape)
zz?=?y_predict.reshape(x0.shape)
plt.figure(figsize=(10,?4))
plt.plot(X[y==2,?0],?X[y==2,?1],?"g^",?label="Iris?virginica")
plt.plot(X[y==1,?0],?X[y==1,?1],?"bs",?label="Iris?versicolor")
plt.plot(X[y==0,?0],?X[y==0,?1],?"yo",?label="Iris?setosa")
from?matplotlib.colors?import?ListedColormap
custom_cmap?=?ListedColormap(['#fafab0','#9898ff','#a0faa0'])
plt.contourf(x0,?x1,?zz,?cmap=custom_cmap)
contour?=?plt.contour(x0,?x1,?zz1,?cmap=plt.cm.brg)
plt.clabel(contour,?inline=1,?fontsize=12)
plt.xlabel("Petal?length",?fontsize=14)
plt.ylabel("Petal?width",?fontsize=14)
plt.legend(loc="center?left",?fontsize=14)
plt.axis([0,?7,?0,?3.5])
plt.show()
得到鳶尾花根據(jù)花瓣分類的圖像如下:

四、小結(jié)
相比于概念的理解,本文更側(cè)重上手實(shí)踐,通過動手編程你應(yīng)該有“手熱”的感覺了。截至到本文,你應(yīng)該對機(jī)器學(xué)習(xí)的概念有了一定的掌握,我們簡單梳理一下:
機(jī)器學(xué)習(xí)的分類 機(jī)器學(xué)習(xí)的工業(yè)化流程 特征、標(biāo)簽、實(shí)例、模型的概念 過擬合、欠擬合 損失函數(shù)、最小二乘法 梯度下降、學(xué)習(xí)率 7.線性回歸、邏輯回歸、多項(xiàng)式回歸、逐步回歸、嶺回歸、套索(Lasso)回歸、彈性網(wǎng)絡(luò)(ElasticNet)回歸是最常用的回歸技術(shù) Sigmoid函數(shù)、Softmax函數(shù)、最大似然估計(jì)
如果你還有不清楚的地方請參考:
機(jī)器學(xué)習(xí)(二):理解線性回歸與梯度下降并做簡單預(yù)測 機(jī)器學(xué)習(xí)(一):5分鐘理解機(jī)器學(xué)習(xí)并上手實(shí)踐 前置機(jī)器學(xué)習(xí)(五):30分鐘掌握常用Matplotlib用法 前置機(jī)器學(xué)習(xí)(四):一文掌握Pandas用法 前置機(jī)器學(xué)習(xí)(三):30分鐘掌握常用NumPy用法 前置機(jī)器學(xué)習(xí)(二):30分鐘掌握常用Jupyter Notebook用法 前置機(jī)器學(xué)習(xí)(一):數(shù)學(xué)符號及希臘字母
往期精彩回顧
本站知識星球“黃博的機(jī)器學(xué)習(xí)圈子”(92416895)
本站qq群704220115。
加入微信群請掃碼:
