Self-Supervised Learning可以很簡單: BYOL與SimSiam的觀點

Self-Supervised Learning可以很簡單: BYOL與SimSiam的觀點

近年self-supervised learning最知名的兩個方法是Google Brain的SimCLR 與Facebook AI Research的MoCo ,不過這兩個方法都有各自實作上的麻煩處:SimCLR需要夠大的batch提供互斥的樣本,而MoCo在內存中需要維護一個類似memory bank的過往紀錄與額外的momentum encoder

事實上SimCLR與MoCo的這些複雜的設計是為了提供訓練時負樣本,以避免contrasive learning的崩壞。2020年6月DeepMind團隊提出了BYOL **,用預測的概念拿掉了MoCo的過往紀錄。**而後同年11月,Facebook AI Research用交互的stop gradient在BYLO的基礎上建立了更簡潔的SimSiam 。這兩個方法在訓練時都不需要特別提供負樣本,讓SSL的訓練變得簡單又強大,是今年 (2020年) 極度值得學習的方法。

Cover photo by Canva  (組圖來源)

Tip

文章難度:★★★☆☆

閱讀建議: 文章介紹今年 (2020年)的 self-supervised learning的方向,著重於簡化訓練部分。文章前段簡單介紹 self-supervised learning與 contrastive learning的概念,並回顧兩個使用負樣本的代表性作品 SimCLR與 MoCo。後段介紹今年兩個重要的簡化方法 BYOL與 SimSiam。適合對於 SimCLR或相關負樣本方法有概念的讀者。

推薦背景知識:supervised learning, unsupervised learning, self-supervised learning, pre-training, data augmentation, image classification, SimCLR, MoCo.

近期的Self-Supervised Learning設計

Self-supervised learning (SSL) 其實算是近幾年迸出來的新名詞,屬於unsupervised learning的一種,訓練時必須從資料本身設計某些不含目標任務標註且具可學習性的資訊,以進行supervised learning

DeepMind團隊對於self-supervision的解釋 (資料來源)

SSL可以減少對於大量標註資料的依賴,這點與常見的supervised pretrain想法一致,兩者都是在追求representation learning的實踐,但SSL更強調於資料本身的豐富資訊,探索除了標註以外的更多可能性

事實上在自然語言處理領域,常見的word to vector的訓練就是一種經由語句前後關係的SSL。而在電腦視覺領域,contrastive learning (CL、對比學習) 是近年實現SSL最熱門的方法。像小孩子學習一樣,透過比較貓狗的同類相同之處與異類不同之處,在即使式在不知道什麼是貓、什麼是狗的情況下 (甚至沒有語言定義的情況),也可以學會分辨貓狗。

很簡單的contrastive learning概念,來自SimCLR (資料來源)

想像一個簡單的contrastive learning的概念,輸入x通過不同的transform得到x_i與x_j,通過某些參數的計算f(x)後得到representationh_i與h_j,最後通過projection到特定維度g(h)得到z_i與z_j,而訓練的目標就是最大化z_i與z_j之間的相似度 (或簡單想像為最小化兩者的差距)。

基本contrastive learning的概念 (資料來源)

其實很直觀的可以發現這樣的CL設計有天生就會崩壞的問題,也就是不論是f(x)或是g(h),都可以透過將輸出變成固定的數值,得到永遠最大的相似度。這樣的輸出一般稱為collapsing output。

維持Negative Sample的方法

**為了避免所謂的collapsing output,通常會藉由導入負樣本,讓學習目標不只是得到得到最大的相似度,同時也要與負樣本有最小的相似度。**這樣的設計讓輸出定值不會是天然的minimun。近年Google Brain的SimCLR [1] 與Facebook AI Research的MoCo [3] 都是走這樣的思路。

SimCLR是self-supervised learning與contrastive learning中重要的一個相當重要的里程碑,其最大的特點在於研究各種資料增強 (data augmentation) 作為SSL的歸納偏好 (inductive bias),並利用不同data間彼此的互斥強化學習目標,避免contrastive learning的output collapse。

整體運作概念分為三個階段:

  1. 先sample一些圖片(batch of image)
  2. 對batch裡的image做兩種不同的data augmentation
  3. 希望同一張影像、不同augmentation的結果相近,並互斥其他結果。

如果要將SimCLR的架構劃分階段,大致可以分成兩個階段,首先是大個embedding網路執行特徵抽取得到y**,接下來使用一個**小的網路投影到某個固定為度的空間得到z

SimCLR的網路與訓練架構圖 (資料來源)

這個小網路投影也是SimCLR的另一個特點。對於同一個x,用data augmentation得到不同的v,通過網路抽取、投影得到固定維度的特徵,計算z的contrastive loss,直接用gradient decent同時訓練兩個階段的網路。

SimCLR的方法雖然簡單,但是一個麻煩的點在於需要大量的online負樣本提供斥力。在論文中使用了4096的batch size,還需要為了特別大的batch使用LARS作為optimizer。

MoCo是Facebook AI Research (FAIR) 推出的SSL方法,也是Kaiming He這一兩年的重點研究,整體運作的精華在於改善過往SSL常用的memory bank。原本memory bank是將過往看過sample的representation儲存起來,在每次training step時從中sample出一個mini batch作為負樣本使用。這個方法的問題在於儲存在memory bank中的representation可能經過幾百或更多step的參數更新後,已經失去與現在抽出來的representation的一致性。

MoCo (momentum contrast) 的想法是維持兩個encoder,一個使用gradient decent訓練,另一個的參數則是跟著第一個encode的參數,但是使用momentum更新,保持一個相似但不同的狀態。而這個momentum encoder的輸出會被一個queue儲存起來,取代原本的memory bank。相比原本的memory bank作法,MoCo的儲存的representation比較新,保持一致性的同時也藉由momentum encoder確保了足夠的對比性。

(資料來源)

至於MoCo v2 [4] 只是將SimCLR的兩個元件 (stronger data augmentation與投影網路 )加入MoCo的架構。整體架構可以視為依照SimCLR的擴充版,多了momentum encoder與memory bank,但是不需要特別大的batch。

MoCo v2的網路與訓練架構圖 (資料來源)

以上的方法雖然展示了self-supervised learning的強大,不過都有一些實作上的麻煩之處,降低許多應用開發人員嘗試的意願。

在SimCLR中,訓練時需要大量online的負樣本,導致了batch size要開得很大效果才好,論文中使用了4096的batch size,甚至為了optimize這樣不尋常的大batch,還使用了特殊設計的optimizer LARS。這樣大的batch以現階段的硬體設備不可能放在一張GPU上,也因此cross GPU的batch normalization也需要被導入。

在MoCo中,除了需要維護representation的queue(其實也就是某種memory bank),還需要在訓練時運行兩個網路,gradient訓練的encoder與momentum encoder,造成了兩倍大的記憶體需求。

簡化SSL訓練方法

今年,分別由DeepMind團隊與Facebook AI Research團隊都提出了不需要負樣本的contrastive learning架構BYOL [5] 與SimSiam [6],示範了在computer vision領域實現SSL的簡單方法。雖然不需要負樣本也不會導致collapsing output,至今 (2020年12月) 相關論文方法都不能提供具有相當理論的合理解釋。

BYOL與MoCo設計相似,需要兩組網路 (online與target network),online網路依照gradient更新參數,target network的參數則是藉由online network參數的momentum更新。與MoCo不同的是網路多了一個預測層,分為三個階段:特徵抽取(embedding, or encoding) 得到y、特徵投影 (projection) 得到z、最後通過預測 (prediction) 得到p。Online network的任務是用預測網路產生的 p 來預測target network的projection `z

BYOL的網路與訓練架構圖 (資料來源)

也就是說BYOL並不需要儲存過去的representation,拔除了MoCo的memory bank。換句話說,BYOL的contrastive learning過程中並沒有顯著地使用到任何負樣本,但是卻不會造成collapsing output。這本質上是一件非常非常不可思議的事情。論文的解釋是他們認為在loss的設計上並沒有同時optimize兩個encoder,就像GAN (generative adversarial network) 一樣,因此網路並不會converge到整體loss最小的地方,也就是collapsing output。

事實上這樣的說法很難說服人,論文也承認了存在一個明顯的 undesirable equilibria會導致 collapsing,這個實驗結果僅僅是實務上 “we did not observe convergence to such equilibria”

至於效果,BYOL是無庸置疑的強大,在ImageNet的linear evaluation下與其他下游任務的transfer learning上表現都相當優異。

BYOL在ImageNet上的linear evaluation (資料來源)
BYOL在下游任務的transfer learning效果 (資料來源)

SimSiam簡單用一句話描述就是沒有momentum encoder的BYOL。BYOL拿掉了MoCo的memory bank,SimSiam進一步地拿掉了momentum encoder。方法簡單,實務上同樣能避免collapsing output的發生。

SimSiam的架構與BYOL一樣是三個階段的架構,先過主網路embedding,再過小網路projection,最後過小網路prediction。與BYOL不同之處在於SimSiam並沒有兩組網路參數,同一個網路對於不同的view交互地用latent projection互為彼此的prediction target。在更新參數時,都只計算prediction那條路線的gradient,也自然沒有什麼momentum encoder。

SimSiam的網路與訓練架構圖

作者認為近期的unsupervised或self-supervised learning大多都使用了某種形式的siamese network,也就是對於成對或多個輸入使用共享參數的網路,藉以計算某種相似度。論文進一步的推斷,在這樣的Siamese架構下存在兩個不同的最佳化任務,實現兩個observation的matching。就像是convolution是modeling translation invariance的歸納偏置 (inductive bias),Siamese架構也是一種modeling invariance的inductive biases。這個invariance指的就是data augmentation invariance。

論文透過實驗展示,不論是predict head的設計、batch的大小、batch normalization的使用、similarity function的設計、或是對比性的symmetrization都不是避免collapsing output的關鍵。並且透過額外的實驗旁敲側擊地佐證,導入stop gradient的操作等同於引入一些虛構的參數,讓SimSiam類似於某種expectation maximization (EM)的演算法,去估算data augmentation的期望值。

事實上這樣的實驗與說法依然相當地玄,與 BYOL一樣, SimSiam論文也承認了這僅僅是實務上的結果 (“SimSiam and its variants non-collapsing behavior still remains as an empirical observation”)

在實驗結果上,SimSiam並不是最強的,但是依然是所有方法中最簡單、最好實現的。

在ImageNet上的linear evalueation (資料來源)

Self-supervised learning在這兩年內獲得莫大的進展,而且許多優質論文都是來自強大軟體公司的研究團隊,可以想像這是一個學界與業界都熱烈渴求的目標。畢竟在蒐集標註資料的時候,收斂需要的數量越少,成本越低在同樣的標註資料量下,可以從資料本身獲取越多資訊的模型自然也越強大

時至今日,SSL的相關技術還完全稱不上是成熟,工業上的應用應該也相較地少,不過相信在一兩年內就有機會看到更完整的論文與應用落地。

好了~這篇文章就先到這邊。老話一句,Deep Learning領域每年都會有大量高質量的論文產出,說真的要跟緊不是一件容易的事。所以我的觀點可能也會存在瑕疵,若有發現什麼錯誤或值得討論的地方,歡迎回覆文章或來信一起討論 :)

Supplementary

今年其實還有一篇優秀的SSL論文,來自Inria與Facebook AI Research的SwAV (“Unsupervised Learning of Visual Features by Contrasting Cluster Assignments”) [7]。論文的設計包含了類似clustering以及swap prediction的概念,效果也相當地好。不過由於方法有點繁雜,與本篇文章追求簡單的SSL有些不同,就不在這篇文章中詳述。

SwAV的CL運作概念 (資料來源)

Reference

  1. A Simple Framework for Contrastive Learning of Visual Representations[ICML 2020]
  2. Advancing Self-Supervised and Semi-Supervised Learning with SimCLR[Google AI Blog]
  3. Momentum Contrast for Unsupervised Visual Representation Learning[CVPR 2020]
  4. Improved Baselines with Momentum Contrastive Learning[arXiv 2020]
  5. Bootstrap your own latent: A new approach to self-supervised Learning[NeurIPS 2020]
  6. Exploring Simple Siamese Representation Learning[arXiv 2020]
  7. Unsupervised Learning of Visual Features by Contrasting Cluster Assignments[NeurIPS 2020]
comments powered by Disqus