機器之心分析師網(wǎng)絡
作者:周宇
編輯:H4O
本文重點探討分布式學習框架中針對隨機梯度下降(SGD)算法的拜占庭問題。
分布式學習(Distributed Learning)是一種廣泛應用的大規(guī)模模型訓練框架。在分布式學習框架中,服務器通過聚合在分布式設備中訓練的本地模型(local model)來利用各個設備的計算能力。分布式機器學習的典型架構(gòu)——參數(shù)服務器架構(gòu)中,包括一個服務器(稱為參數(shù)服務器 - Parameter Server,PS)和多個計算節(jié)點(workers,也稱為節(jié)點 nodes)[1]。其中,隨機梯度下降(Stochastic Gradient Descent,SGD)是一種廣泛使用的、效果較好的分布式優(yōu)化算法。在每一輪中,每個計算節(jié)點根據(jù)不同的本地數(shù)據(jù)集在它的設備上訓練一個本地模型,并與服務器共享最終的參數(shù)。然后,服務器聚合不同計算節(jié)點的參數(shù),并通過與計算節(jié)點共享得到的組合參數(shù)來啟動下一輪訓練。關于基于 SGD 優(yōu)化的分布式框架的網(wǎng)絡結(jié)構(gòu)(包括:層數(shù)、類型、大小等)在訓練開始之前由所有計算節(jié)點共同商定確認。
近年來,分布式學習的安全性越來越受到人們的關注,其中,最重要的就是拜占庭威脅模型。在拜占庭威脅模型中,計算節(jié)點可以任意和惡意地行事。機器之心在前期的文章中也探討過分布式學習中的拜占庭問題,主要針對聯(lián)邦學習中的拜占庭問題。在這篇文章中,我們重點探討的是分布式學習框架中針對隨機梯度下降(SGD)算法的拜占庭問題。如圖 1 所示,在 SGD 學習框架中,一些惡意節(jié)點(Malicious worker)向服務器發(fā)送拜占庭梯度(Byzantine Gradient),而不是計算得到的真實梯度,而拜占庭梯度可以是任意值。惡意節(jié)點可以控制計算節(jié)點設備本身,也可以控制節(jié)點和服務器之間的通信。以 Algorithm 1 中提出的同步 SGD(sync-SGD)協(xié)議為例 [4]。攻擊者(惡意節(jié)點)在使其效果最大化的時間內(nèi)(即在 Algorithm 1 的第 6 行和第 7 行之間)干擾進程。在此期間,攻擊者可以將節(jié)點 i 中的參數(shù)(p_i)^(t+1) 替換為任意值,然后將此任意值發(fā)送到服務器中。攻擊方法在設置參數(shù)值的方式上有所不同,而防御方法則試圖識別損壞的參數(shù)并丟棄它們。Algorithm 1 使用平均值(第 8 行中的 AggregationRule( ))聚合計算節(jié)點參數(shù)。
圖 1. SGD 學習框架工作流程 [3]
本文所討論的分布式學習的核心是這樣一個假設:經(jīng)過訓練的網(wǎng)絡參數(shù)是獨立同分布的(Independent and identically distributed,i.i.d.)