這應(yīng)該是網(wǎng)上最簡單的元學(xué)習(xí)入門教程了
作者:涼爽的安迪
知乎:https://www.zhihu.com/people/wang-wj-38
「寫在前面:迄今為止,本文應(yīng)該是網(wǎng)上介紹【元學(xué)習(xí)(Meta-Learning)】」 最通俗易懂的文章了( 保命),主要目的是想對自己對于元學(xué)習(xí)的內(nèi)容和問題進(jìn)行總結(jié),同時(shí)為想要學(xué)習(xí)Meta-Learning的同學(xué)提供一下簡單的入門。筆者挑選了經(jīng)典的paper詳讀,看了李宏毅老師深度學(xué)習(xí)課程元學(xué)習(xí)部分,并附了MAML的代碼。為了通俗易懂,我將數(shù)學(xué)推導(dǎo)和工程實(shí)踐分開兩篇文章進(jìn)行介紹。~
「如果大家覺得有幫助,可以幫忙點(diǎn)個(gè)贊或者收藏一下,這將是我繼續(xù)分享的動(dòng)力~」
以下是本文的主要框架:
Introduction Meta Learning 實(shí)施——以 MAML 為例 Reptile What's more
1. Introduction
通常在機(jī)器學(xué)習(xí)里,我們會(huì)使用某個(gè)場景的大量數(shù)據(jù)來訓(xùn)練模型;然而當(dāng)場景發(fā)生改變,模型就需要重新訓(xùn)練。但是對于人類而言,一個(gè)小朋友成長過程中會(huì)見過許多物體的照片,某一天,當(dāng) Ta(第一次)僅僅看了幾張狗的照片,就可以很好地對狗和其他物體進(jìn)行區(qū)分。
元學(xué)習(xí) Meta Learning,含義為學(xué)會(huì)學(xué)習(xí),即 learn to learn,就是帶著這種對人類這種“學(xué)習(xí)能力”的期望誕生的。Meta Learning 希望使得模型獲取一種“學(xué)會(huì)學(xué)習(xí)”的能力,使其可以在獲取已有“知識”的基礎(chǔ)上快速學(xué)習(xí)新的任務(wù),如:
讓 Alphago 迅速學(xué)會(huì)下象棋 讓一個(gè)貓咪圖片分類器,迅速具有分類其他物體的能力
「需要注意的是,雖然同樣有“預(yù)訓(xùn)練”的意思在里面,但是元學(xué)習(xí)的內(nèi)核區(qū)別于遷移學(xué)習(xí)(Transfer Learning)」,關(guān)于他們的區(qū)別,我會(huì)在下文進(jìn)行闡述。
接下來,我們通過對比機(jī)器學(xué)習(xí)和元學(xué)習(xí)這兩個(gè)概念的要素來加深對元學(xué)習(xí)這個(gè)概念的理解。

在機(jī)器學(xué)習(xí)中,「訓(xùn)練單位是一條數(shù)據(jù)」,通過數(shù)據(jù)來對模型進(jìn)行優(yōu)化;數(shù)據(jù)可以分為訓(xùn)練集、測試集和驗(yàn)證集。在元學(xué)習(xí)中,訓(xùn)練單位分層級了,「第一層訓(xùn)練單位是任務(wù),也就是說,元學(xué)習(xí)中要準(zhǔn)備許多任務(wù)來進(jìn)行學(xué)習(xí),第二層訓(xùn)練單位才是每個(gè)任務(wù)對應(yīng)的數(shù)據(jù)」。
二者的目的都是找一個(gè) Function,只是兩個(gè) Function 的功能不同,要做的事情不一樣。機(jī)器學(xué)習(xí)中的 Function 直接作用于特征和標(biāo)簽,去尋找特征與標(biāo)簽之間的關(guān)聯(lián);而元學(xué)習(xí)中的 Function 是用于尋找新的 f,新的 f 才會(huì)應(yīng)用于具體的任務(wù)。「有種不同階導(dǎo)數(shù)的感覺」。又有種**老千層餅的感覺,**你看到我在第二層,你把我想象成第一層,而其實(shí)我在第五層。。。
2. Meta Learning 實(shí)施——以 MAML 為例
我們先對比機(jī)器學(xué)習(xí)的過程來進(jìn)一步理解元學(xué)習(xí)。如下圖所示,機(jī)器學(xué)習(xí)的一般過程如下:
設(shè)計(jì)網(wǎng)絡(luò)網(wǎng)絡(luò)結(jié)構(gòu),如 CNN、RNN 等; 選定某個(gè)分布來初始化參數(shù);(以上其實(shí)決定了初始的f的長相,選擇不同的網(wǎng)絡(luò)結(jié)構(gòu)或參數(shù)相當(dāng)于定義了不同的f); 喂訓(xùn)練數(shù)據(jù),根據(jù)選定的 Loss Function 計(jì)算 Loss; 梯度下降,逐步更新 ; 得到最終的 f

機(jī)器學(xué)習(xí)過程,引自李宏毅《深度學(xué)習(xí)》
其中,紅色方框里的“配置”都是由人為設(shè)計(jì)的,我們又叫做“超參數(shù)“。Meta Learning 中希望把這些配置,如網(wǎng)絡(luò)結(jié)構(gòu),參數(shù)初始化,優(yōu)化器等由機(jī)器自行設(shè)計(jì)(注:此處區(qū)別于 AutoML,遷移學(xué)習(xí)(Transfer Learning)和終身學(xué)習(xí)(Life Long Learning) ),使網(wǎng)絡(luò)有更強(qiáng)的學(xué)習(xí)能力和表現(xiàn)。
上文已經(jīng)提到,「【元學(xué)習(xí)中要準(zhǔn)備許多任務(wù)來進(jìn)行學(xué)習(xí),而每個(gè)任務(wù)又有各自的訓(xùn)練集和測試集】」。我們結(jié)合一個(gè)具體的任務(wù),來介紹元學(xué)習(xí)和MAML的實(shí)施過程。
有一個(gè)圖像數(shù)據(jù)集叫 Omniglot:https://github.com/brendenlake/omniglot。Omniglot 包含 1623 個(gè)不同的火星文字符,每個(gè)字符包含 20 個(gè)手寫的 case。這個(gè)任務(wù)是判斷每個(gè)手寫的 case 屬于哪一個(gè)火星文字符。
如果我們要進(jìn)行 N-ways,K-shot(數(shù)據(jù)中包含 N 個(gè)字符類別,每個(gè)字符有 K 張圖像)的一個(gè)圖像分類任務(wù)。比如 20-ways,1-shot 分類的意思是說,要做一個(gè) 20 分類,但是每個(gè)分類下只有 1 張圖像的任務(wù)。我們可以依據(jù) Omniglot 構(gòu)建很多 N-ways,K-shot 任務(wù),這些任務(wù)將作為元學(xué)習(xí)的任務(wù)來源。構(gòu)建的任務(wù)分為訓(xùn)練任務(wù)(Train Task),測試任務(wù)(Test Task)。特別地,每個(gè)任務(wù)包含自己的「訓(xùn)練數(shù)據(jù)、測試數(shù)據(jù)」,在元學(xué)習(xí)里,分別稱為 「Support Set 和 Query Set」。
「MAML 的目的是獲取一組更好的模型初始化參數(shù)(即讓模型自己學(xué)會(huì)初始化)」。我們通過(許多)N-ways,K-shot 的任務(wù)(訓(xùn)練任務(wù))進(jìn)行元學(xué)習(xí)的訓(xùn)練,使得模型學(xué)習(xí)到“先驗(yàn)知識”(初始化的參數(shù))。這個(gè)“先驗(yàn)知識”在新的 N-ways,K-shot 任務(wù)上可以表現(xiàn)的更好。
接下來介紹 MAML 的算法流程:

?「當(dāng)然,在“預(yù)訓(xùn)練”階段,也可以sample出1個(gè)batch的幾個(gè)任務(wù),那么在更新meta網(wǎng)絡(luò)時(shí),要使用sample出所有任務(wù)的梯度之和?!?/strong>**注意:**在MAML中,「meta網(wǎng)絡(luò)與子任務(wù)的網(wǎng)絡(luò)結(jié)構(gòu)必須完全相同」。
?
這里面有幾個(gè)小問題:
MAML的執(zhí)行過程與model pretraining & transfer learning的區(qū)別是什么? 為何在meta網(wǎng)絡(luò)賦值給具體訓(xùn)練任務(wù)(如任務(wù)m)后,要先更訓(xùn)練任務(wù)的參數(shù),再計(jì)算梯度,更新meta網(wǎng)絡(luò)? 在更新訓(xùn)練任務(wù)的網(wǎng)絡(luò)時(shí),只走了一步,然后更新meta網(wǎng)絡(luò)。為什么是一步,可以是多步嗎?
這三個(gè)問題是MAML中很核心的問題,大家可以先思考一下,我們將在后文進(jìn)行解答。我們先看一下MAML的實(shí)現(xiàn)代碼。
## 網(wǎng)絡(luò)構(gòu)建部分: refer: https://github.com/dragen1860/MAML-TensorFlow
#################################################
# 任務(wù)描述:5-ways,1-shot圖像分類任務(wù),圖像統(tǒng)一處理成 84 * 84 * 3 = 21168的尺寸。
# support set:5 * 1
# query set:5 * 15
# 訓(xùn)練取1個(gè)batch的任務(wù):batch size:4
# 對訓(xùn)練任務(wù)進(jìn)行訓(xùn)練時(shí),更新5次:K = 5
#################################################
print(support_x) # (4, 5, 21168)
print(query_x) # (4, 75, 21168)
print(support_y) # (4, 5, 5)
print(query_y) # (4, 75, 5)
print(meta_batchsz) # 4
print(K) # 5
model = MAML()
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
class MAML:
def __init__(self):
pass
def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
"""
:param support_xb: [4, 5, 84*84*3]
:param support_yb: [4, 5, n-way]
:param query_xb: [4, 75, 84*84*3]
:param query_yb: [4, 75, n-way]
:param K: 訓(xùn)練任務(wù)的網(wǎng)絡(luò)更新步數(shù)
:param meta_batchsz: 任務(wù)數(shù),4
"""
self.weights = self.conv_weights() # 創(chuàng)建或者復(fù)用網(wǎng)絡(luò)參數(shù);訓(xùn)練任務(wù)對應(yīng)的網(wǎng)絡(luò)復(fù)用meta網(wǎng)絡(luò)的參數(shù)
training = True if mode is 'train' else False
def meta_task(input):
"""
:param support_x: [setsz, 84*84*3] (5, 21168)
:param support_y: [setsz, n-way] (5, 5)
:param query_x: [querysz, 84*84*3] (75, 21168)
:param query_y: [querysz, n-way] (75, 5)
:param training: training or not, for batch_norm
:return:
"""
support_x, support_y, query_x, query_y = input
query_preds, query_losses, query_accs = [], [], [] # 子網(wǎng)絡(luò)更新K次,記錄每一次queryset的結(jié)果
## 第0次對網(wǎng)絡(luò)進(jìn)行更新
support_pred = self.forward(support_x, self.weights, training) # 前向計(jì)算support set
support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
tf.argmax(support_y, axis=1))
grads = tf.gradients(support_loss, list(self.weights.values())) # 計(jì)算support set的梯度
gvs = dict(zip(self.weights.keys(), grads))
# 使用support set的梯度計(jì)算的梯度更新參數(shù),theta_pi = theta - alpha * grads
fast_weights = dict(zip(self.weights.keys(), \
[self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))
# 使用梯度更新后的參數(shù)對quert set進(jìn)行前向計(jì)算
query_pred = self.forward(query_x, fast_weights, training)
query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
query_preds.append(query_pred)
query_losses.append(query_loss)
# 第1到 K-1次對網(wǎng)絡(luò)進(jìn)行更新
for _ in range(1, K):
loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
labels=support_y)
grads = tf.gradients(loss, list(fast_weights.values()))
gvs = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
for key in fast_weights.keys()]))
query_pred = self.forward(query_x, fast_weights, training)
query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
# 子網(wǎng)絡(luò)更新K次,記錄每一次queryset的結(jié)果
query_preds.append(query_pred)
query_losses.append(query_loss)
for i in range(K):
query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
tf.argmax(query_y, axis=1)))
result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
return result
# return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
support_pred_tasks, support_loss_tasks, support_acc_tasks, \
query_preds_tasks, query_losses_tasks, query_accs_tasks = result
if mode is 'train':
self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
for j in range(K)]
self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
for j in range(K)]
# 更新meta網(wǎng)絡(luò),只使用了第 K步的query loss。這里應(yīng)該是個(gè)超參,更新幾步可以調(diào)調(diào)
optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
gvs = optimizer.compute_gradients(self.query_losses[-1])
# def ********
接下來回答一下上面的三個(gè)問題:
「問題1:MAML的執(zhí)行過程與model pretraining & transfer learning的區(qū)別是什么?」
我們將meta learning與model pretraining的loss函數(shù)寫出來。
meta learning與model pretraining的loss函數(shù)
注意這兩個(gè)loss函數(shù)的區(qū)別:
meta learning的L來「源于訓(xùn)練任務(wù)上網(wǎng)絡(luò)的參數(shù)更新過一次后」(該網(wǎng)絡(luò)更新過一次以后,網(wǎng)絡(luò)的參數(shù)與meta網(wǎng)絡(luò)的參數(shù)已經(jīng)有一些區(qū)別)「,然后使用Query Set」計(jì)算的loss; model pretraining的L來源于「同一個(gè)model的參數(shù)」(只有一個(gè)),使用訓(xùn)練數(shù)據(jù)計(jì)算的loss和梯度對model進(jìn)行更新;如果有多個(gè)訓(xùn)練任務(wù),我們可以將這個(gè)參數(shù)在很多任務(wù)上進(jìn)行預(yù)訓(xùn)練,訓(xùn)練的所有梯度都會(huì)直接更新到model的參數(shù)上。
看一下二者的更新過程簡圖:
meta learning與model pretraining訓(xùn)練過程,引自李宏毅《深度學(xué)習(xí)》
MAML是使用子任務(wù)的參數(shù),「第二次更新」的gradient的方向來更新參數(shù)(所以左圖,第一個(gè)藍(lán)色箭頭的方向與第二個(gè)綠色箭頭的方向平行;左圖第二個(gè)藍(lán)色箭頭的方向與第二個(gè)橘色箭頭的方向平行) 而model pretraining是使用子任務(wù)第一步更新的gradient的方向來更新參數(shù)(子任務(wù)的梯度往哪個(gè)方向走,model的參數(shù)就往哪個(gè)方向走)。
從sense上直觀理解:
model pretraining最小化當(dāng)前的model(只有一個(gè))在所有任務(wù)上的loss,所以model pretraining希望找到一個(gè)在所有任務(wù)(實(shí)際情況往往是大多數(shù)任務(wù))上都表現(xiàn)較好的一個(gè)初始化參數(shù),這個(gè)參數(shù)要在多數(shù)任務(wù)上「當(dāng)前表現(xiàn)較好」。 meta learning最小化每一個(gè)子任務(wù)訓(xùn)練一步之后,第二次計(jì)算出的loss,用第二步的gradient更新meta網(wǎng)絡(luò),這代表了什么呢?子任務(wù)從【狀態(tài)0】,到【狀態(tài)1】,我們希望狀態(tài)1的loss小,說明meta learning更c(diǎn)are的是「初始化參數(shù)未來的潛力」。
「一個(gè)關(guān)注當(dāng)下,一個(gè)關(guān)注潛力?!?/strong>
如下圖所示,model pretraining找到的參數(shù) ,在兩個(gè)任務(wù)上當(dāng)前的表現(xiàn)比較好(「當(dāng)下好」,但訓(xùn)練之后不保證好); 而MAML的參數(shù) 在兩個(gè)子任務(wù)當(dāng)前的表現(xiàn)可能都不是很好,但是如果在兩個(gè)子任務(wù)上繼續(xù)訓(xùn)練下去,可能會(huì)達(dá)到各自任務(wù)的局部最優(yōu)(「潛力好」)。

引自李宏毅《深度學(xué)習(xí)》
這里有一個(gè)toy example可以表現(xiàn)MAML的執(zhí)行過程與model pretraining & transfer learning的區(qū)別。
訓(xùn)練任務(wù):給定N個(gè)函數(shù),y = asinx + b(通過給a和b不同的取值可以得到很多sin函數(shù)),從每個(gè)函數(shù)中sample出K個(gè)點(diǎn),用sample出的K個(gè)點(diǎn)來預(yù)估最初的函數(shù),即求解a和b的值。
訓(xùn)練過程:用這N個(gè)訓(xùn)練任務(wù)sample出的數(shù)據(jù)點(diǎn)分別通過MAML與model pretraining訓(xùn)練網(wǎng)絡(luò),得到預(yù)訓(xùn)練的參數(shù)。
如下圖,用橘黃色的sin函數(shù)作為測試任務(wù),三角形的點(diǎn)是測試任務(wù)中sample出的樣本點(diǎn),在測試任務(wù)中,我們希望用sample出的樣本點(diǎn)還原橘黃色的線。

Toy example,引自李宏毅《深度學(xué)習(xí)》
model pretraining的結(jié)果,在測試任務(wù)上,在finetuning之前,綠色線是一條水平線,finetuning之后還原的線基本還是一條水平線。因?yàn)樵陬A(yù)訓(xùn)練的時(shí)候,有很多sin函數(shù),model pretraining希望找到一個(gè)在所有任務(wù)上都效果較好的初始化結(jié)果,但是許多sin函數(shù)波峰和波谷重疊起來,基本就是一條水平線。用這個(gè)初始化的結(jié)果取finetuning,得到的結(jié)果仍然是水平線。 MAML的初始化結(jié)果是綠色的線,和橘黃色的線有差異。但是隨著finetuning的進(jìn)行,結(jié)果與橘黃色的線更加接近。
「問題2:為何在meta網(wǎng)絡(luò)賦值給具體訓(xùn)練任務(wù)(如任務(wù)m)后,要先更訓(xùn)練任務(wù)的參數(shù),再計(jì)算梯度,更新meta網(wǎng)絡(luò)?」
這個(gè)問題其實(shí)在問題1中已經(jīng)進(jìn)行了回答,更新一步之后,避免了meta learning陷入了和model pretraining一樣的訓(xùn)練模式,更重要的是,可以使得meta模型更關(guān)注參數(shù)的**“潛力”**。
「問題3:在更新訓(xùn)練任務(wù)的網(wǎng)絡(luò)時(shí),只走了一步,然后更新meta網(wǎng)絡(luò)。為什么是一步,可以是多步嗎?」
李宏毅老師的課程中提到:
只更新一次,速度比較快;因?yàn)閙eta learning中,子任務(wù)有很多,都更新很多次,訓(xùn)練時(shí)間比較久。 MAML希望得到的初始化參數(shù)在新的任務(wù)中finetuning的時(shí)候效果好。如果只更新一次,就可以在新任務(wù)上獲取很好的表現(xiàn)。把這件事情當(dāng)成目標(biāo),可以使得meta網(wǎng)絡(luò)參數(shù)訓(xùn)練是很好(目標(biāo)與需求一致)。 當(dāng)初始化參數(shù)應(yīng)用到具體的任務(wù)中時(shí),也可以finetuning很多次。 Few-shot learning往往數(shù)據(jù)較少。
那么MAML中的訓(xùn)練任務(wù)的網(wǎng)絡(luò)可以更新多次后,再更新meta網(wǎng)絡(luò)嗎?
我覺得可以。直觀上感覺,更新次數(shù)決定了子任務(wù)對于meta網(wǎng)絡(luò)的影響程度,我覺得這個(gè)步數(shù)可以作為一個(gè)參數(shù)來調(diào)。
另外,即將介紹的下一個(gè)網(wǎng)絡(luò)——Reptile,也是對訓(xùn)練任務(wù)網(wǎng)絡(luò)進(jìn)行多次更新的。
「3. Reptile」
Reptile與MAML有點(diǎn)像,我們先看一下Reptile的訓(xùn)練簡圖:
Reptile訓(xùn)練過程,引自李宏毅《深度學(xué)習(xí)》
Reptile的訓(xùn)練過程如下:

Reptile,每次sample出1個(gè)訓(xùn)練任務(wù)

Reptile,每次sample出1個(gè)batch訓(xùn)練任務(wù)
在Reptile中:
訓(xùn)練任務(wù)的網(wǎng)絡(luò)可以更新多次 reptile不再像MAML一樣計(jì)算梯度(因此帶來了工程性能的提升),而是直接用一個(gè)參數(shù) 乘以meta網(wǎng)絡(luò)與訓(xùn)練任務(wù)的網(wǎng)絡(luò)參數(shù)的差來更新meta網(wǎng)絡(luò)參數(shù) 從效果上來看,Reptile效果與MAML基本持平
「4. What's more」
元學(xué)習(xí)入門部分的文章基本就分享到這里了~
從出發(fā)點(diǎn)上來看,元學(xué)習(xí)和model pretraining有點(diǎn)像,即,都是讓網(wǎng)絡(luò)具有一些先驗(yàn)知識。 從訓(xùn)練過程的設(shè)計(jì)來看,元學(xué)習(xí)更關(guān)注模型的潛力,而model pretraining更注重模型當(dāng)下在多數(shù)情況下的表現(xiàn),效果孰好孰壞很難直接判定。這大概也就是仰望天空和腳踏實(shí)地的區(qū)別hahaha 元學(xué)習(xí)除了可以初始化參數(shù)以外,還有一些設(shè)計(jì)可以幫助確定網(wǎng)絡(luò)結(jié)構(gòu),如何更新參數(shù)等等這里有李宏毅老師的一個(gè)課程大家可以關(guān)注一下https://www.youtube.com/watch?v=c10nxBcSH14 。
分享一個(gè)關(guān)于元學(xué)習(xí)的搞笑的圖。。。

老千層餅,你永遠(yuǎn)都不知道你咬下去的這一口有多少層。。

接下來可能會(huì)分享一篇MAML的數(shù)學(xué)推導(dǎo),以及想把當(dāng)前工作里的model pretraining模型切到meta learning看一下效果。
「最后的最后,求贊求收藏求關(guān)注~」
「參考文獻(xiàn)」
Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017: 1126-1135. Nichol A, Schulman J. Reptile: a scalable metalearning algorithm[J]. arXiv preprint arXiv:1803.02999, 2018, 2: 2. https://github.com/dragen1860/MAML-TensorFlow https://www.youtube.com/watch?v=c10nxBcSH14 [https://www.bilibili.com/video/BV1J](
