Bayesian Nonparametric Federated Learning of Neural Networks · Bayesian Nonparametric Federated...

15
Bayesian Nonparametric Federated Learning of Neural Networks Mikhail Yurochkin 12 Mayank Agarwal 12 Soumya Ghosh 123 Kristjan Greenewald 12 Trong Nghia Hoang 12 Yasaman Khazaeni 12 Abstract In federated learning problems, data is scattered across different servers and exchanging or pooling it is often impractical or prohibited. We develop a Bayesian nonparametric framework for federated learning with neural networks. Each data server is assumed to provide local neural network weights, which are modeled through our framework. We then develop an inference approach that allows us to synthesize a more expressive global network without additional supervision, data pooling and with as few as a single communication round. We then demonstrate the efficacy of our approach on federated learning problems simulated from two popular image classification datasets. 1 1. Introduction The standard machine learning paradigm involves algo- rithms that learn from centralized data, possibly pooled together from multiple data sources. The computations in- volved may be done on a single machine or farmed out to a cluster of machines. However, in the real world, data often live in silos and amalgamating them may be prohibitively expensive due to communication costs, time sensitivity, or privacy concerns. Consider, for instance, data recorded from sensors embedded in wearable devices. Such data is inher- ently private, can be voluminous depending on the sampling rate of the sensors, and may be time sensitive depending on the analysis of interest. Pooling data from many users is technically challenging owing to the severe computational burden of moving large amounts of data, and is fraught with privacy concerns stemming from potential data breaches that may expose a user’s protected health information (PHI). Federated learning addresses these pitfalls by obviating the 1 IBM Research, Cambridge 2 MIT-IBM Watson AI Lab 3 Center for Computational Health. Correspondence to: Mikhail Yurochkin <[email protected]>. Proceedings of the 36 th International Conference on Machine Learning, Long Beach, California, PMLR 97, 2019. Copyright 2019 by the author(s). 1 Code is available at https://github.com/IBM/ probabilistic-federated-neural-matching need for centralized data, instead designing algorithms that learn from sequestered data sources. These algorithms iter- ate between training local models on each data source and distilling them into a global federated model, all without explicitly combining data from different sources. Typical federated learning algorithms, however, require access to locally stored data for learning. A more extreme case sur- faces when one has access to models pre-trained on local data but not the data itself. Such situations may arise from catastrophic data loss but increasingly also from regulations such as the general data protection regulation (GDPR) (EU, 2016), which place severe restrictions on the storage and sharing of personal data. Learned models that capture only aggregate statistics of the data can typically be dissemi- nated with fewer limitations. A natural question then is, can “legacy” models trained independently on data from different sources be combined into an improved federated model? Here, we develop and carefully investigate a probabilistic federated learning framework with a particular emphasis on training and aggregating neural network models. We assume that either local data or pre-trained models trained on local data are available. When data is available, we proceed by training local models for each data source, in parallel. We then match the estimated local model parameters (groups of weight vectors in the case of neural networks) across data sources to construct a global network. The matching, to be formally defined later, is governed by the posterior of a Beta-Bernoulli process (BBP) (Thibaux & Jordan, 2007), a Bayesian nonparametric (BNP) model that allows the local parameters to either match existing global ones or to create new global parameters if existing ones are poor matches. Our construction provides several advantages over exist- ing approaches. First, it decouples the learning of local models from their amalgamation into a global federated model. This decoupling allows us to remain agnostic about the local learning algorithms, which may be adapted as necessary, with each data source potentially even using a different learning algorithm. Moreover, given only pre- trained models, our BBP informed matching procedure is able to combine them into a federated global model with- out requiring additional data or knowledge of the learning algorithms used to generate the pre-trained models. This is arXiv:1905.12022v1 [stat.ML] 28 May 2019

Transcript of Bayesian Nonparametric Federated Learning of Neural Networks · Bayesian Nonparametric Federated...

Bayesian Nonparametric Federated Learning of Neural Networks

Mikhail Yurochkin 1 2 Mayank Agarwal 1 2 Soumya Ghosh 1 2 3 Kristjan Greenewald 1 2 Trong Nghia Hoang 1 2

Yasaman Khazaeni 1 2

AbstractIn federated learning problems, data is scatteredacross different servers and exchanging or poolingit is often impractical or prohibited. We develop aBayesian nonparametric framework for federatedlearning with neural networks. Each data server isassumed to provide local neural network weights,which are modeled through our framework. Wethen develop an inference approach that allows usto synthesize a more expressive global networkwithout additional supervision, data pooling andwith as few as a single communication round. Wethen demonstrate the efficacy of our approach onfederated learning problems simulated from twopopular image classification datasets.1

1. IntroductionThe standard machine learning paradigm involves algo-rithms that learn from centralized data, possibly pooledtogether from multiple data sources. The computations in-volved may be done on a single machine or farmed out to acluster of machines. However, in the real world, data oftenlive in silos and amalgamating them may be prohibitivelyexpensive due to communication costs, time sensitivity, orprivacy concerns. Consider, for instance, data recorded fromsensors embedded in wearable devices. Such data is inher-ently private, can be voluminous depending on the samplingrate of the sensors, and may be time sensitive depending onthe analysis of interest. Pooling data from many users istechnically challenging owing to the severe computationalburden of moving large amounts of data, and is fraught withprivacy concerns stemming from potential data breaches thatmay expose a user’s protected health information (PHI).

Federated learning addresses these pitfalls by obviating the

1IBM Research, Cambridge 2MIT-IBM Watson AI Lab 3Centerfor Computational Health. Correspondence to: Mikhail Yurochkin<[email protected]>.

Proceedings of the 36 th International Conference on MachineLearning, Long Beach, California, PMLR 97, 2019. Copyright2019 by the author(s).

1Code is available at https://github.com/IBM/probabilistic-federated-neural-matching

need for centralized data, instead designing algorithms thatlearn from sequestered data sources. These algorithms iter-ate between training local models on each data source anddistilling them into a global federated model, all withoutexplicitly combining data from different sources. Typicalfederated learning algorithms, however, require access tolocally stored data for learning. A more extreme case sur-faces when one has access to models pre-trained on localdata but not the data itself. Such situations may arise fromcatastrophic data loss but increasingly also from regulationssuch as the general data protection regulation (GDPR) (EU,2016), which place severe restrictions on the storage andsharing of personal data. Learned models that capture onlyaggregate statistics of the data can typically be dissemi-nated with fewer limitations. A natural question then is,can “legacy” models trained independently on data fromdifferent sources be combined into an improved federatedmodel?

Here, we develop and carefully investigate a probabilisticfederated learning framework with a particular emphasis ontraining and aggregating neural network models. We assumethat either local data or pre-trained models trained on localdata are available. When data is available, we proceed bytraining local models for each data source, in parallel. Wethen match the estimated local model parameters (groups ofweight vectors in the case of neural networks) across datasources to construct a global network. The matching, tobe formally defined later, is governed by the posterior of aBeta-Bernoulli process (BBP) (Thibaux & Jordan, 2007), aBayesian nonparametric (BNP) model that allows the localparameters to either match existing global ones or to createnew global parameters if existing ones are poor matches.

Our construction provides several advantages over exist-ing approaches. First, it decouples the learning of localmodels from their amalgamation into a global federatedmodel. This decoupling allows us to remain agnostic aboutthe local learning algorithms, which may be adapted asnecessary, with each data source potentially even using adifferent learning algorithm. Moreover, given only pre-trained models, our BBP informed matching procedure isable to combine them into a federated global model with-out requiring additional data or knowledge of the learningalgorithms used to generate the pre-trained models. This is

arX

iv:1

905.

1202

2v1

[st

at.M

L]

28

May

201

9

Bayesian Nonparametric Federated Learning of Neural Networks

in sharp contrast with existing work on federated learningof neural networks (McMahan et al., 2017), which requirestrong assumptions about the local learners, for instance,that they share the same random initialization, and are notapplicable for combining pre-trained models. Next, the BNPnature of our model ensures that we recover compressedglobal models with fewer parameters than the cardinalityof the set of all local parameters. Unlike naive ensemblesof local models, this allows us to store fewer parametersand perform more efficient inference at test time, requiringonly a single forward pass through the compressed modelas opposed to J forward passes, once for each local model.While techniques such as knowledge distillation (Hintonet al., 2015) allow for the cost of multiple forward passes tobe amortized, training the distilled model itself requires ac-cess to data pooled across all sources or an auxiliary dataset,luxuries unavailable in our scenario. Finally, even in the tra-ditional federated learning scenario, where local and globalmodels are learned together, we show empirically that ourproposed method outperforms existing distributed trainingand federated learning algorithms (Dean et al., 2012; McMa-han et al., 2017) while requiring far fewer communicationsbetween the local data sources and the global model server.

The remainder of the paper is organized as follows. Webriefly introduce the Beta-Bernoulli process in Section 2before describing our model for federated learning in Sec-tion 3. We thoroughly evaluate the proposed model anddemonstrate its utility empirically in Section 4. Finally, Sec-tion 5 discusses current limitations of our work and openquestions.

2. Background and Related WorksOur approach builds on tools from Bayesian nonparametrics,in particular the Beta-Bernoulli Process (BBP) (Thibaux &Jordan, 2007) and the closely related Indian Buffet Process(IBP) (Griffiths & Ghahramani, 2011). We briefly reviewthese ideas before describing our approach.

2.1. Beta-Bernoulli Process (BBP)

Let Q be a random measure distributed by a Beta processwith mass parameter γ0 and base measure H . That is,Q|γ0, H ∼ BP(1, γ0H). It follows that Q is a discrete (notprobability) measure Q =

∑i qiδθi formed by an infinitely

countable set of (weight, atom) pairs (qi, θi) ∈ [0, 1] × Ω.The weights qi∞i=1 are distributed by a stick-breaking pro-cess (Teh et al., 2007): ci ∼ Beta(γ0, 1), qi =

∏ij=1 cj

and the atoms are drawn i.i.d from the normalized basemeasure θi ∼ H/H(Ω) with domain Ω. In this paper, Ωis simply RD for some D. Subsets of atoms in the ran-dom measure Q are then selected using a Bernoulli pro-cess with a base measure Q. That is, each subset Tj withj = 1, . . . , J is characterized by a Bernoulli process with

base measure Q, Tj |Q ∼ BeP(Q). Each subset Tj is alsoa discrete measure formed by pairs (bji, θi) ∈ 0, 1 × Ω,Tj :=

∑i bjiδθi , where bji|qi ∼ Bernoulli(qi)∀i is a bi-

nary random variable indicating whether atom θi belongs tosubset Tj . The collection of such subsets is then said to bedistributed by a Beta-Bernoulli process.

2.2. Indian Buffet Process (IBP)

The above subsets are conditionally independent givenQ. Thus, marginalizing Q will induce dependenciesamong them. In particular, we have TJ |T1, . . . , TJ−1 ∼BeP

(H γ0

J +∑imiJ δθi

), where mi =

∑J−1j=1 bji (depen-

dency on J is suppressed in the notation for simplicity)and is sometimes called the Indian Buffet Process. TheIBP can be equivalently described by the following culi-nary metaphor. Imagine J customers arrive sequentially ata buffet and choose dishes to sample as follows, the firstcustomer tries Poisson(γ0) dishes. Every subsequent j-thcustomer tries each of the previously selected dishes accord-ing to their popularity, i.e. dish i with probability mi/j, andthen tries Poisson(γ0/j) new dishes.

The IBP, which specifies a distribution over sparse bi-nary matrices with infinitely many columns, was originallydemonstrated for latent factor analysis (Ghahramani & Grif-fiths, 2005). Several extensions to the IBP (and the equiva-lent BBP) have been developed, see Griffiths & Ghahramani(2011) for a review. Our work is related to a recent applica-tion of these ideas to distributed topic modeling (Yurochkinet al., 2018), where the authors use the BBP for model-ing topics learned from multiple collections of document,and provide an inference scheme based on the Hungarianalgorithm (Kuhn, 1955).

2.3. Federated and Distributed Learning

Federated learning has garnered interest from the machinelearning community of late. Smith et al. (2017) pose fed-erated learning as a multi-task learning problem, whichexploits the convexity and decomposability of the cost func-tion of the underlying support vector machine (SVM) modelfor distributed learning. This approach however does notextend to the neural network structure considered in ourwork. McMahan et al. (2017) use strategies based on simpleaveraging of the local learner weights to learn the federatedmodel. However, as pointed out by the authors, such naiveaveraging of model parameters can be disastrous for non-convex cost functions. To cope, they have to use a schemewhere the local learners are forced to share the same randominitialization. In contrast, our proposed framework is natu-rally immune to such issues since its development assumesnothing specific about how the local models were trained.Moreover, unlike existing work in this area, our frameworkis non-parametric in nature allowing the federated model

Bayesian Nonparametric Federated Learning of Neural Networks

to flexibly grow or shrink its complexity (i.e., its size) toaccount for varying data complexity.

There is also significant work on distributed deep learn-ing (Lian et al., 2015; 2017; Moritz et al., 2015; Li et al.,2014; Dean et al., 2012). However, the emphasis of theseworks is on scalable training from large data and they typ-ically require frequent communication between the dis-tributed nodes to be effective. Yet others explore distributedoptimization with a specific emphasis on communication ef-ficiency (Zhang et al., 2013; Shamir et al., 2014; Yang, 2013;Ma et al., 2015; Zhang & Lin, 2015). However, as pointedout by McMahan et al. (2017), these works primarily focuson settings with convex cost functions and often assume thateach distributed data source contains an equal number ofdata instances. These assumptions, in general, do not holdin our scenario. Finally, neither these distributed learningapproaches nor existing federated learning approaches de-couple local training from global model aggregation. As aresult they are not suitable for combining pre-trained legacymodels, a particular problem of interest in this paper.

3. Probabilistic Federated Neural MatchingWe now describe how the Bayesian nonparametric machin-ery can be applied to the problem of federated learning withneural networks. Our goal will be to identify subsets ofneurons in each of the J local models that match neurons inother local models. We will then appropriately combine thematched neurons to form a global model.

Our approach to federated learning builds upon the fol-lowing basic problem. Suppose we have trained J Mul-tilayer Perceptrons (MLPs) with one hidden layer each.For the jth MLP j = 1, . . . , J , let V (0)

j ∈ RD×Lj and

v(0)j ∈ RLj be the weights and biases of the hidden

layer; V (1)j ∈ RLj×K and v(1)j ∈ RK be weights and bi-

ases of the softmax layer; D be the data dimension, Ljthe number of neurons on the hidden layer; and K thenumber of classes. We consider a simple architecture:fj(x) = softmax(σ(xV

(0)j + v

(0)j )V

(1)j + v

(1)j ) where σ(·)

is some nonlinearity (sigmoid, ReLU, etc.). Given the col-lection of weights and biases V (0)

j , v(0)j , V

(1)j , v

(1)j Jj=1 we

want to learn a global neural network with weights and bi-ases Θ(0) ∈ RD×L, θ(0) ∈ RL,Θ(1) ∈ RL×K , θ(1) ∈ RK ,where L

∑Jj=1 Lj is an unknown number of hidden

units of the global network to be inferred.

Our first observation is that ordering of neurons of the hid-den layer of an MLP is permutation invariant. Considerany permutation τ(1, . . . , Lj) of the j-th MLP – reorderingcolumns of V (0)

j , biases v(0)j and rows of V (1)j according to

τ(1, . . . , Lj) will not affect the outputs fj(x) for any valueof x. Therefore, instead of treating weights as matrices and

biases as vectors we view them as unordered collections ofvectors V (0)

j = v(0)jl ∈ RDLjl=1, V (1)j = v(1)jl ∈ RLjKl=1

and scalars v(0)j = v(0)jl ∈ RLjl=1 correspondingly.

Hidden layers in neural networks are commonly viewedas feature extractors. This perspective can be justifiedby the fact that the last layer of a neural network classi-fier simply performs a softmax regression. Since neuralnetworks often outperform basic softmax regression, theymust be learning high quality feature representations ofthe raw input data. Mathematically, in our setup, everyhidden neuron of the j-th MLP represents a new featurexl(v

(0)jl , v

(0)jl ) = σ(〈x, v(0)jl 〉 + v

(0)jl ). Our second observa-

tion is that each (v(0)jl , v(0)jl ) parameterizes the corresponding

neuron’s feature extractor. Since, the J MLPs are trainedon the same general type of data (not necessarily homo-geneous), we assume that they share at least some featureextractors that serve the same purpose. However, due tothe permutation invariance issue discussed previously, a fea-ture extractor indexed by l from the j-th MLP is unlikelyto correspond to a feature extractor with the same indexfrom a different MLP. In order to construct a set of globalfeature extractors (neurons) θ(0)i ∈ RD, θ(0)i ∈ RLi=1 wemust model the process of grouping and combining featureextractors of collection of MLPs.

3.1. Single Layer Neural Matching

We now present the key building block of our framework,a Beta Bernoulli Process (Thibaux & Jordan, 2007) basedmodel of MLP weight parameters. Our model assumes thefollowing generative process. First, draw a collection ofglobal atoms (hidden layer neurons) from a Beta processprior with a base measure H and mass parameter γ0, Q =∑i qiδθi . In our experiments we choose H = N (µ0,Σ0)

as the base measure with µ0 ∈ RD+1+K and diagonalΣ0. Each θi ∈ RD+1+K is a concatenated vector of [θ

(0)i ∈

RD, θ(0)i ∈ R, θ(1)i ∈ RK ] formed from the feature extractorweight-bias pairs with the corresponding weights of thesoftmax regression. In what follows, we will use “batch” torefer to a partition of the data.

Next, for each j = 1, . . . , J select a subset of the globalatoms for batch j via the Bernoulli process:

Tj :=∑i

bjiδθi , where bji|qi ∼ Bern(qi)∀i. (1)

Tj is supported by atoms θi : bji = 1, i = 1, 2, . . .,which represent the identities of the atoms (neurons) usedby batch j. Finally, assume that observed local atoms arenoisy measurements of the corresponding global atoms:

vjl|Tj ∼ N (Tjl,Σj) for l = 1, . . . , Lj ; Lj := card(Tj),(2)

Bayesian Nonparametric Federated Learning of Neural Networks

with vjl = [v(0)jl , v

(0)jl , v

(1)jl ] being the weights, biases, and

softmax regression weights corresponding to the l-th neuronof the j-th MLP trained with Lj neurons on the data ofbatch j.

Under this model, the key quantity to be inferred is thecollection of random variables that match observed atoms(neurons) at any batch to the global atoms. We denote thecollection of these random variables as BjJj=1, whereBji,l = 1 implies that Tjl = θi (there is a one-to-one corre-

spondence between bji∞i=1 andBj).

Maximum a posteriori estimation. We now derive analgorithm for MAP estimation of global atoms for the modelpresented above. The objective function to be maximized isthe posterior of θi∞i=1 and BjJj=1:

arg maxθi,Bj

P (θi, Bj|vjl) (3)

∝ P (vjl|θi, Bj)P (Bj)P (θi).

Note that the next proposition easily follows from Gaussian-Gaussian conjugacy:

Proposition 1. Given Bj, the MAP estimate of θi isgiven by

θi =µ0/σ

20 +

∑j,lB

ji,lvjl/σ

2j

1/σ20 +

∑j,lB

ji,l/σ

2j

for i = 1, . . . , L, (4)

where for simplicity we assume Σ0 = Iσ20 and Σj = Iσ2

j .

Using this fact we can cast optimization corresponding to (3)with respect to only BjJj=1. Taking the natural logarithmwe obtain:

arg maxBj

1

2

∑i

∥∥∥µ0

σ20

+∑j,lB

ji,l

vjlσj

2∥∥∥2

1/σ20 +

∑j,lB

ji,l/σ

2j

+ log(P (Bj).

(5)We consider an iterative optimization approach: fixing allbut oneBj we find corresponding optimal assignment, thenpick a new j at random and proceed until convergence. Inthe following we will use notation −j to denote “all butj”. Let L−j = maxi : B−ji,l = 1 denote number ofactive global weights outside of group j. We now rearrangethe first term of (5) by partitioning it into i = 1, . . . , L−jand i = L−j + 1, . . . , L−j + Lj . We are interested insolving forBj , hence we can modify the objective functionby subtracting terms independent of Bj and noting that∑lB

ji,l ∈ 0, 1, i.e. it is 1 if some neuron from batch j is

matched to global neuron i and 0 otherwise:

1

2

∑i

‖µ0/σ20 +

∑j,lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑j,lB

ji,l/σ

2j

=

L−j+Lj∑i=1

Lj∑l=1

Bji,l

(‖µ0/σ

20 + vjl/σ

2j +

∑−j,lB

ji,lvjl/σ

2j ‖2

1/σ20 + 1/σ2

j +∑−j,lB

ji,l/σ

2j

−‖µ0/σ

20 +

∑−j,lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑−j,lB

ji,l/σ

2j

). (6)

Now we consider the second term of (5):

logP (Bj) = logP (Bj |B−j) + logP (B−j).

First, because we are optimizing for Bj , we can ignorelogP (B−j). Second, due to exchangeability of batches(i.e. customers of the IBP), we can always consider Bj

to be the last batch (i.e. last customer of the IBP). Letm−ji =

∑−j,lB

ji,l denote number of times batch weights

were assigned to global weight i outside of group j. Wethen obtain:

logP (Bj) =

L−j∑i=1

Lj∑l=1

Bji,l logm−ji

J −m−ji(7)

+

L−j+Lj∑i=L−j+1

Lj∑l=1

Bji,l

(log

γ0J− log(i− L−j)

).

Combining (6) and (7) we obtain the assignment cost objec-tive, which we solve with the Hungarian algorithm.

Proposition 2. The (negative) assignment cost specificationfor findingBj is −Cji,l =

∥∥∥∥∥µ0σ20

+vjl

σ2j

+∑−j,l

Bji,lvjl

σ2j

∥∥∥∥∥2

1

σ20+ 1

σ2j

+∑

−j,l Bji,l/σ

2j

∥∥∥∥∥µ0σ20

+∑−j,l

Bji,lvjl

σ2j

∥∥∥∥∥2

1

σ20+∑

−j,l Bji,l/σ

2j

+ 2 logm−ji

J−m−ji

,

i ≤ L−j∥∥∥∥µ0σ20

+vjl

σ2j

∥∥∥∥21

σ20+ 1

σ2j

∥∥∥∥µ0σ20

∥∥∥∥21

σ20

−2 logi−L−jγ0/J

, L−j < i ≤ L−j + Lj .

(8)

We then apply the Hungarian algorithm to find the mini-mizer of

∑i

∑lB

ji,lC

ji,l and obtain the neuron matching

assignments. Proof is described in Supplement A.

We summarize the overall single layer inference procedurein Figure 1 below.

3.2. Multilayer Neural Matching

The model we have presented thus far can handle any ar-bitrary width single layer neural network, which is knownto be theoretically sufficient for approximating any func-tion of interest (Hornik et al., 1989). However, deep neural

Bayesian Nonparametric Federated Learning of Neural Networks

Server 1 Server 2 Server 3

Hidden layers

Outputs

Match and merge neurons to form aggregate layer

Global hidden layer

Input

Input

Outputs

Algorithm 1 Single Layer Neural Matching

1: Collect weights and biases from the J batches andform vjl.

2: Form assignment cost matrix per (8).3: Compute matching assignments Bj using the Hun-

garian algorithm.4: Enumerate all resulting unique global neurons and

use (4) to infer the associated global weight vectorsfrom all instances of the global neurons across theJ batches.

5: Concatenate the global neurons and the inferredweights and biases to form the new global hiddenlayer.

Figure 1: Single layer Probabilistic Federated Neural Matchingalgorithm showing matching of three MLPs. Nodes in the graphsindicate neurons, neurons of the same color have been matched.Our approach consists of using the corresponding neurons in theoutput layer to convert the neurons in each of the J batches toweight vectors referencing the output layer. These weight vectorsare then used to form a cost matrix, which the Hungarian algorithmthen uses to do the matching. The matched neurons are thenaggregated via Proposition 1 to form the global model.

networks with moderate layer widths are known to be bene-ficial both practically (LeCun et al., 2015) and theoretically(Poggio et al., 2017). We extend our neural matching ap-proach to these deep architectures by defining a generativemodel of deep neural network weights from outputs backto inputs (top-down). Let C denote the number of hiddenlayers and Lc the number of neurons on the c-th layer. ThenLC+1 = K is the number of labels and L0 = D is theinput dimension. In the top down approach, we considerthe global atoms to be vectors of outgoing weights from aneuron instead of weights forming a neuron as it was in thesingle hidden layer model. This change is needed to avoidbase measures with unbounded dimensions.

Starting with the top hidden layer c = C, we generate eachlayer following a model similar to that used in the singlelayer case. For each layer we generate a collection of globalatoms and select a subset of them for each batch usingBeta-Bernoulli process construction. Lc+1 is the number of

neurons on the layer c+ 1, which controls the dimension ofthe atoms in layer c.

Definition 1 (Multilayer generative process). Starting withlayer c = C, generate (as in the single layer process)

Qc|γc0, Hc, Lc+1 ∼ BP(1, γc0Hc), (9)

thenQc =∑i

qci δθci , θci ∼ N (µc0,Σ

c0), µc0 ∈ RL

c+1

T cj :=∑i

bcjiδθci , where bcji|qci ∼ Bern(qci ).

This T cj is the set of global atoms (neurons) used by batchj in layer c, it contains atoms θci : bcji = 1, i = 1, 2, . . ..Finally, generate the observed local atoms:

vcjl|T cj ,∼ N (T cjl,Σcj) for l = 1, . . . , Lcj , (10)

where we have set Lcj := card(T cj ). Next, compute thegenerated number of global neurons Lc = card∪Jj=1T cj and repeat this generative process for the next layer c− 1.Repeat until all layers are generated (c = C, . . . , 1).

An important difference from the single layer model is thatwe should now set to 0 some of the dimensions of vcjl ∈RLc+1

since they correspond to weights outgoing to neuronsof the layer c+ 1 not present on the batch j, i.e. vcjli := 0

if bc+1ji = 0 for i = 1, . . . , Lc+1. The resulting model can

be understood as follows. There is a global fully connectedneural network with Lc neurons on layer c and there are Jpartially connected neural networks with Lcj active neuronson layer c, while weights corresponding to the remainingLc − Lcj neurons are zeroes and have no effect locally.

Remark 1. Our model can conceptually handle permutedordering of the input dimensions across batches, howeverin most practical cases the ordering of input dimensionsis consistent across batches. Thus, we assume that theweights connecting the first hidden layer to the inputs exhibitpermutation invariance only on the side of the first hiddenlayer. Similarly to how all weights were concatenated in thesingle hidden layer model, we consider µc0 ∈ RD+Lc+1

forc = 1. We also note that the bias term can be added to themodel, we omitted it to simplify notation.

Inference Following the top-down generative model, weadopt a greedy inference procedure that first infers thematching of the top layer and then proceeds down the layersof the network. This is possible because the generative pro-cess for each layer depends only on the identity and numberof the global neurons in the layer above it, hence once weinfer the c+ 1th layer of the global model we can apply thesingle layer inference algorithm (Algorithm 1) to the cthlayer. This greedy setup is illustrated in Supplement Figure4. The per-layer inference follows directly from the singlelayer case, yielding the following propositions.

Bayesian Nonparametric Federated Learning of Neural Networks

Proposition 3. The (negative) assignment cost specificationfor findingBj,c is −Cj,ci,l =

∥∥∥∥∥ µc0(σc0)

2 +vcjl

(σcj )2 +

∑−j,l

Bj,ci,lvcjl

(σcj )2

∥∥∥∥∥2

1(σc0)

2 + 1(σcj )

2 +∑−j,lB

j,ci,l /(σ

cj)

2+ 2 log

m−j,ci

J −m−j,ci

∥∥∥µc0/(σc0)2 +∑−j,lB

j,ci,l v

cjl/(σ

cj)

2∥∥∥2

1/(σc0)2 +∑−j,lB

j,ci,l /(σ

cj)

2, ,i ≤ Lc−j∥∥∥ µc0

(σc0)2 +

vcjl(σcj )

2

∥∥∥21

(σc0)2 + 1

(σcj )2

−∥∥µc0/(σc0)2

∥∥21/(σc0)2

− 2 logi− Lc−jγ0/J

,

Lc−j < i ≤ Lc−j + Lcj ,

where for simplicity we assume Σc0 = I(σc0)2 and Σc

j =

I(σcj)2. We then apply the Hungarian algorithm to find

the minimizer of∑i

∑lB

j,ci,l C

j,ci,l and obtain the neuron

matching assignments.

Proposition 4. Given the assignment Bj,c, the MAP es-timate of θci is given by

θci =µc0/(σ

c0)2 +

∑j,lB

j,ci,l v

cjl/(σ

cj)

2

1/(σc0)2 +∑j,lB

j,ci,l /(σ

cj)

2for i = 1, . . . , L.

(11)

We combine these propositions and summarize the overallmultilayer inference procedure in Supplement Algorithm 2.

3.3. Neural Matching with Additional Communications

In the traditional federated learning scenario, where localand global models are learned together, common approach(see e.g., McMahan et al. (2017)) is to learn via rounds ofcommunication between local and global models. Typically,local model parameters are trained for few epochs, sent toserver for updating the global model and then reinitializedwith the global model parameters for the new round. One ofthe key factors in federated learning is the number of com-munications required to achieve accurate global model. Inthe preceding sections we proposed Probabilistic FederatedNeural Matching (PFNM) to aggregate local models in asingle communication round. Our approach can be naturallyextended to benefit from additional communication roundsas follows.

Let t denote a communication round. To initialize localmodels at round t + 1 we set vt+1

jl =∑iB

j,ti,l θ

ti . Recall

that∑iB

j,ti,l = 1 ∀l = 1, . . . , Lj , j = 1, . . . , J , hence a

local model is initialized with a subset of the global model,keeping local model size Lj constant across communicationrounds (this also holds for the multilayer case). After localmodels are updated we proceed to apply matching to obtain

new global model. Note that global model size can changeacross communication rounds, in particular we expect it toshrink as local models improve on each step.

4. ExperimentsTo verify our methodology we simulate federated learningscenarios using two standard datasets: MNIST and CIFAR-10. We randomly partition each of these datasets into Jbatches. Two partition strategies are of interest: (a) a homo-geneous partition where each batch has approximately equalproportion of each of the K classes; and (b) a heteroge-neous partition for which batch sizes and class proportionsare unbalanced. We simulate a heterogeneous partition bysimulating pk ∼ DirJ(0.5) and allocating a pk,j proportionof the instances of class k to batch j. Note that due to thesmall concentration parameter (0.5) of the Dirichlet distri-bution, some sampled batches may not have any examplesof certain classes of data. For each of the four combinationsof partition strategy and dataset we run 10 trials to obtainmean performances with standard deviations.

In our empirical studies below, we will show that our frame-work can aggregate multiple local neural networks trainedindependently on different batches of data into an efficient,modest-size global neural network with as few as a sin-gle communication round. We also demonstrate enhancedperformance when additional communication is allowed.

Learning with single communication First we considera scenario where a global neural network needs to be con-structed with a single communication round. This imitatesthe real-world scenario where data is no longer availableand we only have access to pre-trained local models (i.e.“legacy” models). To be useful, this global neural networkneeds to outperform the individual local models. Ensemblemethods (Dietterich, 2000; Breiman, 2001) are a classicapproach for combining predictions of multiple learners.They often perform well in practice even when the ensem-ble members are of poor quality. Unfortunately, in the caseof neural networks, ensembles have large storage and in-ference costs, stemming from having to store and forwardpropagate through all local networks.

The performance of local NNs and the ensemble methoddefine the lower and upper extremes of aggregating whenlimited to a single communication. We also compare toother strong baselines, including federated averaging oflocal neural networks trained with the same random initial-ization as proposed by McMahan et al. (2017). We note thata federated averaging variant without the shared initializa-tion would likely be more realistic when trying to aggregatepre-trained models, but this variant performs significantlyworse than all other baselines. We also consider k-Meansclustering (Lloyd, 1982) of vectors constructed by concate-

Bayesian Nonparametric Federated Learning of Neural Networks

84

86

88

90

92

94

96

98

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsKMeansFedAvg

2 5 10 15 20 25 30No. of batches J; No. of hidden layers C=1

1.5

1.0

0.5

0.0

log(

size)

PFNM

(a) MNIST homogeneous

55

60

65

70

75

80

85

90

95

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsKMeansFedAvg

2 5 10 15 20 25 30No. of batches J; No. of hidden layers C=1

1.5

1.0

0.5

log(

size)

PFNM

(b) MNIST heterogeneous

20

25

30

35

40

45

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsKMeansFedAvg

2 5 10 15 20 25 30No. of batches J; No. of hidden layers C=1

1.5

1.0

0.5

0.0

log(

size)

PFNM

(c) CIFAR10 homogeneous

15

20

25

30

35

40

45

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsKMeansFedAvg

2 5 10 15 20 25 30No. of batches J; No. of hidden layers C=1

1.5

1.0

0.5

0.0

log(

size)

PFNM

(d) CIFAR10 heterogeneous

89

90

91

92

93

94

95

96

97

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsFedAvg

1 2 3 4 5 6No. of hidden layers C; No. of batches J=10

0.7

0.6

0.5

0.4

log(

size)

PFNM

(e) MNIST homogeneous

50

60

70

80

90

100

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsFedAvg

1 2 3 4 5 6No. of hidden layers C; No. of batches J=10

0.8

0.6

0.4

log(

size)

PFNM

(f) MNIST heterogeneous

20

25

30

35

40

45

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsFedAvg

1 2 3 4 5 6No. of hidden layers C; No. of batches J=10

0.7

0.6

0.5

0.4

log(

size)

PFNM

(g) CIFAR homogeneous

10

15

20

25

30

35

40

Test

Acc

urac

y (%

)

PFNMEnsembleLocal NNsFedAvg

1 2 3 4 5 6No. of hidden layers C; No. of batches J=10

0.7

0.6

0.5

0.4

0.3

log(

size)

PFNM

(h) CIFAR heterogeneous

Figure 2: Single communication federated learning. TOP: Test accuracy and normalized model size (log L∑j Lj

) as afunction of varying number of batches (J). BOTTOM: Test accuracy and normalized model size for multi-layer networks asa function of number of layers. PFNM consistently outperforms local models and federated averaging while performingcomparably to ensembles at a fraction of the storage and computational costs.

nating weights and biases of local neural networks. Thekey difference between k-Means and our approach is thatclustering, unlike matching, allows several neurons froma single neural network to be assigned to the same globalneuron, potentially averaging out their individual featurerepresentations. Further, k-Means requires us to choose k,which we set to K = min(500, 50J). In contrast, PFNMnonparametrically learns the global model size and otherhyperparameters, i.e. σ, σ0, γ0, are chosen based on thetraining data. We discuss parameter sensitivity in Supple-ment D.3.

Figure 2 presents our results with single hidden layer neu-ral networks for varying number of batches J . Note thata higher number of batches implies fewer data instancesper batch, leading to poorer local model performances. Theupper plots summarize test data accuracy, while the lowerplots show the model size compression achieved by PFNM.Specifically we plot log L∑

j Lj, which is the log ratio of the

PFNM global model size L to the total number of neuronsacross all local models (i.e. the size of an ensemble model).In this and subsequent experiments each local neural net-work has Lj = 100 hidden neurons. We see that PFNMproduces strong results, occasionally even outperformingensembles. In the heterogeneous setting we observe a no-ticeable degradation in the performance of the local NNs

and of k-means, while PFNM retains its good performance.It is worth noting that the gap between PFNM and ensembleincreases on CIFAR10 with J , while it is constant (and evenin favor of PFNM) on MNIST. This is not surprising. En-semble methods are known to perform particularly well ataggregating “weak” learners (recall higher J implies smallerbatches) (Breiman, 2001), while PFNM assumes the neuralnetworks being aggregated already perform reasonably well.

Next, we investigate aggregation of multi-layer neural net-works, each using a hundred neurons per layer. The ex-tension of k-means to this setting is unclear and k-meansis excluded from further comparisons. In Figure 2, weshow that PFNM again provides drastic and consistent im-provements over local models and federated averaging. Itperforms marginally worse than ensembles, especially fordeeper networks on CIFAR10. This aligns with our previousobservation — when there is insufficient data for traininggood local models, PFNM’s performance marginally de-grades with respect to ensembles, but still provides signifi-cant compression over ensembles.

We briefly discuss complexity of our algorithms and experi-ment run-times in Supplement C.

Bayesian Nonparametric Federated Learning of Neural Networks

50

60

70

80

90

100

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19Communication Rounds

2.00

1.95

1.90

log(

size)

PFNM

(a) MNIST homogeneous

50

60

70

80

90

100

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41 43 45 47 49Communication Rounds

2.1

2.0

log(

size)

PFNM

(b) MNIST heterogeneous

10

15

20

25

30

35

40

45

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19Communication Rounds

2.5

2.0

1.5

log(

size)

PFNM

(c) CIFAR homogeneous

10

15

20

25

30

35

40

45

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41 43 45 47 49Communication Rounds

3.0

2.5

2.0

1.5

log(

size)

PFNM

(d) CIFAR heterogeneous

50

60

70

80

90

100

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19Communication Rounds

2.2

2.1

2.0

log(

size)

PFNM

(e) MNIST homogeneous

50

60

70

80

90

100

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41 43 45 47 49Communication Rounds

2.4

2.3

2.2

2.1

log(

size)

PFNM

(f) MNIST heterogeneous

10

15

20

25

30

35

40

45

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19Communication Rounds

3.0

2.5

2.0

log(

size)

PFNM

(g) CIFAR homogeneous

10

15

20

25

30

35

40

45

Test

Acc

urac

y (%

)

PFNMEnsembleFedAvgD-SGD

1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41 43 45 47 49Communication Rounds

3.0

2.5

2.0

log(

size)

PFNM

(h) CIFAR heterogeneous

Figure 3: Federated learning with communications. Test accuracy and normalized model size as a function of number ofcommunication rounds for J = 25 batches for one (TOP) and two layer (BOTTOM) neural networks. PFNM consistentlyoutperforms strong competitors.

Learning with limited communication While in somescenarios limiting communication to a single communica-tion round may be a hard constraint, we also consider sit-uations, that frequently arise in practice, where a limitedamount of communication is permissible. To this end, we in-vestigate federated learning with J = 25 batches and up totwenty communications when the data has a homogeneouspartition and up to fifty communications under a heteroge-neous partition. We compare PFNM, using the communica-tion procedure from Section 3.3 (σ = σ0 = γ0 = 1 acrossexperiments) to federated averaging and the distributed op-timization approach, downpour SGD (D-SGD) of Deanet al. (2012). In this limited communication setting, the en-sembles can be outperformed by many distributed learningalgorithms provided a large enough communication budget.An interesting metric then is the number of communicationsrounds required to outperform ensembles.

We report results with both one and two layer neural net-works in Figure 3. In either case, we use a hundered neuronsper layer. PFNM outperforms ensembles in all scenariosgiven sufficient communications. Moreover, in all experi-ments, PFNM requires significantly fewer communicationrounds than both federated averaging and D-SGD to achievea given performance level. In addition to improved perfor-mance, additional rounds of communication allow PFNMto shrink the size of the global model as demonstrated in thefigure. In Figures 3a to 3f we note steady improvement in ac-

curacy and reduction in the global model size. In CIFAR10experiments, the two layer PFNM network’s performancetemporarily drops, which corresponds to a sharp reductionin the size of the global network. See Figures 3g and 3h.

5. DiscussionIn this work, we have developed methods for federated learn-ing of neural networks, and empirically demonstrated theirfavorable properties. Our methods are particularly effec-tive at learning compressed federated networks from pre-trained local networks and with a modest communicationbudget can outperform state-of-the-art algorithms for feder-ated learning of neural networks. In future work, we planto explore more sophisticated ways of combining local net-works especially in the regime where each local network hasvery few training instances. Our current matching approachis completely unsupervised – incorporating some form ofsupervision may help further improve the performance ofthe global network, especially when the local networks areof poor quality. Finally, it is of interest to extend our model-ing framework to other architectures such as ConvolutionalNeural Networks (CNNs) and Recurrent Neural Networks(RNNs). The permutation invariance necessitating match-ing inference also arises in CNNs since any permutationof the filters results in the same output, however additionalbookkeeping is needed due to the pooling operations.

Bayesian Nonparametric Federated Learning of Neural Networks

ReferencesBreiman, L. Random forests. Machine learning, 45(1):

5–32, 2001.

Date, K. and Nagi, R. Gpu-accelerated hungarian algorithmsfor the linear assignment problem. Parallel Computing,57:52–72, 2016.

Dean, J., Corrado, G., Monga, R., Chen, K., Devin, M.,Mao, M., Senior, A., Tucker, P., Yang, K., Le, Q. V., et al.Large scale distributed deep networks. In Advances inneural information processing systems, pp. 1223–1231,2012.

Dietterich, T. G. Ensemble methods in machine learning.In International workshop on multiple classifier systems,pp. 1–15. Springer, 2000.

EU. Regulation (EU) 2016/679 of the European Parliamentand of the Council of 27 April 2016 on the protection ofnatural persons with regard to the processing of personaldata and on the free movement of such data, and repealingDirective 95/46/EC (General Data Protection Regulation).Official Journal of the European Union, L119:1–88,may 2016. URL http://eur-lex.europa.eu/legal-content/EN/TXT/?uri=OJ:L:2016:119:TOC.

Ghahramani, Z. and Griffiths, T. L. Infinite latent featuremodels and the Indian buffet process. In Advances inNeural Information Processing Systems, pp. 475–482,2005.

Griffiths, T. L. and Ghahramani, Z. The Indian buffet pro-cess: An introduction and review. Journal of MachineLearning Research, 12:1185–1224, 2011.

Hinton, G., Vinyals, O., and Dean, J. Distillingthe knowledge in a neural network. arXiv preprintarXiv:1503.02531, 2015.

Hornik, K., Stinchcombe, M., and White, H. Multilayerfeedforward networks are universal approximators. Neu-ral networks, 2(5):359–366, 1989.

Kuhn, H. W. The Hungarian method for the assignmentproblem. Naval Research Logistics (NRL), 2(1-2):83–97,1955.

LeCun, Y., Bengio, Y., and Hinton, G. Deep learning. nature,521(7553):436, 2015.

Li, M., Andersen, D. G., Park, J. W., Smola, A. J., Ahmed,A., Josifovski, V., Long, J., Shekita, E. J., and Su, B.-Y.Scaling distributed machine learning with the parameterserver. In OSDI, volume 14, pp. 583–598, 2014.

Lian, X., Huang, Y., Li, Y., and Liu, J. Asynchronousparallel stochastic gradient for nonconvex optimization.In Advances in Neural Information Processing Systems,pp. 2737–2745, 2015.

Lian, X., Zhang, C., Zhang, H., Hsieh, C.-J., Zhang, W.,and Liu, J. Can decentralized algorithms outperformcentralized algorithms? a case study for decentralizedparallel stochastic gradient descent. In Advances in Neu-ral Information Processing Systems 30, pp. 5330–5340.2017.

Lloyd, S. Least squares quantization in PCM. InformationTheory, IEEE Transactions on, 28(2):129–137, Mar 1982.

Ma, C., Smith, V., Jaggi, M., Jordan, M., Richtarik, P., andTakac, M. Adding vs. averaging in distributed primal-dualoptimization. In Proceedings of the 32nd InternationalConference on Machine Learning, pp. 1973–1982, 2015.

McMahan, B., Moore, E., Ramage, D., Hampson, S., andy Arcas, B. A. Communication-efficient learning of deepnetworks from decentralized data. In Artificial Intelli-gence and Statistics, pp. 1273–1282, 2017.

Moritz, P., Nishihara, R., Stoica, I., and Jordan, M. I.Sparknet: Training deep networks in spark. arXiv preprintarXiv:1511.06051, 2015.

Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E.,DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer,A. Automatic differentiation in pytorch. In NIPS-W,2017.

Poggio, T., Mhaskar, H., Rosasco, L., Miranda, B., and Liao,Q. Why and when can deep-but not shallow-networksavoid the curse of dimensionality: a review. InternationalJournal of Automation and Computing, 14(5):503–519,2017.

Reddi, S. J., Kale, S., and Kumar, S. On the conver-gence of adam and beyond. In International Confer-ence on Learning Representations, 2018. URL https://openreview.net/forum?id=ryQu7f-RZ.

Shamir, O., Srebro, N., and Zhang, T. Communication-efficient distributed optimization using an approximatenewton-type method. In International conference onmachine learning, pp. 1000–1008, 2014.

Smith, V., Chiang, C.-K., Sanjabi, M., and Talwalkar, A. S.Federated multi-task learning. In Advances in NeuralInformation Processing Systems, pp. 4424–4434, 2017.

Teh, Y. W., Grur, D., and Ghahramani, Z. Stick-breakingconstruction for the Indian buffet process. In ArtificialIntelligence and Statistics, pp. 556–563, 2007.

Bayesian Nonparametric Federated Learning of Neural Networks

Thibaux, R. and Jordan, M. I. Hierarchical Beta processesand the Indian buffet process. In Artificial Intelligenceand Statistics, pp. 564–571, 2007.

Yang, T. Trading computation for communication: Dis-tributed stochastic dual coordinate ascent. In Advancesin Neural Information Processing Systems, pp. 629–637,2013.

Yurochkin, M., Fan, Z., Guha, A., Koutris, P., and Nguyen,X. Scalable inference of topic evolution via mod-els for latent geometric structures. arXiv preprintarXiv:1809.08738, 2018.

Zhang, Y. and Lin, X. Disco: Distributed optimization forself-concordant empirical loss. In International confer-ence on machine learning, pp. 362–370, 2015.

Zhang, Y., Duchi, J., Jordan, M. I., and Wainwright, M. J.Information-theoretic lower bounds for distributed sta-tistical estimation with communication constraints. InAdvances in Neural Information Processing Systems, pp.2328–2336, 2013.

Bayesian Nonparametric Federated Learning of Neural Networks

A. Single Hidden Layer InferenceThe goal of maximum a posteriori (MAP) estimation isto maximize posterior probability of the latent variables:global atoms θi∞i=1 and assignments of observed neuralnetwork weight estimates to global atoms BjJj=1, givenestimates of the batch weights vjl for l = 1, . . . , LjJj=1:

arg maxθi,Bj

P (θi,Bj|vjl) ∝ (12)

P (vjl|θi, Bj)P (Bj)P (θi).

MAP estimates given matching. First we note that givenBj it is straightforward to find MAP estimates of θibased on Gaussian-Gaussian conjugacy:

θi =

∑j,lB

ji,lvjl/σ

2j

1/σ20 +

∑j,lB

ji,l/σ

2j

for i = 1, . . . , L, (13)

where L = maxi : Bji,l = 1 for l = 1, . . . , Lj , j =1, . . . , J is the number of active global atoms, which is an(unknown) latent random variable identified by Bj. Forsimplicity we assume Σ0 = Iσ2

0 , Σj = Iσ2j and µ0 = 0.

Inference of atom assignments (Proposition 2). We cannow cast optimization corresponding to (12) with respect toonly BjJj=1. Taking natural logarithm we obtain:

−1

2

∑i

‖θi‖2σ20

+D log(2πσ20) +

∑j,l

Bji,l‖vjl − θi‖2

σ2j

+ log(P (Bj). (14)

We now simplify the first term of (14) (in this and subse-quent derivations we use ∼= to say that two objective func-tions are equivalent up to terms independent of the variables

of interest):

−1

2

∑i

‖θi‖2σ20

+D log(2πσ20) +

∑j,l

Bji,l‖vjl − θi‖2

σ2j

=− 1

2

∑i

(〈θi, θi〉σ20

+D log(2πσ20)

+∑j,l

Bji,l〈vjl,vjl〉 − 2〈vjl, θi〉+ 〈θi, θi〉

σ2jl

∼=−

1

2

∑i

〈θi, θi〉 1

σ20

+∑j,l

Bji,lσ2j

+D log(2πσ20)

− 2〈θi,∑j,l

Bji,lvjlσ2j

=

1

2

∑i

〈θi, θi〉 1

σ20

+∑j,l

Bji,lσ2j

−D log(2πσ20)

=

1

2

∑i

(‖∑j,lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑j,lB

ji,l/σ

2j

−D log(2πσ20)

).

(15)

We consider an iterative optimization approach: fixingall but one Bj we find the corresponding optimal assign-ment, then pick a new j at random and repeat until conver-gence. We define notation −j to denote “all but j”, andlet L−j = maxi : B−ji,l = 1 denote number of activeglobal weights outside of group j. We partition (15) be-tween i = 1, . . . , L−j and i = L−j +1, . . . , L−j +Lj , andsince we are solving forBj , we subtract terms independentofBj :

∑i

(‖∑j,lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑j,lB

ji,l/σ

2j

−D log(2πσ20)

)

∼=L−j∑i=1

(‖∑lB

ji,lvjl/σ

2j +

∑−j,lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑lB

ji,l/σ

2j +

∑−j,lB

ji,l/σ

2j

−‖∑−j,lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑−j,lB

ji,l/σ

2j

)

+

L−j+Lj∑i=L−j+1

(‖∑lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑lB

ji,l/σ

2j

). (16)

Now observe that∑lB

ji,l ∈ 0, 1, i.e. it is 1 if some

neuron from batch j is matched to global neuron i and 0otherwise. Due to this we can rewrite (16) as a linear sum

Bayesian Nonparametric Federated Learning of Neural Networks

assignment problem:

L−j∑i=1

Lj∑l=1

Bji,l

(‖vjl/σ2

j +∑−j,lB

ji,lvjl/σ

2j ‖2

1/σ20 + 1/σ2

j +∑−j,lB

ji,l/σ

2j

−‖∑−j,lB

ji,lvjl/σ

2j ‖2

1/σ20 +

∑−j,lB

ji,l/σ

2j

)

+

L−j+Lj∑i=L−j+1

Lj∑l=1

Bji,l

(‖vjl/σ2

j ‖2

1/σ20 + 1/σ2

j

). (17)

Now we consider second term of (14):

logP (Bj) = logP (Bj |B−j) + logP (B−j).

First, because we are optimizing for Bj , we can ignorelogP (B−j). Second, due to exchangeability of batches(i.e. customers of the IBP), we can always consider Bj

to be the last batch (i.e. last customer of the IBP). Letm−ji =

∑−j,lB

ji,l denote number of times batch weights

were assigned to global atom i outside of group j. We nowobtain the following:

logP (Bj |B−j) ∼=L−j∑i=1

Lj∑l=1

Bji,l

logm−jiJ

+

1−Lj∑l=1

Bji,l

logJ −m−ji

J

− log

L−j+Lj∑i=L−j+1

Lj∑l=1

Bji,l

+

L−j+Lj∑i=L−j+1

Lj∑l=1

Bji,l

logγ0J.

(18)

We now rearrange (18) as linear sum assignment problem:

L−j∑i=1

Lj∑l=1

Bji,l logm−ji

J −m−ji(19)

+

L−j+Lj∑i=L−j+1

Lj∑l=1

Bji,l

(log

γ0J− log(i− L−j)

).

Combining (17) and (19) we arrive at the cost specificationfor findingBj as minimizer of

∑i

∑lB

ji,lC

ji,l, where:

Cji,l = (20)

∥∥∥∥∥vjlσ2j

+∑−j,l

Bji,lvjlσ2j

∥∥∥∥∥2

1σ20

+ 1σ2j

+∑−j,lB

ji,l/σ

2j

∥∥∥∥∥∑−j,lBji,l vjlσ2j

∥∥∥∥∥2

1σ20

+∑−j,lB

ji,l/σ

2j

+2 logm−ji

J −m−ji, i ≤ L−j∥∥∥∥ vjl

σ2j

∥∥∥∥21

σ20+ 1

σ2j

− 2 logi−L−jγ0/J

, L−j < i ≤ L−j + Lj .

This completes the proof of Proposition 1.

B. Multilayer Inference DetailsFigure 4 illustrates the overall multilayer inference proce-dure visually, and Algorithm 2 provides the details.

Algorithm 2 Multilayer PFNM

1: LC+1 ← number of outputs2: # Top down iteration through layers3: for layers c = C,C − 1, . . . , 2 do4: Collect hidden layer c from the J batches and form

vcjl.5: Call Single Layer Neural Matching algorithm with

output dimension Lc+1 and input dimension 0 (sincewe do not use the weights connecting to lower layershere).

6: Form global neuron layer c from output of the singlelayer matching.

7: Lc ← card(∪Jj=1T cj ) (greedy approach).8: end for9: # Match bottom layer using weights connecting to both

the input and the layer above.10: Call Single Layer Neural Matching algorithm with out-

put dimension L2 and input dimension equal to thenumber of inputs.

11: Return global assignments and form global mutltilayermodel.

Figure 4: Probabilistic Federated Neural Matching algo-rithm showing matching of three multilayer MLPs. Nodesin the graphs indicate neurons, neurons of the same colorhave been matched. On the left, the individual layer match-ing approach is shown, consisting of using the matchingassignments of the next highest layer to convert the neuronsin each of the J batches to weight vectors referencing theglobal previous layer. These weight vectors are then usedto form a cost matrix, which the Hungarian algorithm thenuses to do the matching. Finally, the matched neurons arethen aggregated and averaged to form the new layer of theglobal model. As shown on the right, in the multilayer set-ting the resulting global layer is then used to match the nextlower layer, etc. until the bottom hidden layer is reached(Steps 1, 2, 3,... in order).

Bayesian Nonparametric Federated Learning of Neural Networks

C. Complexity AnalysisIn this section we present a brief discussion of the complex-ity of our algorithms. The worst case complexity per layeris achieved when no neurons are matched and is equal toO(D(JLj)

2) for building the cost matrix and O((JLj)3)

for running the Hungarian algorithm, where Lj is the num-ber of neurons per batch (here for simplicity we assumethat each batch has same number of neurons) and J is thenumber of batches. The best case complexity per layer (i.e.when all neurons are matched) is O(DL2

j + L3j ), also note

that complexity is independent of the data size. In practicethe complexity is closer to the best case since global modelsize is moderate (i.e. L

∑j Lj). Actual timings with

our code for the experiments in Figure 2 are as follows -40sec for Fig. 2a, 2b at J = 30 groups; 500sec for Fig. 2c,2d at J = 30 (the DL2

j term is dominating as CIFAR10 di-mension is much higher than MNIST); 60sec for Fig. 2e, 2f(J = 10) at C = 6 layers; 150sec for Fig. 2g, 2h (J = 10)at C = 6. The computations were done using 2 CPU coresand 4GB memory on a machine with 3.0 GHz core speed.We note that (i) this computation only needs to be performedonce (ii) the cost matrix construction which appears to bedominating can be trivially sped up using GPUs (iii) recentwork demonstrates impressive large scale running times forthe Hungarian algorithm using GPUs (Date & Nagi, 2016).

D. Experimental Details and AdditionalResults

Data partitioning. In the federated learning setup, weanalyze data from multiple sources, which we call batches.Data on the batches does not overlap and may have differentdistributions. To simulate federated learning scenario weconsider two partition strategies of MNIST and CIFAR-10.For each pair of partition strategy and dataset we run 10trials to obtain mean accuracies and standard deviations.The easier case is homogeneous partitioning, i.e. whenclass distributions on batches are approximately equal aswell as batch sizes. To generate homogeneous partitioningwith J batches we split examples for each of the classesinto J approximately equal parts to form J batches. Inthe heterogeneous case, batches are allowed to have highlyimbalanced class distributions as well as highly variablesizes. To simulate heterogeneous partition, for each classk, we sample pk ∼ DirJ(0.5) and allocate pk,j proportionof instances of class k of the complete dataset to batchj. Note that due to small concentration parameter, 0.5, ofthe Dirichlet distribution, some batches may entirely missexamples of a subset of classes.

Batch networks training. Our modeling framework andensemble related methods operate on collection of weightsof neural networks from all batches. Any optimization pro-

cedure and software can be used locally on batches fortraining neural networks. We used PyTorch (Paszke et al.,2017) to implement the networks and train these using theAMSGrad optimizer (Reddi et al., 2018) with default pa-rameters unless otherwise specified. For reproducibility wesummarize all parameter settings in Table 1.

Table 1: Parameter settings for batch neural networks train-ing

MNIST CIFAR-10

Neurons per layer 100 100Learning rate 0.01 0.001L2 regularization 10−6 10−5

Minibatch size 32 32Epochs 10 10Weights initialization N (0, 0.01) N (0, 0.01)Bias initialization 0.1 0.1

D.1. Parameter Settings for the Baselines

We first formally define the ensemble procedure. Letyj ∈ ∆K−1 denote the probability distribution over theK classes output by neural network trained on data frombatch j for some test input x. Then ensemble predictionis arg max

k

1J

∑Jj=1 yj,k. In our experiments, we train each

individual network on the specific batch dataset using theparameters listed in Table 1, and then compute the perfor-mance using the ensemble aggregation technique.

For the downpour SGD (Dean et al., 2012) we used PyTorch,SGD optimizer and parameter settings as in Table 1 for thelocal learners. The master neural network was optimizedwith Adam and the same initial learning rate as in the Ta-ble 1. The local learners communicated the accumulatedgradients back to the master network after every mini-batchupdate. This translates to the setting of Dean et al. (2012)with parameters npush = nfetch = 1. Note that with thisapproach the global network and networks for each of thebatches are constrained to have identical number of neuronsper layer, which is 100 in our experiments.

For Federated Averaging (McMahan et al., 2017), we useSGD optimizer for learning the local networks with therest of the parameters as defined in Table 1. We initializeall the local networks with the same seed, and train thesenetworks for 10 epochs initially and for 5 epochs after thefirst communication round. At each communication round,we utilize all the local networks (C = 1) for the centralmodel update.

Bayesian Nonparametric Federated Learning of Neural Networks

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

95.3933 95.4068 95.3888

95.3212 95.3157 95.3335

95.3063 95.3050 95.3057

95.2672 95.2840 95.2955

95.1478 95.2502 95.2637

94.9853 95.2158 95.2318

0 = 1.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

95.4417 95.4583 95.4562

95.4163 95.4332 95.4295

95.3873 95.3610 95.3838

95.3622 95.3620 95.3510

95.3027 95.3483 95.3492

95.3225 95.3360 95.3312

0 = 3.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

95.4150 95.4242 95.4225

95.4605 95.4582 95.4393

95.4582 95.4747 95.4710

95.4323 95.4440 95.4252

95.4295 95.4318 95.4388

95.4013 95.4143 95.4400

0 = 10.0

95.0

95.1

95.2

95.3

95.4

(a) MNIST homogeneous

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

94.4292 94.4768 94.5018

94.5450 94.5375 94.4882

94.4638 94.5047 94.4955

94.5090 94.5288 94.3070

94.2318 94.3527 94.4393

94.0678 94.2800 94.3973

0 = 1.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

94.2728 94.0885 94.1670

94.7312 94.7175 94.5803

94.7312 94.5872 94.5915

94.6847 94.6993 94.6422

94.8688 94.7258 94.6542

94.6487 94.6530 94.6942

0 = 3.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

93.7340 93.6930 93.7098

94.0435 94.2225 94.1845

94.5457 94.2068 94.4602

94.6482 94.3805 94.5083

94.3792 94.4712 94.6320

94.6800 94.6518 94.7695

0 = 10.0

93.8

94.0

94.2

94.4

94.6

94.8

(b) MNIST heterogeneous

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

35.1340 35.0230 35.5306

34.7628 34.8712 35.2402

35.1716 34.9454 34.9758

34.0136 34.0430 33.9620

33.8276 33.4218 34.1030

33.8498 34.2772 34.4910

0 = 1.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

35.0364 35.3394 34.9632

35.5362 34.5238 35.1856

35.2376 35.3082 35.2112

35.5894 35.2142 35.6576

35.2376 35.0722 35.5960

34.7778 35.5590 35.4686

0 = 3.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

35.6324 35.8448 35.9914

36.4006 36.6440 36.1948

37.1354 37.1344 37.0354

37.3302 37.4516 37.6512

37.8286 37.7058 37.7158

37.6874 37.9902 37.8416

0 = 10.0

33.6

34.4

35.2

36.0

36.8

37.6

(c) CIFAR homogeneous

1.0 10.0 50.00.

10.

30.

50.

70.

91.

0

30.8466 30.2954 28.5596

29.0242 29.8614 29.5608

28.8498 29.9692 30.8042

28.7532 28.0314 27.3388

26.7638 27.2688 27.8936

28.2316 29.1272 27.8028

0 = 1.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

31.1412 29.6882 29.5392

30.7284 30.9318 29.9834

31.2028 31.7424 31.1014

31.1548 30.3416 30.0494

29.2364 31.0726 30.4670

29.7346 30.8630 30.0380

0 = 3.0

1.0 10.0 50.0

0.1

0.3

0.5

0.7

0.9

1.0

30.1822 30.3986 29.3208

30.9072 31.5166 30.3228

31.0004 31.8636 32.2182

32.0742 32.1638 31.8926

33.3664 33.5056 33.1534

33.6082 33.7252 33.8980

0 = 10.0

27.0

28.5

30.0

31.5

33.0

(d) CIFAR heterogeneous

Figure 5: Parameter sensitivity analysis for J = 25.

0.1 0.3 0.5 0.7 1.02

95.5

96.0

96.5

97.0

97.5

98.0

98.5

Trai

n ac

cura

cy (%

)

J=2J=10J=15J=20J=25

(a) MNIST homogeneous

0.1 0.3 0.5 0.7 1.02

91

92

93

94

95

96

97

Trai

n ac

cura

cy (%

)

J=2J=10J=15J=20J=25

(b) MNIST heterogeneous

0.1 0.3 0.5 0.7 1.02

34

36

38

40

42

44

46

Trai

n ac

cura

cy (%

)

J=2J=10J=15J=20J=25

(c) CIFAR homogeneous

0.1 0.3 0.5 0.7 1.02

27.5

30.0

32.5

35.0

37.5

40.0

42.5

45.0

Trai

n ac

cura

cy (%

)

J=2J=10J=15J=20J=25

(d) CIFAR heterogeneous

Figure 6: Sensitivity analysis of σ2 for fixed σ20 = 10 and

γ0 = 1 for varying J .

D.2. Parameter Settings for Matching with AdditionalCommunications

For neural matching with additional communications, wetrain the local networks for 10 epochs for the first com-munication round, and 5 epochs thereafter. All the otherparameters are as mentioned in Table 1. The local networksare trained using AMSGrad optimizer (Reddi et al., 2018),and the optimizer parameters are reset after every communi-cation. We also found it useful to decay the initial learningrate by a factor of 0.99 after every communication.

D.3. Parameter Sensitivity Analysis for PFNM

Our models presented in Section 3 have three parametersσ20 , γ0 and σ2 = σ2

1 = . . . = σ2J . The first parameter, σ2

0 , isthe prior variance of weights of the global neural network.Second parameter, γ0, controls discovery of new neuronsand correspondingly increasing γ0 increases the size of thelearned global network. The third parameter, σ2, is thevariance of the local neural network weights around corre-sponding global network weights. We empirically analyzethe effect of these parameters on the accuracy for singlehidden layer model with J = 25 batches in Figure 5. Theheatmap indicates the accuracy on the training data - we seethat for all parameter values considered performance doesn’tnot fluctuate significantly. PFNM appears to be robust tochoices of σ2

0 and γ0, which we set to 10 and 1 respectivelyin the experiments with single communication round. Pa-

Bayesian Nonparametric Federated Learning of Neural Networks

rameter σ2 has slightly higher impact on the performanceand we set it using training data during experiments. Toquantify importance of σ2 for fixed σ2

0 = 10 and γ0 = 1 weplot average train data accuracies for varying σ2 in Figure 6.We see that for homogeneous partitioning and one hidden

layer σ2 has almost no effect on the performance (Fig. 6aand Fig. 6c). In the case of heterogeneous partitioning (Fig.6b and Fig. 6d), effect of σ2 is more noticeable, however allconsidered values result in competitive performance.