1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as...

12
1 Scalable and Practical Natural Gradient for Large-Scale Deep Learning Kazuki Osawa, Student Member, IEEE, Yohei Tsuji, Yuichiro Ueno, Akira Naruse, Chuan-Sheng Foo, and Rio Yokota Abstract—Large-scale distributed training of deep neural networks results in models with worse generalization performance as a result of the increase in the effective mini-batch size. Previous approaches attempt to address this problem by varying the learning rate and batch size over epochs and layers, or ad hoc modifications of batch normalization. We propose Scalable and Practical Natural Gradient Descent (SP-NGD), a principled approach for training models that allows them to attain similar generalization performance to models trained with first-order optimization methods, but with accelerated convergence. Furthermore, SP-NGD scales to large mini-batch sizes with a negligible computational overhead as compared to first-order methods. We evaluated SP-NGD on a benchmark task where highly optimized first-order methods are available as references: training a ResNet-50 model for image classification on ImageNet. We demonstrate convergence to a top-1 validation accuracy of 75.4% in 5.5 minutes using a mini-batch size of 32,768 with 1,024 GPUs, as well as an accuracy of 74.9% with an extremely large mini-batch size of 131,072 in 873 steps of SP-NGD. Index Terms—Natural gradient descent, distributed deep learning, deep convolutional neural networks, image classification. 1 I NTRODUCTION A S the size of deep neural network models and the data which they are trained on continues to increase rapidly, the demand for distributed parallel computing is increasing. A common approach for achieving distributed parallelism in deep learning is to use the data-parallel approach, where the data is distributed across different processes while the model is replicated across them. When the mini-batch size per process is kept constant to increase the ratio of computation over communication, the effective mini-batch size over the entire system grows proportional to the number of processes. When the mini-batch size is increased beyond a certain point, the generalization performance starts to degrade. This generalization gap caused by large mini-batch sizes have been studied extensively for various models and datasets [1]. Hoffer et al. [2] attribute this generalization gap to the limited number of updates, and suggest to train longer. This has lead to strategies such as scaling the learning rate proportional to the mini-batch size, while using the first few epochs to gradually warmup the learning rate [3]. Such methods have enabled the training for mini-batch sizes of 8K, where ImageNet [4] with ResNet-50 [5] could be trained for 90 epochs to achieve 76.3% top-1 validation accuracy in 60 minutes [6]. Combining this learning rate scaling with other techniques such as RMSprop warm-up, batch normalization without moving averages, and a slow-start K. Osawa and Y. Tsuji are with the Tokyo Institute of Technology, Japan. E-mail: (oosawa.k.ad, tsuji.y.ae)@m.titech.ac.jp Y. Ueno and R. Yokota are with the Tokyo Institute of Technology, Japan and AIST-Tokyo Tech RWBC-OIL, AIST, Japan. E-mail: [email protected], [email protected] C.S. Foo is with the Institute for Infocomm Research, A*STAR, Singapore. E-mail: foo chuan [email protected] A. Naruse is with NVIDIA, Japan. E-mail: [email protected] Manuscript received MONTH DD, 20XX; revised MONTH DD, 20XX. learning rate schedule, Akiba et al. [7] were able to train the same dataset and model with a mini-batch size of 32K to achieve 74.9% accuracy in 15 minutes. More complex approaches for manipulating the learning rate were proposed, such as LARS [8], where a different learning rate is used for each layer by normalizing them with the ratio between the layer-wise norms of the weights and gradients. This enabled the training with a mini-batch size of 32K without the use of ad hoc modifications, which achieved 74.9% accuracy in 14 minutes (64 epochs) [8]. It has been reported that combining LARS with counter intuitive modifications to the batch normalization, can yield 75.8% accuracy even for a mini-batch size of 65K [9]. The use of small batch sizes to encourage rapid conver- gence in early epochs, and then progressively increasing the batch size is yet another successful approach. Using such an adaptive batch size method, Mikami et al. [10] were able to train in 122 seconds with an accuracy of 75.3%, and Yamazaki et al. [11] were able to train in 75 seconds with a accuracy of 75.1%. The hierarchical synchronization of mini- batches have also been proposed [12], but such methods have not been tested at scale to the extent of the authors’ knowledge. In this work, we take a more mathematically rigorous approach to close the generalization gap when large mini- batches are used. Our approach builds on Natural Gradient Descent (NGD) [14], a second-order optimization method that leverages curvature information to accelerate optimiza- tion. This approach is made feasible by the use of large mini-batches, which enables stable estimation of curvature even in models with a large number of parameters. To this end, we propose an efficient distributed NGD design that scales to massively parallel settings and large mini- batch sizes. In particular, we demonstrate scalability to batch sizes of 32,768 over 1024 GPUs across 256 nodes. Another unique aspect of our approach is the accuracy at which we

Transcript of 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as...

Page 1: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

1

Scalable and Practical Natural Gradientfor Large-Scale Deep Learning

Kazuki Osawa, Student Member, IEEE, Yohei Tsuji, Yuichiro Ueno, Akira Naruse,Chuan-Sheng Foo, and Rio Yokota

Abstract—Large-scale distributed training of deep neural networks results in models with worse generalization performance as aresult of the increase in the effective mini-batch size. Previous approaches attempt to address this problem by varying the learning rateand batch size over epochs and layers, or ad hoc modifications of batch normalization. We propose Scalable and Practical NaturalGradient Descent (SP-NGD), a principled approach for training models that allows them to attain similar generalization performance tomodels trained with first-order optimization methods, but with accelerated convergence. Furthermore, SP-NGD scales to largemini-batch sizes with a negligible computational overhead as compared to first-order methods. We evaluated SP-NGD on a benchmarktask where highly optimized first-order methods are available as references: training a ResNet-50 model for image classification onImageNet. We demonstrate convergence to a top-1 validation accuracy of 75.4% in 5.5 minutes using a mini-batch size of 32,768 with1,024 GPUs, as well as an accuracy of 74.9% with an extremely large mini-batch size of 131,072 in 873 steps of SP-NGD.

Index Terms—Natural gradient descent, distributed deep learning, deep convolutional neural networks, image classification.

F

1 INTRODUCTION

A S the size of deep neural network models and thedata which they are trained on continues to increase

rapidly, the demand for distributed parallel computing isincreasing. A common approach for achieving distributedparallelism in deep learning is to use the data-parallelapproach, where the data is distributed across differentprocesses while the model is replicated across them. Whenthe mini-batch size per process is kept constant to increasethe ratio of computation over communication, the effectivemini-batch size over the entire system grows proportionalto the number of processes.

When the mini-batch size is increased beyond a certainpoint, the generalization performance starts to degrade. Thisgeneralization gap caused by large mini-batch sizes havebeen studied extensively for various models and datasets[1]. Hoffer et al. [2] attribute this generalization gap to thelimited number of updates, and suggest to train longer.This has lead to strategies such as scaling the learning rateproportional to the mini-batch size, while using the firstfew epochs to gradually warmup the learning rate [3]. Suchmethods have enabled the training for mini-batch sizes of8K, where ImageNet [4] with ResNet-50 [5] could be trainedfor 90 epochs to achieve 76.3% top-1 validation accuracyin 60 minutes [6]. Combining this learning rate scalingwith other techniques such as RMSprop warm-up, batchnormalization without moving averages, and a slow-start

• K. Osawa and Y. Tsuji are with the Tokyo Institute of Technology, Japan.E-mail: (oosawa.k.ad, tsuji.y.ae)@m.titech.ac.jp

• Y. Ueno and R. Yokota are with the Tokyo Institute of Technology, Japanand AIST-Tokyo Tech RWBC-OIL, AIST, Japan.E-mail: [email protected], [email protected]

• C.S. Foo is with the Institute for Infocomm Research, A*STAR, Singapore.E-mail: foo chuan [email protected]

• A. Naruse is with NVIDIA, Japan.E-mail: [email protected]

Manuscript received MONTH DD, 20XX; revised MONTH DD, 20XX.

learning rate schedule, Akiba et al. [7] were able to train thesame dataset and model with a mini-batch size of 32K toachieve 74.9% accuracy in 15 minutes.

More complex approaches for manipulating the learningrate were proposed, such as LARS [8], where a differentlearning rate is used for each layer by normalizing themwith the ratio between the layer-wise norms of the weightsand gradients. This enabled the training with a mini-batchsize of 32K without the use of ad hoc modifications, whichachieved 74.9% accuracy in 14 minutes (64 epochs) [8]. It hasbeen reported that combining LARS with counter intuitivemodifications to the batch normalization, can yield 75.8%accuracy even for a mini-batch size of 65K [9].

The use of small batch sizes to encourage rapid conver-gence in early epochs, and then progressively increasing thebatch size is yet another successful approach. Using such anadaptive batch size method, Mikami et al. [10] were ableto train in 122 seconds with an accuracy of 75.3%, andYamazaki et al. [11] were able to train in 75 seconds with aaccuracy of 75.1%. The hierarchical synchronization of mini-batches have also been proposed [12], but such methodshave not been tested at scale to the extent of the authors’knowledge.

In this work, we take a more mathematically rigorousapproach to close the generalization gap when large mini-batches are used. Our approach builds on Natural GradientDescent (NGD) [14], a second-order optimization methodthat leverages curvature information to accelerate optimiza-tion. This approach is made feasible by the use of largemini-batches, which enables stable estimation of curvatureeven in models with a large number of parameters. Tothis end, we propose an efficient distributed NGD designthat scales to massively parallel settings and large mini-batch sizes. In particular, we demonstrate scalability to batchsizes of 32,768 over 1024 GPUs across 256 nodes. Anotherunique aspect of our approach is the accuracy at which we

Page 2: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

2

TABLE 1Training time and top-1 single-crop validation accuracy of ResNet-50 for ImageNet reported by related work and this work.

Hardware Software Batch size Optimizer # Steps Time/step Time Accuracy

Goyal et al. [6] Tesla P100 × 256 Caffe2 8,192 SGD 14,076 0.255 s 1 hr 76.3 %You et al. [8] KNL × 2048 Intel Caffe 32,768 SGD 3,519 0.341 s 20 min 75.4 %

Akiba et al. [7] Tesla P100 × 1024 Chainer 32,768 RMSprop/SGD 3,519 0.255 s 15 min 74.9 %You et al. [8] KNL × 2048 Intel Caffe 32,768 SGD 2,503 0.335 s 14 min 74.9 %Jia et al. [9] Tesla P40 × 2048 TensorFlow 65,536 SGD 1,800 0.220 s 6.6 min 75.8 %

Ying et al. [13] TPU v3 × 1024 TensorFlow 32,768 SGD 3,519 0.037 s 2.2 min 76.3 %Mikami et al. [10] Tesla V100 × 3456 NNL 55,296 SGD 2,086 0.057 s 2.0 min 75.3 %

Yamazaki et al. [11] Tesla V100 × 2048 MXNet 81,920 SGD 1,440 0.050 s 1.2 min 75.1 %

This work

Tesla V100 × 128

Chainer

4,096

SP-NGD

10,948 0.178 s 32.5 min 74.8 %Tesla V100 × 256 8,192 5,434 0.186 s 16.9 min 75.3 %Tesla V100 × 512 16,384 2,737 0.149 s 6.8 min 75.2 %Tesla V100 × 1024 32,768 1,760 0.187 s 5.5 min 75.4 %

n/a 65,536 1,173 n/a n/a 75.6 %n/a 131,072 873 n/a n/a 74.9 %

2000 4000 6000 8000 10000 12000 14000# Steps

74.5

75.0

75.5

76.0

76.5

Top-

1 ac

cura

cy (%

)

Goyal+

Akiba+

Jia+

Ying+

Mikami+

Yamazaki+BS131K

BS65KBS32K

BS16KBS8K

BS4K

0 10 20 30 40 50 60Time (min)

74.5

75.0

75.5

76.0

76.5Goyal+

Akiba+

Jia+

Ying+

Mikami+

Yamazaki+

BS32KBS16K

BS8K

BS4K

SP-NGDSGD

Fig. 1. Top-1 validation accuracy vs the number of steps to converge (left) and vs training time (right) of ResNet-50 on ImageNet (1000 class)classification by related work with SGD and this work with Scalable and Practical NGD (SP-NGD).

can approximate the Fisher information matrix (FIM) whencompared to other second-order optimization methods. Un-like methods that use very crude approximations of the FIM,such as the TONGA [15], Hessian free methods [16], weadopt the Kronecker-Factored Approximate Curvature (K-FAC) method [17]. The two main characteristics of K-FACare that it converges faster than first-order stochastic gradi-ent descent (SGD) methods, and that it can tolerate relativelylarge mini-batch sizes without any ad hoc modifications. K-FAC has been successfully applied to convolutional neuralnetworks [18], distributed memory training of ImageNet[19], recurrent neural networks [20], Bayesian deep learning[21], reinforcement learning [22] , and Transformer models[23].

Our contributions are:

• Extremely large mini-batch training. We were ableto show for the first time that approximated NGDcan achieve similar generalization capability com-pared to highly optimized SGD, by training ResNet-50 on ImageNet classification as a benchmark. Weconverged to over 75% top-1 validation accuracyfor large mini-batch sizes of 4,096, 8,192, 16,384,32,768 and 65,536. We also achieved 74.9% with anextremely large mini-batch size of 131,072, whichtook only 873 steps.

• Scalable natural gradient. We propose a distributed

NGD design using data and model hybrid paral-lelism that shows superlinear scaling up to 64 GPUs.

• Practical natural gradient. We propose practicalNGD techniques based on analysis of the FIM es-timation in large mini-batch settings. Our practicaltechniques make the overhead of NGD compare toSGD almost negligible. Combining these techniqueswith our distributed NGD, we see an ideal scalingup to 1024 GPUs as shown in Figure 5.

• Training ResNet-50 on ImageNet in 5.5 minutes.Using 1024 NVIDIA Tesla V100, we achieve 75.4 %top-1 accuracy with ResNet-50 for ImageNet in 5.5minutes (1760 steps = 45 epochs, including a valida-tion after each epoch). The comparison is shown inFigure 1 and Table 1.

A preliminary version of this manuscript was publishedpreviously [24]. Since then, the performance optimization ofthe distributed second-order optimization has been studied[25], and our distributed NGD framework has been appliedto accelerate Bayesian deep learning with the natural gradi-ent at ImageNet scale [26]. We extend the previous work andpropose Scalable and Practical Natural Gradient Descent (SP-NGD) framework, which includes more detailed analysison the FIM estimation and significant improvements on theperformance of the distributed NGD.

Page 3: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

3

2 RELATED WORK

With respect to large-scale distributed training of deepneural networks, there have been very few studies that usesecond-order methods. At a smaller scale, there have beenprevious studies that used K-FAC to train ResNet-50 onImageNet [19]. However, the SGD they used as referencewas not showing state-of-the-art Top-1 validation accuracy(only around 70%), so the advantage of K-FAC over SGDthat they claim was not obvious from the results. In thepresent work, we compare the Top-1 validation accuracywith state-of-the-art SGD methods for large mini-batchesmentioned in the introduction (Table 1).

The previous studies that used K-FAC to train ResNet-50 on ImageNet [19] also were not considering large mini-batches and were only training with mini-batch size of512 on 8 GPUs. In contrast, the present work uses mini-batch sizes up to 131,072, which is equivalent to 32 perGPU on 4096 GPUs, and we are able to achieve a muchhigher Top-1 validation accuracy of 74.9%. Note that suchlarge mini-batch sizes can also be achieved by accumulatingthe gradient over multiple iterations before updating theparameters, which can mimic the behavior of the executionon many GPUs without actually running them on manyGPUs.

The previous studies using K-FAC also suffered fromlarge overhead of the communication since they used aparameter-server approach for their TensorFlow [27] im-plementation of K-FAC with a single parameter-server.Since the parameter server requires all workers to sendthe gradients and receive the latest model’s parametersfrom the parameter server, the parameter server becomesa huge communication bottleneck especially at large scale.Our implementation uses a decentralized approach usingMPI/NCCL1 collective communications among the pro-cesses. The decentralized approach has been used in highperformance computing for a long time, and is known toscale to thousands of GPUs without modification. Although,software like Horovod2 can alleviate the problems withparameter servers by working as a TensorFlow wrapperfor NCCL, a workable realization of K-FAC requires solv-ing many engineering and modeling challenges, and oursolution is the first one that succeeds on a large scale task.

3 NOTATION AND BACKGROUND

3.1 Mini-batch Stochastic Learning

Throughout this paper, we use E[·] to denote the empiricalexpectation among the samples in the mini-batch {(x, t)},and compute the cross-entropy loss for a supervised learningas

L(θ) = E[− log pθ(t|x)] . (1)

where x, t are the training input and label (one-hot vector),pθ(t|x) is the likelihood of each sample (x, t) calculated bythe probabilistic model using a feed-forward deep neuralnetwork (DNN) with the parameters θ ∈ RN .

1. https://developer.nvidia.com/nccl2. https://github.com/horovod/horovod

For the standard mini-batch stochastic gradient descent(SGD), the parameters θ are updated based on the gradientof the loss function at the current point:

θ(t+1) ← θ(t) − η∇L(θ(t)) . (2)

where η > 0 is the learning rate.

3.2 Natural Gradient Descent in Deep Learning

Natural Gradient Descent (NGD) [14] is an optimizer whichupdates the parameters using the first-order gradient of theloss function preconditioned by the Fisher information matrix(FIM) of the probabilistic model:

θ(t+1) ← θ(t) − η (F + λI)−1∇L(θ(t)) . (3)

The FIM F ∈ RN×N of a DNN with the learnable parameterθ ∈ RN is defined as:

F := Ex∼q[Ey∼pθ

[∇ log pθ(y|x)∇ log pθ(y|x)>

]]. (4)

Ev[·] is an expectation w.r.t. the random variable v, andq is the training data distribution. To limit the step size,a damping value λ > 0 is added to the diagonal of Fbefore inverting it. In the training of DNNs, the FIM maybe thought of as the curvature matrix in parameter space[14], [17], [28].

To realize an efficient NGD training procedure for deepneural networks, we make the following approximations tothe FIM:

• Layer-wise block-diagonal approximation. We as-sume that the correlation between parameters indifferent layers (Figure 2) is negligible and can beignored. This assumption significantly reduces thecomputational cost of inverting F especially whenN is large.

• Stochastic natural gradient. We approximate theexpectation over the input data distribution Ex∼q[·]using the empirical expectation over a mini-batchE[·]. This enables estimation of F during mini-batchstochastic learning.

• Monte Carlo estimation. We approximate the ex-pectation over the model predictive distributionEy∼pθ [·] using a single Monte Carlo sample (a singlebackward-pass). We note that for a K-class classi-fication model, K backward-passes are required toapproximate F .

Using these approximations, we estimate the FIM F` ∈RN`×N` for the `-th layer using a Monte Carlo sampley ∼ pθ(y|x) for each input x in a mini-batch as

F` ≈ F` := E[∇w`

log pθ(y|x)∇w`log pθ(y|x)>

]. (5)

With this F`, the parameters w` ∈ RN` for the `-th layer arethen updated using the FIM preconditioned gradients:

w(t+1)` ← w

(t)` − η

(F

(t)` + λI

)−1∇w`L(t) . (6)

Here ∇w`L(t) ∈ RN` denotes the gradient of the loss

function w.r.t. w` for w` = w(t)` .

Page 4: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

4

1<latexit sha1_base64="8jLNIv/S8WBvHtfFxahEaJPNMnU=">AAAB6HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cW7Ae0oWy2k3btZhN2N0IJ/QVePCji1Z/kzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqGDZZLGLVCahGwSU2DTcCO4lCGgUC28H4bua3n1BpHssHM0nQj+hQ8pAzaqzU8Prlilt15yCrxMtJBXLU++Wv3iBmaYTSMEG17npuYvyMKsOZwGmpl2pMKBvTIXYtlTRC7WfzQ6fkzCoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDGz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m5INwVt+eZW0LqreZdVtXFVqt3kcRTiBUzgHD66hBvdQhyYwQHiGV3hzHp0X5935WLQWnHzmGP7A+fwBe0eMtw==</latexit> L<latexit sha1_base64="9PF+tKjSQN8i8QYPTB/Bgp+YVSU=">AAAB6HicbVA9SwNBEJ3zM8avqKXNYhCswp0KWgZtLCwSMB+QHGFvM5es2ds7dveEcOQX2FgoYutPsvPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDv1W0+oNI/lgxkn6Ed0IHnIGTVWqt/3SmW34s5AlomXkzLkqPVKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZodOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymX5M+V8iMGFtCmeL2VsKGVFFmbDZFG4K3+PIyaZ5XvIuKW78sV2/yOApwDCdwBh5cQRXuoAYNYIDwDK/w5jw6L8678zFvXXHymSP4A+fzB6QzjNI=</latexit>

L<latexit sha1_base64="9PF+tKjSQN8i8QYPTB/Bgp+YVSU=">AAAB6HicbVA9SwNBEJ3zM8avqKXNYhCswp0KWgZtLCwSMB+QHGFvM5es2ds7dveEcOQX2FgoYutPsvPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDv1W0+oNI/lgxkn6Ed0IHnIGTVWqt/3SmW34s5AlomXkzLkqPVKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZodOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymX5M+V8iMGFtCmeL2VsKGVFFmbDZFG4K3+PIyaZ5XvIuKW78sV2/yOApwDCdwBh5cQRXuoAYNYIDwDK/w5jw6L8678zFvXXHymSP4A+fzB6QzjNI=</latexit>

<latexit sha1_base64="tO8xm45MXnYyaIGHhHOD/E2+VfU=">AAAB63icbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cK9gPaUDbbSbt0swm7G6GE/gUvHhTx6h/y5r9x0+agrQ8GHu/NMDMvSATXxnW/ndLa+sbmVnm7srO7t39QPTxq6zhVDFssFrHqBlSj4BJbhhuB3UQhjQKBnWByl/udJ1Sax/LRTBP0IzqSPOSMmlzqoxCDas2tu3OQVeIVpAYFmoPqV38YszRCaZigWvc8NzF+RpXhTOCs0k81JpRN6Ah7lkoaofaz+a0zcmaVIQljZUsaMld/T2Q00noaBbYzomasl71c/M/rpSa88TMuk9SgZItFYSqIiUn+OBlyhcyIqSWUKW5vJWxMFWXGxlOxIXjLL6+S9kXdu6y7D1e1xm0RRxlO4BTOwYNraMA9NKEFDMbwDK/w5kTOi/PufCxaS04xcwx/4Hz+AA4Qjj0=</latexit>

⌦<latexit sha1_base64="/AopNi+UvXmbne7B02/2UzbATF0=">AAAB7nicbVDLSgNBEOyNrxhfUY9eBoPgKeyqoN4CXjxGMA9IljA7mU2GzM4sM71CCPkILx4U8er3ePNvnCR70MSChqKqm+6uKJXCou9/e4W19Y3NreJ2aWd3b/+gfHjUtDozjDeYltq0I2q5FIo3UKDk7dRwmkSSt6LR3cxvPXFjhVaPOE55mNCBErFgFJ3U6moUCbe9csWv+nOQVRLkpAI56r3yV7evWZZwhUxSazuBn2I4oQYFk3xa6maWp5SN6IB3HFXULQkn83On5MwpfRJr40ohmau/JyY0sXacRK4zoTi0y95M/M/rZBjfhBOh0gy5YotFcSYJajL7nfSF4Qzl2BHKjHC3EjakhjJ0CZVcCMHyy6ukeVENLqv+w1WldpvHUYQTOIVzCOAaanAPdWgAgxE8wyu8ean34r17H4vWgpfPHMMfeJ8/g+CPpg==</latexit>

⇡<latexit sha1_base64="YFCHK6WsItiBx7hYztqr3rYTJRk=">AAAB7nicbVBNSwMxEJ2tX7V+VT16CRbBU9nVgnorePFYwX5Au5Rsmm1Ds0lIsmJZ+iO8eFDEq7/Hm//GtN2Dtj4YeLw3w8y8SHFmrO9/e4W19Y3NreJ2aWd3b/+gfHjUMjLVhDaJ5FJ3ImwoZ4I2LbOcdpSmOIk4bUfj25nffqTaMCke7ETRMMFDwWJGsHVSu4eV0vKpX674VX8OtEqCnFQgR6Nf/uoNJEkTKizh2Jhu4CsbZlhbRjidlnqpoQqTMR7SrqMCJ9SE2fzcKTpzygDFUrsSFs3V3xMZToyZJJHrTLAdmWVvJv7ndVMbX4cZEyq1VJDFojjlyEo0+x0NmKbE8okjmGjmbkVkhDUm1iVUciEEyy+vktZFNbis+ve1Sv0mj6MIJ3AK5xDAFdThDhrQBAJjeIZXePOU9+K9ex+L1oKXzxzDH3ifP5FXj68=</latexit>

⇡<latexit sha1_base64="YFCHK6WsItiBx7hYztqr3rYTJRk=">AAAB7nicbVBNSwMxEJ2tX7V+VT16CRbBU9nVgnorePFYwX5Au5Rsmm1Ds0lIsmJZ+iO8eFDEq7/Hm//GtN2Dtj4YeLw3w8y8SHFmrO9/e4W19Y3NreJ2aWd3b/+gfHjUMjLVhDaJ5FJ3ImwoZ4I2LbOcdpSmOIk4bUfj25nffqTaMCke7ETRMMFDwWJGsHVSu4eV0vKpX674VX8OtEqCnFQgR6Nf/uoNJEkTKizh2Jhu4CsbZlhbRjidlnqpoQqTMR7SrqMCJ9SE2fzcKTpzygDFUrsSFs3V3xMZToyZJJHrTLAdmWVvJv7ndVMbX4cZEyq1VJDFojjlyEo0+x0NmKbE8okjmGjmbkVkhDUm1iVUciEEyy+vktZFNbis+ve1Sv0mj6MIJ3AK5xDAFdThDhrQBAJjeIZXePOU9+K9ex+L1oKXzxzDH3ifP5FXj68=</latexit>

1<latexit sha1_base64="8jLNIv/S8WBvHtfFxahEaJPNMnU=">AAAB6HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cW7Ae0oWy2k3btZhN2N0IJ/QVePCji1Z/kzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqGDZZLGLVCahGwSU2DTcCO4lCGgUC28H4bua3n1BpHssHM0nQj+hQ8pAzaqzU8Prlilt15yCrxMtJBXLU++Wv3iBmaYTSMEG17npuYvyMKsOZwGmpl2pMKBvTIXYtlTRC7WfzQ6fkzCoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDGz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m5INwVt+eZW0LqreZdVtXFVqt3kcRTiBUzgHD66hBvdQhyYwQHiGV3hzHp0X5935WLQWnHzmGP7A+fwBe0eMtw==</latexit>

<latexit sha1_base64="tO8xm45MXnYyaIGHhHOD/E2+VfU=">AAAB63icbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cK9gPaUDbbSbt0swm7G6GE/gUvHhTx6h/y5r9x0+agrQ8GHu/NMDMvSATXxnW/ndLa+sbmVnm7srO7t39QPTxq6zhVDFssFrHqBlSj4BJbhhuB3UQhjQKBnWByl/udJ1Sax/LRTBP0IzqSPOSMmlzqoxCDas2tu3OQVeIVpAYFmoPqV38YszRCaZigWvc8NzF+RpXhTOCs0k81JpRN6Ah7lkoaofaz+a0zcmaVIQljZUsaMld/T2Q00noaBbYzomasl71c/M/rpSa88TMuk9SgZItFYSqIiUn+OBlyhcyIqSWUKW5vJWxMFWXGxlOxIXjLL6+S9kXdu6y7D1e1xm0RRxlO4BTOwYNraMA9NKEFDMbwDK/w5kTOi/PufCxaS04xcwx/4Hz+AA4Qjj0=</latexit>

F `<latexit sha1_base64="XBGIeJBncvIVbDIfNsshi8pK/jg=">AAACBHicbVDLSsNAFJ34rPUVddnNYCu4KokKdlkQxGUF+4AmhMlk0g6dZMLMRCwhCzf+ihsXirj1I9z5N07aLLT1wIXDOfdy7z1+wqhUlvVtrKyurW9sVraq2zu7e/vmwWFP8lRg0sWccTHwkSSMxqSrqGJkkAiCIp+Rvj+5Kvz+PRGS8vhOTRPiRmgU05BipLTkmTUn8vlD5vicBRFSY9i4buRe5hDGcs+sW01rBrhM7JLUQYmOZ345AcdpRGKFGZJyaFuJcjMkFMWM5FUnlSRBeIJGZKhpjCIi3Wz2RA5PtBLAkAtdsYIz9fdEhiIpp5GvO4tD5aJXiP95w1SFLTejcZIqEuP5ojBlUHFYJAIDKghWbKoJwoLqWyEeI4Gw0rlVdQj24svLpHfWtM+b1u1Fvd0q46iAGjgGp8AGl6ANbkAHdAEGj+AZvII348l4Md6Nj3nrilHOHIE/MD5/AJHDmAA=</latexit>

F 1<latexit sha1_base64="JZ7dOqzXrefvw4FBHK7nbdpfpic=">AAACAXicbVDLSsNAFJ34rPUVdSO4GWwFVyVRwS4LgrisYB/QhDCZTNqhk0yYmYglxI2/4saFIm79C3f+jZM2C209cOFwzr3ce4+fMCqVZX0bS8srq2vrlY3q5tb2zq65t9+VPBWYdDBnXPR9JAmjMekoqhjpJ4KgyGek54+vCr93T4SkPL5Tk4S4ERrGNKQYKS155qET+fwhc3zOggipEaxf13Mvs3PPrFkNawq4SOyS1ECJtmd+OQHHaURihRmScmBbiXIzJBTFjORVJ5UkQXiMhmSgaYwiIt1s+kEOT7QSwJALXbGCU/X3RIYiKSeRrzuLK+W8V4j/eYNUhU03o3GSKhLj2aIwZVBxWMQBAyoIVmyiCcKC6lshHiGBsNKhVXUI9vzLi6R71rDPG9btRa3VLOOogCNwDE6BDS5BC9yANugADB7BM3gFb8aT8WK8Gx+z1iWjnDkAf2B8/gDkAJZ6</latexit>

F<latexit sha1_base64="b+8+pXUQp7kkZVh9s/561xdGwtU=">AAAB/XicbVDLSgMxFM34rPU1PnZugq3gqsyoYJcFQVxWsA/oDCWTybShyWRIMmIdir/ixoUibv0Pd/6NmXYW2nogcDjnXu7JCRJGlXacb2tpeWV1bb20Ud7c2t7Ztff220qkEpMWFkzIboAUYTQmLU01I91EEsQDRjrB6Cr3O/dEKiriOz1OiM/RIKYRxUgbqW8fejwQD5kXCBZypIewel2d9O2KU3OmgIvELUgFFGj27S8vFDjlJNaYIaV6rpNoP0NSU8zIpOyliiQIj9CA9AyNESfKz6bpJ/DEKCGMhDQv1nCq/t7IEFdqzAMzmSdU814u/uf1Uh3V/YzGSapJjGeHopRBLWBeBQypJFizsSEIS2qyQjxEEmFtCiubEtz5Ly+S9lnNPa85txeVRr2oowSOwDE4BS64BA1wA5qgBTB4BM/gFbxZT9aL9W59zEaXrGLnAPyB9fkD3UOUyg==</latexit>

F L<latexit sha1_base64="2v8D538eE3uycrZydXeNJ8jDWNc=">AAACAXicbVDLSsNAFJ3UV62vqBvBzWAruCqJCnZZEMSFiwr2AU0Ik8m0HTrJhJmJWELc+CtuXCji1r9w5984abPQ1gMXDufcy733+DGjUlnWt1FaWl5ZXSuvVzY2t7Z3zN29juSJwKSNOeOi5yNJGI1IW1HFSC8WBIU+I11/fJn73XsiJOXRnZrExA3RMKIDipHSkmceOKHPH1LH5ywIkRrB2lUt89KbzDOrVt2aAi4SuyBVUKDlmV9OwHESkkhhhqTs21as3BQJRTEjWcVJJIkRHqMh6WsaoZBIN51+kMFjrQRwwIWuSMGp+nsiRaGUk9DXnfmVct7Lxf+8fqIGDTelUZwoEuHZokHCoOIwjwMGVBCs2EQThAXVt0I8QgJhpUOr6BDs+ZcXSee0bp/VrdvzarNRxFEGh+AInAAbXIAmuAYt0AYYPIJn8ArejCfjxXg3PmatJaOY2Qd/YHz+AA0WlpU=</latexit>

F `<latexit sha1_base64="XBGIeJBncvIVbDIfNsshi8pK/jg=">AAACBHicbVDLSsNAFJ34rPUVddnNYCu4KokKdlkQxGUF+4AmhMlk0g6dZMLMRCwhCzf+ihsXirj1I9z5N07aLLT1wIXDOfdy7z1+wqhUlvVtrKyurW9sVraq2zu7e/vmwWFP8lRg0sWccTHwkSSMxqSrqGJkkAiCIp+Rvj+5Kvz+PRGS8vhOTRPiRmgU05BipLTkmTUn8vlD5vicBRFSY9i4buRe5hDGcs+sW01rBrhM7JLUQYmOZ345AcdpRGKFGZJyaFuJcjMkFMWM5FUnlSRBeIJGZKhpjCIi3Wz2RA5PtBLAkAtdsYIz9fdEhiIpp5GvO4tD5aJXiP95w1SFLTejcZIqEuP5ojBlUHFYJAIDKghWbKoJwoLqWyEeI4Gw0rlVdQj24svLpHfWtM+b1u1Fvd0q46iAGjgGp8AGl6ANbkAHdAEGj+AZvII348l4Md6Nj3nrilHOHIE/MD5/AJHDmAA=</latexit>

G`<latexit sha1_base64="p7WtBmPDgLLuOXBjzelfCkAYUy4=">AAACBHicbVDLSsNAFJ34rPUVddnNYCu4KokKdllwocsK9gFNCJPJpB06yYSZiVhCFm78FTcuFHHrR7jzb5y0WWjrgQuHc+7l3nv8hFGpLOvbWFldW9/YrGxVt3d29/bNg8Oe5KnApIs542LgI0kYjUlXUcXIIBEERT4jfX9yVfj9eyIk5fGdmibEjdAopiHFSGnJM2tO5POHzPE5CyKkxrBx3ci9zCGM5Z5Zt5rWDHCZ2CWpgxIdz/xyAo7TiMQKMyTl0LYS5WZIKIoZyatOKkmC8ASNyFDTGEVEutnsiRyeaCWAIRe6YgVn6u+JDEVSTiNfdxaHykWvEP/zhqkKW25G4yRVJMbzRWHKoOKwSAQGVBCs2FQThAXVt0I8RgJhpXOr6hDsxZeXSe+saZ83rduLertVxlEBNXAMToENLkEb3IAO6AIMHsEzeAVvxpPxYrwbH/PWFaOcOQJ/YHz+AJNQmAE=</latexit>

A`�1<latexit sha1_base64="WdxgCPjI8vSf/nTWtinUCiWBK38=">AAACBnicbVDLSsNAFJ3UV62vqEsRBlvBjSVRwS4rblxWsA9oQphMJu3QSSbMTMQSsnLjr7hxoYhbv8Gdf+O0zUJbD1w4nHMv997jJ4xKZVnfRmlpeWV1rbxe2djc2t4xd/c6kqcCkzbmjIuejyRhNCZtRRUjvUQQFPmMdP3R9cTv3hMhKY/v1DghboQGMQ0pRkpLnnnoRD5/yByfsyBCaghrV7XcyxzC2Kmde2bVqltTwEViF6QKCrQ888sJOE4jEivMkJR920qUmyGhKGYkrzipJAnCIzQgfU1jFBHpZtM3cnislQCGXOiKFZyqvycyFEk5jnzdOTlVznsT8T+vn6qw4WY0TlJFYjxbFKYMKg4nmcCACoIVG2uCsKD6VoiHSCCsdHIVHYI9//Ii6ZzV7fO6dXtRbTaKOMrgAByBE2CDS9AEN6AF2gCDR/AMXsGb8WS8GO/Gx6y1ZBQz++APjM8fdx6YbQ==</latexit>

F `,unitBN<latexit sha1_base64="IyiqVnUnB9RfTUDhPtqFFNFx3l0=">AAACEnicbVDJSgNBEO2JW4zbqEcvjYmgIGFGBXMMCuJJIpgFMiH0dDpJk16G7h4xDPMNXvwVLx4U8erJm39jZzlo4oOCx3tVVNULI0a18bxvJ7OwuLS8kl3Nra1vbG652zs1LWOFSRVLJlUjRJowKkjVUMNII1IE8ZCReji4HPn1e6I0leLODCPS4qgnaJdiZKzUdo8CHsqHJAgl63Bk+rBwVUjbSUAYO04CxWEsqIEXN2nadvNe0RsDzhN/SvJgikrb/Qo6EsecCIMZ0rrpe5FpJUgZihlJc0GsSYTwAPVI01KBONGtZPxSCg+s0oFdqWwJA8fq74kEca2HPLSdo7P1rDcS//OasemWWgkVUWyIwJNF3ZhBI+EoH9ihimDDhpYgrKi9FeI+Uggbm2LOhuDPvjxPaidF/7To3Z7ly6VpHFmwB/bBIfDBOSiDa1ABVYDBI3gGr+DNeXJenHfnY9KacaYzu+APnM8fxMudew==</latexit>

Layer

Layer

Layer

……

1<latexit sha1_base64="8jLNIv/S8WBvHtfFxahEaJPNMnU=">AAAB6HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cW7Ae0oWy2k3btZhN2N0IJ/QVePCji1Z/kzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqGDZZLGLVCahGwSU2DTcCO4lCGgUC28H4bua3n1BpHssHM0nQj+hQ8pAzaqzU8Prlilt15yCrxMtJBXLU++Wv3iBmaYTSMEG17npuYvyMKsOZwGmpl2pMKBvTIXYtlTRC7WfzQ6fkzCoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDGz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m5INwVt+eZW0LqreZdVtXFVqt3kcRTiBUzgHD66hBvdQhyYwQHiGV3hzHp0X5935WLQWnHzmGP7A+fwBe0eMtw==</latexit>

<latexit sha1_base64="tO8xm45MXnYyaIGHhHOD/E2+VfU=">AAAB63icbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cK9gPaUDbbSbt0swm7G6GE/gUvHhTx6h/y5r9x0+agrQ8GHu/NMDMvSATXxnW/ndLa+sbmVnm7srO7t39QPTxq6zhVDFssFrHqBlSj4BJbhhuB3UQhjQKBnWByl/udJ1Sax/LRTBP0IzqSPOSMmlzqoxCDas2tu3OQVeIVpAYFmoPqV38YszRCaZigWvc8NzF+RpXhTOCs0k81JpRN6Ah7lkoaofaz+a0zcmaVIQljZUsaMld/T2Q00noaBbYzomasl71c/M/rpSa88TMuk9SgZItFYSqIiUn+OBlyhcyIqSWUKW5vJWxMFWXGxlOxIXjLL6+S9kXdu6y7D1e1xm0RRxlO4BTOwYNraMA9NKEFDMbwDK/w5kTOi/PufCxaS04xcwx/4Hz+AA4Qjj0=</latexit>

L<latexit sha1_base64="9PF+tKjSQN8i8QYPTB/Bgp+YVSU=">AAAB6HicbVA9SwNBEJ3zM8avqKXNYhCswp0KWgZtLCwSMB+QHGFvM5es2ds7dveEcOQX2FgoYutPsvPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDv1W0+oNI/lgxkn6Ed0IHnIGTVWqt/3SmW34s5AlomXkzLkqPVKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZodOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymX5M+V8iMGFtCmeL2VsKGVFFmbDZFG4K3+PIyaZ5XvIuKW78sV2/yOApwDCdwBh5cQRXuoAYNYIDwDK/w5jw6L8678zFvXXHymSP4A+fzB6QzjNI=</latexit>

x<latexit sha1_base64="DQivoxdGTcXKHyd2yP4tsN1A+g4=">AAAB+HicbVDLSgMxFL3js9ZHR126CRbBVZlRQd0V3LisYB/QDiWTSdvQTDIkGbEO/RI3LhRx66e482/MtLPQ1gMhh3PuJScnTDjTxvO+nZXVtfWNzdJWeXtnd6/i7h+0tEwVoU0iuVSdEGvKmaBNwwynnURRHIectsPxTe63H6jSTIp7M0loEOOhYANGsLFS3630QskjPYntlT1Oy3236tW8GdAy8QtShQKNvvvViyRJYyoM4Vjrru8lJsiwMoxwOi33Uk0TTMZ4SLuWChxTHWSz4FN0YpUIDaSyRxg0U39vZDjWeTY7GWMz0oteLv7ndVMzuAoyJpLUUEHmDw1SjoxEeQsoYooSwyeWYKKYzYrICCtMjO0qL8Ff/PIyaZ3V/POad3dRrV8XdZTgCI7hFHy4hDrcQgOaQCCFZ3iFN+fJeXHenY/56IpT7BzCHzifPwMuk0c=</latexit>

K-FAC(Conv, FC)

Unit-wise(BatchNorm)

Feed-forward DNN Layer-wise block-diagonal FIM Practical FIM estimation

F `<latexit sha1_base64="d4XskDj/BLk4GxrsAe7ZVeqWxCY=">AAACC3icbVDLSsNAFJ34rPUVdelmaCu4KokKuiwK4rKCfUATwmQyaYdOMmFmIpaQvRt/xY0LRdz6A+78GydtFtp64MLhnHu59x4/YVQqy/o2lpZXVtfWKxvVza3tnV1zb78reSow6WDOuOj7SBJGY9JRVDHSTwRBkc9Izx9fFX7vnghJeXynJglxIzSMaUgxUlryzJozQipzIp8/ZI7PWRAhNYKN60aee5lDGMurnlm3mtYUcJHYJamDEm3P/HICjtOIxAozJOXAthLlZkgoihnJq04qSYLwGA3JQNMYRUS62fSXHB5pJYAhF7piBafq74kMRVJOIl93FqfKea8Q//MGqQov3IzGSapIjGeLwpRBxWERDAyoIFixiSYIC6pvhXiEBMJKx1eEYM+/vEi6J037tGndntVbl2UcFXAIauAY2OActMANaIMOwOARPINX8GY8GS/Gu/Exa10yypkD8AfG5w/iX5rr</latexit>

p✓(y|x)<latexit sha1_base64="oLfFcH9NvBiapvErNtK4pt4dg2Y=">AAACKXicbVDLSgMxFM34rPVVdekm2Ap1U2ZUsMuCG5cV7AM6Q8mkaRuaeZDckQ7j/I4bf8WNgqJu/REz7Sy09ULIyTn3knuOGwquwDQ/jZXVtfWNzcJWcXtnd2+/dHDYVkEkKWvRQASy6xLFBPdZCzgI1g0lI54rWMedXGd6555JxQP/DuKQOR4Z+XzIKQFN9UuNsJ/YbiAGKvb0ZcOYAUmrtn5M54JHYIwrcSV9WCKnlfSsXyqbNXNWeBlYOSijvJr90qs9CGjkMR+oIEr1LDMEJyESOBUsLdqRYiGhEzJiPQ194jHlJDOnKT7VzAAPA6mPD3jG/p5IiKcyJ7ozW1Etahn5n9aLYFh3Eu6HETCfzj8aRgJDgLPY8IBLRkHEGhAqud4V0zGRhIIOt6hDsBYtL4P2ec26qJm3l+VGPY+jgI7RCaoiC12hBrpBTdRCFD2iZ/SG3o0n48X4ML7mrStGPnOE/pTx/QN386ff</latexit>

(Sec. 3.2)

(Sec. 4.2)

(Sec. 3.3)

Fig. 2. Illustration of Fisher information matrix approximations for feed-forward deep neural networks used in this work.

3.3 K-FAC

Kronecker-Factored Approximate Curvature (K-FAC) [17]is a second-order optimization method for deep neuralnetworks, that is based on an accurate and mathematicallyrigorous approximation of the FIM. Using K-FAC, we fur-ther approximate the FIM F` for `-th layer as a Kroneckerproduct of two matrices (Figure 2):

F` ≈ F` = G` ⊗A`−1 . (7)

This is called Kronecker factorization and G`,A`−1 are calledKronecker factors. G` is computed from the gradient of theloss w.r.t. the output of the `-th layer, andA`−1 is computedfrom the activation of the (` − 1)-th layer (the input of `-thlayer).

The definition and the sizes of the Kronecker factorsG`/A`−1 depend on the dimension of the output/input andthe type of layer [17], [18], [20], [23].

3.3.1 K-FAC for fully-connected layers.In a fully-connected (FC) layer of a feed-forward DNN, theoutput s` ∈ Rd` is calculated as

s` ←W`a`−1, (8)

where a`−1 ∈ Rd`−1 is the input to this layer (the activationfrom the previous layer), and W` ∈ Rd`×d`−1 is the weightmatrix (the bias is ignored for simplicity). The Kroneckerfactors for this FC layer are defined as

G` := E[∇s` log pθ(y|x)∇s` log pθ(y|x)>

],

A`−1 := E[a`−1a

>`−1],

(9)

and G` ∈ Rd`×d` ,A`−1 ∈ Rd`−1×d`−1 [17]. From this defini-tion, we can consider that K-FAC is based on an assumptionthat the input to the layer and the gradient w.r.t. the layeroutput are statistically independent.

3.3.2 K-FAC for convolutional layersIn a convolutional (Conv) layer of a feed-forward DNN, theoutput S` ∈ Rc`×h`×w` is calculated as

MA`−1← im2col(A`−1) ∈ Rc`−1k

2`×h`w` ,

MS` ←W`MA`−1∈ Rc`×h`w` ,

S` ← reshape(MS`) ∈ Rc`×h`×w` ,

(10)

where A`−1 ∈ Rc`−1×h`−1×w`−1 is the input to this layer,and W` ∈ Rc`×c`−1k

2` is the weight matrix (the bias is

ignored for simplicity). c`, c`−1 are the number of output,input channels, respectively, and k` is the kernel size (as-suming square kernels for simplicity). The Kronecker factorsfor this Conv layer are defined as

G` := E[∇MS` log pθ(y|x)∇MS` log pθ(y|x)>

],

A`−1 :=1

h`w`E[MA`−1

M>A`−1

],

(11)

and G` ∈ Rc`×c` ,A`−1 ∈ Rc`−1k2`×c`−1k

2` [18].

3.3.3 Inverting Kronecker-factored FIM

By the property of the Kronecker product and the Tikhonovdamping technique used in [17], the inverse of F` + λI isapproximated by the Kronecker product of the inverse ofeach Kronecker factor(F` + λI

)−1= (G` ⊗A`−1 + λI)

−1

≈(G` +

1

π`

√λI

)−1⊗(A`−1 + π`

√λI)−1

,

(12)

where π2` is the average eigenvalue of A`−1 divided by the

average eigenvalue ofG`. π > 0 because bothG` andA`−1are positive-semidefinite matrices as defined above.

4 PRACTICAL NATURAL GRADIENT

K-FAC for FC layer (9) and Conv layer (11) enables usto realize NGD in training deep ConvNets [18], [19]. Fordeep and wide neural architectures with a huge number oflearnable parameters, however, due to the extra computa-tion for the FIM, even NGD with K-FAC has considerableoverhead compared to SGD. In this section, we introducepractical techniques to accelerate NGD for such huge neuralarchitectures. Using our techniques, we are able to reducethe overhead of NGD to almost a negligible amount asshown in Section 7.

Page 5: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

5

Algorithm 1: Natural gradient with the stale statistics.input : set of the statistics S (damped inverses)input : initial parameters θoutput: parameters θ

t← 1foreach X ∈ S do

tX ← 1endwhile not converge do

foreach X ∈ S doif t == tX then

refresh statistics X∆,∆−1 ←Algorithm 2(X,X−1, X−2,∆,∆−1)tX ← tX + ∆X−1, X−2 ← X,X−1

endupdate θ by natural gradient (6) using St← t+ 1

endreturn θ

Algorithm 2: Estimate the acceptable interval until nextrefresh based on the staleness of statistics.

input : current statistics Xinput : last statistics X−1input : statistics before the last X−2input : last interval ∆−1input : interval before the last ∆−2output: next interval, last interval ∆,∆−1

if X is not similar to X−1 then∆← max {1, b∆−1/2c}

else if X is not similar to X−2 then∆← ∆−1

else∆← ∆−1 + ∆−2

endreturn ∆, ∆−1

4.1 Fast Estimation with Empirical Fisher

Instead of using an estimation by a single Monte Carlosampling defined in Eq. (5) (F`,1mc), we adopt the empiricalFisher [17] to estimate the FIM F`:

F` ≈ F`,emp := E[∇w`

log pθ(t|x)∇w`log pθ(t|x)>

].

(13)We implemented an efficient F`,emp computation in theChainer framework [29] that allows us to compute F`,emp

during the forward-pass and the backward-pass for the lossL(θ)3. Therefore, we do not need an extra backward-pass tocompute F`,emp, while an extra backward-pass is necessary forcomputing F`,1mc. This difference is critical especially for adeeper network which takes longer time for a backward-pass.

3. The same empirical Fisher computation can be implemented onPyTorch.

��2<latexit sha1_base64="/uxtYGuOBkRzofj0LwCKsCRXd2c=">AAAB8nicbVBNS8NAEJ34WetX1aOXYBG8WJIq2GNBDx4r2A9IS9lsN+3SzW7YnQgl9Gd48aCIV3+NN/+N2zYHbX0w8Hhvhpl5YSK4Qc/7dtbWNza3tgs7xd29/YPD0tFxy6hUU9akSijdCYlhgkvWRI6CdRLNSBwK1g7HtzO//cS04Uo+4iRhvZgMJY84JWiloHvHBJJ+dlmd9ktlr+LN4a4SPydlyNHol766A0XTmEmkghgT+F6CvYxo5FSwabGbGpYQOiZDFlgqScxML5ufPHXPrTJwI6VtSXTn6u+JjMTGTOLQdsYER2bZm4n/eUGKUa2XcZmkyCRdLIpS4aJyZ/+7A64ZRTGxhFDN7a0uHRFNKNqUijYEf/nlVdKqVvyrivdwXa7X8jgKcApncAE+3EAd7qEBTaCg4Ble4c1B58V5dz4WrWtOPnMCf+B8/gC9ypDc</latexit>

��1<latexit sha1_base64="1lw9VMZwLZBF0nhY1fFxIqs+jns=">AAAB8nicbVBNS8NAEJ3Ur1q/qh69LBbBiyVRwR4LevBYwX5AWspmu2mXbjZhdyKU0J/hxYMiXv013vw3btsctPXBwOO9GWbmBYkUBl332ymsrW9sbhW3Szu7e/sH5cOjlolTzXiTxTLWnYAaLoXiTRQoeSfRnEaB5O1gfDvz209cGxGrR5wkvBfRoRKhYBSt5HfvuETazy68ab9ccavuHGSVeDmpQI5Gv/zVHcQsjbhCJqkxvucm2MuoRsEkn5a6qeEJZWM65L6likbc9LL5yVNyZpUBCWNtSyGZq78nMhoZM4kC2xlRHJllbyb+5/kphrVeJlSSIldssShMJcGYzP4nA6E5QzmxhDIt7K2EjaimDG1KJRuCt/zyKmldVr2rqvtwXanX8jiKcAKncA4e3EAd7qEBTWAQwzO8wpuDzovz7nwsWgtOPnMMf+B8/gC8RZDb</latexit>

�<latexit sha1_base64="u6jOjrrPrrH6NNpVh6FDZXbtl0s=">AAAB7XicbVBNS8NAEN3Ur1q/qh69BIvgqSQq2GNBDx4r2A9oQ9lsJ+3azW7YnQgl9D948aCIV/+PN/+N2zYHbX0w8Hhvhpl5YSK4Qc/7dgpr6xubW8Xt0s7u3v5B+fCoZVSqGTSZEkp3QmpAcAlN5Cigk2igcSigHY5vZn77CbThSj7gJIEgpkPJI84oWqnVuwWBtF+ueFVvDneV+DmpkByNfvmrN1AsjUEiE9SYru8lGGRUI2cCpqVeaiChbEyH0LVU0hhMkM2vnbpnVhm4kdK2JLpz9fdERmNjJnFoO2OKI7PszcT/vG6KUS3IuExSBMkWi6JUuKjc2evugGtgKCaWUKa5vdVlI6opQxtQyYbgL7+8SloXVf+y6t1fVeq1PI4iOSGn5Jz45JrUyR1pkCZh5JE8k1fy5ijnxXl3PhatBSefOSZ/4Hz+AF40jvQ=</latexit>

tX�1<latexit sha1_base64="ZFj095VISfUmJOKZUeYR6IkxQEg=">AAAB8XicbVBNS8NAEJ3Ur1q/qh69BIvgxZKoYI8FLx4r2A9sQ9hsN+3SzSbsToQS8i+8eFDEq//Gm//GbZuDtj4YeLw3w8y8IBFco+N8W6W19Y3NrfJ2ZWd3b/+genjU0XGqKGvTWMSqFxDNBJesjRwF6yWKkSgQrBtMbmd+94kpzWP5gNOEeREZSR5yStBIj+hnPT+7cPPcr9acujOHvUrcgtSgQMuvfg2GMU0jJpEKonXfdRL0MqKQU8HyyiDVLCF0Qkasb6gkEdNeNr84t8+MMrTDWJmSaM/V3xMZibSeRoHpjAiO9bI3E//z+imGDS/jMkmRSbpYFKbCxtievW8PuWIUxdQQQhU3t9p0TBShaEKqmBDc5ZdXSeey7l7VnfvrWrNRxFGGEziFc3DhBppwBy1oAwUJz/AKb5a2Xqx362PRWrKKmWP4A+vzB1+8kK4=</latexit>

tX�2<latexit sha1_base64="RwUgobEuC1GAwjsa9EeWudf/JR0=">AAAB8XicbVBNS8NAEJ34WetX1aOXYBG8WJIq2GPBi8cK9gPbEDbbTbt0swm7E6GE/AsvHhTx6r/x5r9x2+agrQ8GHu/NMDMvSATX6Djf1tr6xubWdmmnvLu3f3BYOTru6DhVlLVpLGLVC4hmgkvWRo6C9RLFSBQI1g0mtzO/+8SU5rF8wGnCvIiMJA85JWikR/Sznp9d1vPcr1SdmjOHvUrcglShQMuvfA2GMU0jJpEKonXfdRL0MqKQU8Hy8iDVLCF0Qkasb6gkEdNeNr84t8+NMrTDWJmSaM/V3xMZibSeRoHpjAiO9bI3E//z+imGDS/jMkmRSbpYFKbCxtievW8PuWIUxdQQQhU3t9p0TBShaEIqmxDc5ZdXSadec69qzv11tdko4ijBKZzBBbhwA024gxa0gYKEZ3iFN0tbL9a79bFoXbOKmRP4A+vzB2FCkK8=</latexit>

tX<latexit sha1_base64="7wYGO1I3Kv2vSBneVG5IPSomCNU=">AAAB6nicbVBNS8NAEJ3Ur1q/qh69LBbBU0m0YI8FLx4r2g9oQ9lsN+3SzSbsToQS+hO8eFDEq7/Im//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqGW+xWMa6G1DDpVC8hQIl7yaa0yiQvBNMbud+54lrI2L1iNOE+xEdKREKRtFKDzjoDsoVt+ouQNaJl5MK5GgOyl/9YczSiCtkkhrT89wE/YxqFEzyWamfGp5QNqEj3rNU0YgbP1ucOiMXVhmSMNa2FJKF+nsio5Ex0yiwnRHFsVn15uJ/Xi/FsO5nQiUpcsWWi8JUEozJ/G8yFJozlFNLKNPC3krYmGrK0KZTsiF4qy+vk/ZV1buuuve1SqOex1GEMziHS/DgBhpwB01oAYMRPMMrvDnSeXHenY9la8HJZ07hD5zPHz5cjbs=</latexit>

t<latexit sha1_base64="ZXEgEOUFHF4QZ8VPQ+GiYQy6k0w=">AAAB6HicbVBNS8NAEN3Ur1q/qh69LBbBU0lUsMeCF48t2A9oQ9lsJ+3azSbsToQS+gu8eFDEqz/Jm//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqObR4LGPdDZgBKRS0UKCEbqKBRYGETjC5m/udJ9BGxOoBpwn4ERspEQrO0EpNHJQrbtVdgK4TLycVkqMxKH/1hzFPI1DIJTOm57kJ+hnTKLiEWamfGkgYn7AR9CxVLALjZ4tDZ/TCKkMaxtqWQrpQf09kLDJmGgW2M2I4NqveXPzP66UY1vxMqCRFUHy5KEwlxZjOv6ZDoYGjnFrCuBb2VsrHTDOONpuSDcFbfXmdtK+q3nXVbd5U6rU8jiI5I+fkknjkltTJPWmQFuEEyDN5JW/Oo/PivDsfy9aCk8+ckj9wPn8A3dGM8A==</latexit>

tXnext<latexit sha1_base64="s0Ifm+DQTJ9gGHpN5u0PcdiS5u0=">AAAB+3icbVBNS8NAEN34WetXrEcvi0XwVBIV7LHgxWMF+wFtCJvttl26uwm7E2kJ+StePCji1T/izX/jts1BWx8MPN6bYWZelAhuwPO+nY3Nre2d3dJeef/g8OjYPam0TZxqylo0FrHuRsQwwRVrAQfBuolmREaCdaLJ3dzvPDFteKweYZawQJKR4kNOCVgpdCsQZt0wy/paYsWmkOd56Fa9mrcAXid+QaqoQDN0v/qDmKaSKaCCGNPzvQSCjGjgVLC83E8NSwidkBHrWaqIZCbIFrfn+MIqAzyMtS0FeKH+nsiINGYmI9spCYzNqjcX//N6KQzrQcZVkgJTdLlomAoMMZ4HgQdcMwpiZgmhmttbMR0TTSjYuMo2BH/15XXSvqr51zXv4abaqBdxlNAZOkeXyEe3qIHuURO1EEVT9Ixe0ZuTOy/Ou/OxbN1wiplT9AfO5w/J9JTj</latexit>

Fig. 3. Estimate the interval (steps) ∆ until the next step tXnext torefresh the statistics X.

Although it is insisted that F`,emp is not a proper ap-proximation of the FIM, and F`,1mc is better estimation inthe literature [30], [31], we do not see any difference in theconvergence behavior nor the final model accuracy betweenNGD with F`,emp and that with F`,1mc in training deepConvNets for ImageNet classification as shown in Section 7.

4.2 Practical FIM Estimation for BatchNorm LayersIt is often the case that a deep ConvNet has Conv layers thatare followed by BatchNorm layers [32]. The `-th BatchNormlayer after the (`−1)-th Conv layer has scale γ` ∈ Rc`−1 andbias β` ∈ Rc`−1 to be applied to the normalized features.When we see these parameters as learnable parameters, wecan define

w` :=(γ`,1 β`,1 · · · γ`,c`−1

β`,c`−1

)> ∈ R2c`−1 ,(14)

where γ`,i and β`,i are the i-th element of γ` and β`,respectively (i = 1, . . . , c`−1). For this BatchNorm layer, theFIM F` (5) is a 2c`−1 × 2c`−1 matrix, and the computationcost of inverting this matrix can not be ignored when c`−1(the number of output channels of the previous Conv layer)is large (e.g. cout = 1024 for a Conv layer in ResNet-50 [5]).

4.2.1 Unit-wise Natural GradientWe approximate this FIM by applying unit-wise naturalgradient [33] to the learnable parameters of BatchNormlayers (Figure 2). A ”unit” in a neural network is a collectionof input/output nodes connected to each other. The ”unit-wise” natural gradient only takes into account the corre-lation of the parameters in the same node. Hence, for aBatchNorm layer, we only consider the correlation betweenγ`c and β`,c of the same channel c:

F` ≈ F`= F`,unit BN

:= diag(F

(1)` . . .F

(i)` . . .F

(c`−1)`

)∈ R2c`−1×2c`−1 ,

(15)

where

F(i)` = E

∇(i)γ`

2∇(i)γ`∇(i)

β`

∇(i)β`∇(i)γ` ∇(i)

β`

2

∈ R2×2 . (16)

∇(i)γ` , ∇(i)

β`are the i-th element of ∇γ`

log pθ(y|x),∇β`

log pθ(y|x), respectively. The number of the elementsto be computed and communicated is significantly re-duced from 4c2`−1 to 4c`−1, and we can get the inverse(F`,unit BN + λI)

−1 with little computation cost using theinverse matrix formula:[

a bc d

]−1=

1

ad− bc

[d −b−c a

]. (17)

Page 6: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

6

We observed that the unit-wise approximation on F` ofBatchNorm does not affect the accuracy of a deep ConvNeton ImageNet classification as shown in Section 7.

4.3 Natural Gradient with Stale StatisticsTo achieve even faster training with NGD, it is critical toutilize stale statistics in order to avoid re-computing thematrices A`−1,G` and F`,unitBN (Figure 2) at every step.Previous work on K-FAC [17] used a simple strategy wherethey refresh the Kronecker factors only once in 20 steps.However, as observed in [24], the statistics rapidly fluctuatesat the beginning of training, and this simple strategy causesserious defects to the convergence. It was also observedthat the degree of fluctuation of the statistics depends onthe mini-batch size, the layer, and the type of the statistics(e.g, statistics with a larger mini-batch fluctuates less thanthat with a smaller mini-batch). Although the previousstrategy [24] to reduce the frequency worked without anydegradation of the accuracy, it requires the prior observationon the fluctuation of the statistics, and its effectiveness ontraining time has not been well studied.

4.3.1 Adaptive Frequency To Refresh StatisticsWe propose an improved strategy which adaptively de-termines the frequency to refresh each statistics based onits staleness during the training. Our strategy is shown inAlgorithm 1. We calculate the timing (step) to refresh eachstatistics based on the acceptable interval (steps) estimated inAlgorithm 2 (Figure 3). In Algorithm 2, matrix A is consid-ered to be similar to matrix B when ‖A − B‖F /‖B‖F < α,where ‖ · ‖F is Frobenius norm, and α > 0 is the thresh-old4. We tuned α by running training for a few epochs tocheck if the threshold is not too large (preserves the sameconvergence). And it aims to find the acceptable interval ∆where the statistics X calculated at step t = tX+∆ is similarto that calculated at step t = tX . With this strategy, we canestimate the acceptable interval for each statistics and skip thecomputation efficiently — it keeps almost the same trainingtime as the original, while reducing the cost for constructingand inverting A`−1,G` and F`,unitBN as much as possible.This significantly reduces the overhead of NGD. We observethe effectiveness of our approach in Section 7.

5 SCALABLE NATURAL GRADIENT

Based on the practical estimation of the FIM proposed in theprevious section, we designed a distributed parallelizationscheme among multiple GPUs so that the overhead ofNGD compare to SGD decreases as the number of GPUs(processes) is increased.

5.1 Distributed Natural GradientFigure 4 shows the overview of our design, which showsa single step of training with our distributed NGD. We usethe term Stage to refer to each phase of computation andcommunication, which is indicated at the top of the figure.Algorithm 3 shows the pseudo code of our distributed NGDdesign.

4. α = 0.1 for all the experiments in this work.

Algorithm 3: Distributed Natural Gradient

while not converge do// Stage 1foreach ` = 1, · · · , L do

forward in `-th layerif `-th layer is Conv or FC then

compute A`−1end

// Stage 2ReduceScatterV(A0:L−1)foreach ` = L, · · · , 1 do

backward in `-th layerif `-th layer is Conv or FC then

compute G`

else if `-th layer is BatchNorm thencompute F`,unitBN

end

// Stage 3ReduceScatterV(G1:L/F1:L,unitBN and ∇w1:L

L)

// Stage 4for ` = 1, · · · , L do in parallel

update w` by natural gradient (6)end

// Stage 5AllGatherV(w1:L)

endreturn θ =

(w>1 , . . . ,w

>L

)>In Stage 1, each process (GPU) receives a different part

of the mini-batch and performs a forward-pass in whichthe Kronecker factor A`−1 is computed for the receivedsamples, if `-th layer is a Conv layer or a FC layer.

In Stage 2, two procedures are performed in parallel— communication among all the processes and a backward-pass in each process. Since Stage 1 is done in a data-parallel fashion, each process computes the statistics onlyfor the different parts of the mini-batch. In order to computethese statistics for the entire mini-batch, we need to averagethese statistics over all the processes. This is performedusing a ReduceScatterV collective communication, whichtransitions our approach from data-parallelism to model-parallelism by reducing (taking the sum of) A`−1 for differ-ent ` to different processes. This collective is much moreefficient than an AllReduce , where A`−1 for all ` arereduced to all the processes (Figure 4). While A`−1 iscommunicated, each process also performs a backward-passto get the gradient∇w`

L, the Kronecker factorG` for Conv,FC layers, and F`,unitBN for BatchNorm layer, for each `.

In Stage 3, G`,F`,unitBN and ∇w`L are communicated

in the same way as A` by ReduceScatterV collective. Atthis point, only a single process has the FIM estimation F`and the gradient∇w`

Lwith the statistics for the entire mini-batch for the `-th layer. In Stage 4, only the process thathas the FIM computes the matrix inverse and applies theNGD update (6) to the weights w` of the `-th layer. Hence,these computations are performed in a model-parallel fash-ion. When the number of layers is larger than the number ofprocesses, multiple layers are handled by a process.

Page 7: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

7

GPU

data-parallelism model-parallelism

GPU

GPU

......

......

......

......

ReduceScatterV

w`<latexit sha1_base64="/3f6nxDwmeRZ4qp/U/FZd4ChzMA=">AAACBHicbVDLSsNAFJ34rPUVddnNYCu4KokKdllw47KCfUATwmQyaYdOMmFmopaQhRt/xY0LRdz6Ee78GydtFtp64MLhnHu59x4/YVQqy/o2VlbX1jc2K1vV7Z3dvX3z4LAneSow6WLOuBj4SBJGY9JVVDEySARBkc9I359cFX7/jghJeXyrpglxIzSKaUgxUlryzJoT+fwhc3zOggipMWzcN3IvcwhjuWfWraY1A1wmdknqoETHM7+cgOM0IrHCDEk5tK1EuRkSimJG8qqTSpIgPEEjMtQ0RhGRbjZ7IocnWglgyIWuWMGZ+nsiQ5GU08jXncWhctErxP+8YarClpvROEkVifF8UZgyqDgsEoEBFQQrNtUEYUH1rRCPkUBY6dyqOgR78eVl0jtr2udN6+ai3m6VcVRADRyDU2CDS9AG16ADugCDR/AMXsGb8WS8GO/Gx7x1xShnjsAfGJ8/3cCYMQ==</latexit>

w1<latexit sha1_base64="MTyCxJGPQW1I3owI+uwS4WkNgvA=">AAACAXicbVDLSsNAFJ34rPUVdSO4GWwFVyVRwS4LblxWsA9oQphMJu3QSSbMTNQS4sZfceNCEbf+hTv/xkmbhbYeuHA4517uvcdPGJXKsr6NpeWV1bX1ykZ1c2t7Z9fc2+9KngpMOpgzLvo+koTRmHQUVYz0E0FQ5DPS88dXhd+7I0JSHt+qSULcCA1jGlKMlJY889CJfP6QOT5nQYTUCNbv67mX2bln1qyGNQVcJHZJaqBE2zO/nIDjNCKxwgxJObCtRLkZEopiRvKqk0qSIDxGQzLQNEYRkW42/SCHJ1oJYMiFrljBqfp7IkORlJPI153FlXLeK8T/vEGqwqab0ThJFYnxbFGYMqg4LOKAARUEKzbRBGFB9a0Qj5BAWOnQqjoEe/7lRdI9a9jnDevmotZqlnFUwBE4BqfABpegBa5BG3QABo/gGbyCN+PJeDHejY9Z65JRzhyAPzA+fwAveZar</latexit>

wL<latexit sha1_base64="1KZtiwYiUJjlT1poOt+mOzcDeD4=">AAACAXicbVDLSsNAFJ3UV62vqBvBzWAruCqJCnZZcOPCRQX7gDaEyWTSDp1kwsxELSFu/BU3LhRx61+482+ctFlo64ELh3Pu5d57vJhRqSzr2ygtLa+srpXXKxubW9s75u5eR/JEYNLGnHHR85AkjEakrahipBcLgkKPka43vsz97h0RkvLoVk1i4oRoGNGAYqS05JoHg9DjD+nA48wPkRrB2n0tc9PrzDWrVt2aAi4SuyBVUKDlml8Dn+MkJJHCDEnZt61YOSkSimJGssogkSRGeIyGpK9phEIinXT6QQaPteLDgAtdkYJT9fdEikIpJ6GnO/Mr5byXi/95/UQFDSelUZwoEuHZoiBhUHGYxwF9KghWbKIJwoLqWyEeIYGw0qFVdAj2/MuLpHNat8/q1s15tdko4iiDQ3AEToANLkATXIEWaAMMHsEzeAVvxpPxYrwbH7PWklHM7IM/MD5/AFiAlsY=</latexit>

Update

(w1 . . . w` . . . wL)<latexit sha1_base64="CsXzt+mx1AvqZde2abP36qhWC1o=">AAACRHicfVDNS8MwHE39nPWr6tFLcBPmZbQquOPAiwcPE9wHrGOkWbqFpU1JUnWU/nFe/AO8+Rd48aCIVzHdetBt+CDweO/3ledFjEpl2y/G0vLK6tp6YcPc3Nre2bX29puSxwKTBuaMi7aHJGE0JA1FFSPtSBAUeIy0vNFl5rfuiJCUh7dqHJFugAYh9SlGSks9q1N2A48/JK7HWT9AaghL96W0lzip6fa5kuZi2yWM/V9xnZ70rKJdsSeA88TJSRHkqPesZz0RxwEJFWZIyo5jR6qbIKEoZkTviyWJEB6hAeloGqKAyG4yCSGFx1rpQ58L/UIFJ+rvjgQFUo4DT1dmZ8pZLxMXeZ1Y+dVuQsMoViTE00V+zKDiMEsU9qkgWLGxJggLqm+FeIgEwkrnbuoQnNkvz5PmacU5q9g358VaNY+jAA7BESgDB1yAGrgCddAAGDyCV/AOPown4834NL6mpUtG3nMA/sD4/gFdVLGh</latexit>

(w1 . . . w` . . . wL)<latexit sha1_base64="CsXzt+mx1AvqZde2abP36qhWC1o=">AAACRHicfVDNS8MwHE39nPWr6tFLcBPmZbQquOPAiwcPE9wHrGOkWbqFpU1JUnWU/nFe/AO8+Rd48aCIVzHdetBt+CDweO/3ledFjEpl2y/G0vLK6tp6YcPc3Nre2bX29puSxwKTBuaMi7aHJGE0JA1FFSPtSBAUeIy0vNFl5rfuiJCUh7dqHJFugAYh9SlGSks9q1N2A48/JK7HWT9AaghL96W0lzip6fa5kuZi2yWM/V9xnZ70rKJdsSeA88TJSRHkqPesZz0RxwEJFWZIyo5jR6qbIKEoZkTviyWJEB6hAeloGqKAyG4yCSGFx1rpQ58L/UIFJ+rvjgQFUo4DT1dmZ8pZLxMXeZ1Y+dVuQsMoViTE00V+zKDiMEsU9qkgWLGxJggLqm+FeIgEwkrnbuoQnNkvz5PmacU5q9g358VaNY+jAA7BESgDB1yAGrgCddAAGDyCV/AOPown4834NL6mpUtG3nMA/sD4/gFdVLGh</latexit>

(w1 . . . w` . . . wL)<latexit sha1_base64="CsXzt+mx1AvqZde2abP36qhWC1o=">AAACRHicfVDNS8MwHE39nPWr6tFLcBPmZbQquOPAiwcPE9wHrGOkWbqFpU1JUnWU/nFe/AO8+Rd48aCIVzHdetBt+CDweO/3ledFjEpl2y/G0vLK6tp6YcPc3Nre2bX29puSxwKTBuaMi7aHJGE0JA1FFSPtSBAUeIy0vNFl5rfuiJCUh7dqHJFugAYh9SlGSks9q1N2A48/JK7HWT9AaghL96W0lzip6fa5kuZi2yWM/V9xnZ70rKJdsSeA88TJSRHkqPesZz0RxwEJFWZIyo5jR6qbIKEoZkTviyWJEB6hAeloGqKAyG4yCSGFx1rpQ58L/UIFJ+rvjgQFUo4DT1dmZ8pZLxMXeZ1Y+dVuQsMoViTE00V+zKDiMEsU9qkgWLGxJggLqm+FeIgEwkrnbuoQnNkvz5PmacU5q9g358VaNY+jAA7BESgDB1yAGrgCddAAGDyCV/AOPown4834NL6mpUtG3nMA/sD4/gFdVLGh</latexit>

AllGatherVMini-batch

......

Stage 1 Stage 2 Stage 3 Stage 4 Stage 5

forward backward

forward backward

forward backward

F 1,rw1L

<latexit sha1_base64="ea+n2gVAJ0aLz+owsC8bm7Dm/B0=">AAACN3icbVBNS8NAEN34WeNX1aOXxVbwICVRQY9FQTyIVLBWaEqYbLd26WY37G7UEvKvvPg3vOnFgyJe/QcmbQ/a+mDh7Zs3zMwLIs60cZwXa2p6ZnZuvrBgLy4tr6wW19avtYwVoXUiuVQ3AWjKmaB1wwynN5GiEAacNoLeSV5v3FGlmRRXph/RVgi3gnUYAZNJfvHC64JJvDCQD4kXSN4OwXRx+bScpn7ipvau7QkIOPiTnvvywJJ6+ZcAT85T2y+WnIozAJ4k7oiU0Ag1v/jstSWJQyoM4aB103Ui00pAGUY4TW0v1jQC0oNb2syogJDqVjK4O8XbmdLGHamyJwweqL87Egi17odB5sx31OO1XPyv1oxN56iVMBHFhgoyHNSJOTYS5yHiNlOUGN7PCBDFsl0x6YICYrKo8xDc8ZMnyfVexd2vOJcHperxKI4C2kRbaAe56BBV0RmqoToi6BG9onf0YT1Zb9an9TW0Tlmjng30B9b3D0J2rL4=</latexit>

F L,rwLL

<latexit sha1_base64="gssSbADDP0+Rl1s31fTkaIxjiNA=">AAACN3icbVBNS8NAEN34bf2KevSy2AoepCQq6FEUxIOIgq2FJoTJZmuXbnbD7kYtIf/Ki3/Dm148KOLVf2BSe1Drg4W3b94wMy9MONPGcZ6ssfGJyanpmdnK3PzC4pK9vNLUMlWENojkUrVC0JQzQRuGGU5biaIQh5xehb2jsn51Q5VmUlyafkL9GK4F6zACppAC+8zrgsm8OJR3mRdKHsVgurh2XMvzIDvNK1sVT0DIIRj13NYGltwrvwR4aQ/sqlN3BsCjxB2SKhriPLAfvUiSNKbCEA5at10nMX4GyjDCaV7xUk0TID24pu2CCoip9rPB3TneKJQId6QqnjB4oP7syCDWuh+HhbPcUf+tleJ/tXZqOvt+xkSSGirI96BOyrGRuAwRR0xRYni/IEAUK3bFpAsKiCmiLkNw/548SprbdXen7lzsVg8Oh3HMoDW0jjaRi/bQATpB56iBCLpHz+gVvVkP1ov1bn18W8esYc8q+gXr8wua86z0</latexit>

F `,rw`L

<latexit sha1_base64="fIHk7CK04HuqPAJClRzRIIAyLP4=">AAACPXicbVBNS8NAFNz4bfyqevSy2AoepCQq6LEoiAcPCq0VmhJettt26WY37G7UEvLHvPgfvHnz4kERr15N2graOrAwzMxj35sg4kwbx3m2pqZnZufmFxbtpeWV1bXC+sa1lrEitEYkl+omAE05E7RmmOH0JlIUwoDTetA7zf36LVWaSVE1/Yg2Q+gI1mYETCb5harXBZN4YSDvEy+QvBWC6eLSWSlN/cSjnKf2nu0JCDj4k7G70k8q9XKFAE8uUtsvFJ2yMwCeJO6IFNEIl37hyWtJEodUGMJB64brRKaZgDKMcJraXqxpBKQHHdrIqICQ6mYyuD7FO5nSwm2psicMHqi/JxIIte6HQZbMd9TjXi7+5zVi0z5uJkxEsaGCDD9qxxwbifMqcYspSgzvZwSIYtmumHRBATFZ4XkJ7vjJk+R6v+welJ2rw2LlZFTHAtpC22gXuegIVdA5ukQ1RNADekFv6N16tF6tD+tzGJ2yRjOb6A+sr28voK/K</latexit>

AllGatherV

ReduceScatterV

AllReduce

sum

sum

Fig. 4. (Left) Overview of our distributed natural gradient descent (a single step of training). (Right) Illustrations of AllReduce, ReduceScatterV,and AllGatherV collective. Different colors correspond to data (and its communication) from different data sources.

Once the weights w` of each ` are updated, we syn-chronize the updated weights among all the processesby calling an AllGatherV (Figure 4) collective, and weswitch back to data-parallelism. Combining the practicalestimation of the FIM proposed in the previous section,we are able to reduce a significant amount of commu-nication required for the Kronecker factors A`−1,G` andF`,unitBN. Therefore, the amount of communication for ourdistirbuted NGD is similar to distributed SGD, where theAllReduce for the gradient ∇w`

L is implemented as aReduceScatter+AllGather.

5.2 Further accelerationOur data-parallel and model-parallel hybrid approach al-lows us to minimize the overhead of NGD in a distributedsetting. However, NGD still has a large overhead comparedto SGD. There are two hotspots in our distributed NGD de-sign. The first is the construction of the statistics A`−1,G`,and F`,unitBN, that cannot be done in a model-parallel fash-ion. The second is the communication (ReduceScatterV)for distributing these statistics. In this section, we discusshow we accelerate these two hotspots to achieve even fastertraining time.

Mixed-precision computation. We use the Tensor Cores inthe NVIDIA Volta Architecture 5. This more than doublesthe speed of the calculation for this part. One might thinkthat this low-precision computation affects the overall accu-racy of the training, but in our experiments we do not findany differences between training with half-precision float-ing point computation and that with full-precision floatingpoint computation.

Symmetry-aware communication. The statistics matricesA`−1,G`, and F`,unitBN are symmetric matrices. We exploitthis property to reduce the amount of communication with-out loss of information. To communicate a symmetric matrixof size N × N , we only need to send the upper triangularmatrix with N(N + 1)/2 elements.

In addition to these two optimizations, we also adoptedthe performance optimizations done by [25]:

• Explicitly use NHWC (mini-batch, height, width, andchannels) format for the input/output data (tensor)

5. https://www.nvidia.com/en-us/data-center/tensorcore/

of Conv layers instead of NCHW format. This makescuDNN API to fully benefit from the Tensor Cores.

• Data I/O pipeline using the NVIDIA Data LoadingLibrary (DALI) 6.

• Hierarchical AllReduce collective proposed byUeno et al. [34], which alleviates the latency of thering-AllReduce communication among a largenumber of GPUs.

• Half-precision communication for AllGatherV col-lective.

6 TRAINING FOR IMAGENET CLASSIFICATION

The behavior of NGD on large models and datasets has notbeen studied in depth. Also, there are very few studies thatuse NGD (K-FAC) for large mini-batches (over 4K) usingdistributed parallelism at scale [19]. Contrary to SGD, wherethe hyperparameters have been optimized by many practi-tioners even for large mini-batches, there is very little insighton how to tune hyperparameters for NGD. In this section,we have explored some methods, which we call trainingschemes, to achieve higher accuracy in our experiments. Inthis section, we show those training schemes in our largemini-batch training with NGD for ImageNet classification.

6.1 Data augmentation

To achieve good generalization performance while keepingthe benefit of the fast convergence that comes from NGD, weadopt the data augmentation techniques commonly used fortraining networks with large mini-batch sizes. We resize allthe images in ImageNet to 256×256 ignoring the aspect ratioof original images and compute the mean value of the upperleft portion (224×224) of the resized images. When readingan image, we randomly crop a 224 × 224 image from it,randomly flip it horizontally, subtract the mean value, andscale every pixel to [0, 1].

Running mixup. We extend mixup [35] to increase its reg-ularization effect. We synthesize virtual training samplesfrom raw samples and virtual samples from the previous

6. https://developer.nvidia.com/DALI

Page 8: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

8

step (while the original mixup method synthesizes newsamples only from the raw samples):

x(t) = λ · x(t) + (1− λ) · x(t−1) , (18)

t(t) = λ · t(t) + (1− λ) · t(t−1) . (19)

x(t), t(t) is a raw input and label (one-hot vector), andx(t), t(t) is a virtual input and label for t th step. λ is sampledfrom the Beta distribution with the beta function

B(α, β) =

∫ 1

0tα−1(1− t)β−1 dt . (20)

where we set α = β = αmixup.

Random erasing with zero value. We also implementedRandom Erasing [36]. We set elements within the erasingregion of each input to zero instead of a random value asused in the original method. We set the erasing probabilityp = 0.5, the erasing area ratio Se ∈ [0.02, 0.25], and theerasing aspect ratio re ∈ [0.3, 1]. We randomly switch thesize of the erasing area from (He,We) to (We, He).

6.2 Learning rate and momentum

The learning rate used for all of our experiments is sched-uled by polynomial decay. The learning rate η(e) for e th epochis determined as follows:

η(e) = η(0) ·(

1− e− estarteend − estart

)pdecay

. (21)

η(0) is the initial learning rate and estart, eend is the epochwhen the decay starts and ends. The decay rate pdecayguides the speed of the learning rate decay.

We use the momentum method for NGD updates. Be-cause the learning rate decays rapidly in the final stage ofthe training with the polynomial decay, the current updatecan become smaller than the previous update. We adjust themomentum ratem(e) for e th epoch so that the ratio betweenm(e) and η(e) is fixed throughout the training:

m(e) =m(0)

η(0)· η(e) , (22)

where m(0) is the initial momentum rate. The weights areupdated as follows:

w(t+1)` ← w

(t)` − η(e)

(F

(t)` + λI

)−1∇w`L(t) +m(e)v(t) ,

(23)where v(t) = w

(t)` −w

(t−1)` .

6.3 Weights rescaling

To prevent the scale of weights from becoming too large, weadopt the Normalizing Weights [37] technique to thew` of FCand Conv layers. We rescale thew` to have a norm

√2 · dout

after (23):

w(t+1)` ←

√2 · dout ·

w(t+1)`

‖w(t+1)` ‖+ ε

. (24)

where we use ε = 1 · 10−9 to stabilize the computation. doutis the output dimension or channels of the layer.

7 EXPERIMENTS

We train ResNet-50 [5] for ImageNet [4] in all of ourexperiments. We use the same hyperparameters for thesame mini-batch size. The hyperparameters for our resultsare shown in Table 2. We implement all of our methodsusing Chainer [29]. Our Chainer extenstion is availableat https://github.com/tyohei/chainerkfac. We initialize theweights by the HeNormal initializer of Chainer 7 with thedefault parameters.

7.1 Experiment EnvironmentWe conduct all experiments on the ABCI (AI BridgingCloud Infrastructure) 8 operated by the National Instituteof Advanced Industrial Science and Technology (AIST) inJapan. ABCI has 1088 nodes with four NVIDIA Tesla V100GPUs per node. Due to the additional memory required byNGD, all of our experiments use a mini-batch size of 32per GPU. We were only given a 24 hour window to use thefull machine so we had to tune the hyperparameters on asmaller number of nodes while mimicking the global mini-batch size of the full node run. For large mini-batch sizeexperiments which cannot be executed directly (BS=65K,131K requires 2048, 4096 GPUs, respectively) ,we used an ac-cumulation method to mimic the behavior by accumulatingthe statistics A`−1,G`,F`,unitBN, and ∇w`

L over multiplesteps.

7.2 Extremely Large Mini-batch TrainingWe trained ResNet-50 for ImageNet classification task withextremely large mini-batch size BS={4,096, 8,192, 16,384,32,768, 65,536, 131,072} and achieved a competitive top-1validation accuracy (≥ 74.9%) compared to highly-tunedSGD training. The summary of the training is shown inTable 1. For BS={4K, 8K, 16K, 32K, 65K}, the trainingconverges in much less than 90 epochs, which is the usualnumber of epochs required by SGD-based training of Ima-geNet [6], [7], [8], [9], [10]. For BS={4K,8K,16K}, the requiredepochs to reach higher than 75% top-1 validation accuracyis 35 epochs. Even for a relatively large mini-batch size ofBS=32K, NGD still converges in 45 epochs, half the numberof epochs compared to SGD. Note that the calculation timeis still decreasing while the number of epochs is less thandouble when we double the mini-batch size, assumingthat doubling the mini-batch corresponds to doubling thenumber of GPUs (and halving the execution time). When weincrease the mini-batch size to BS={32K,65K,131K}, we seea significantly small number of steps it takes to converge.At BS=32K and 65K, it takes 1760 steps (45 epochs) and1173 steps (60 epochs), respectively. At BS=131K, there areless than 10 iterations per epoch since the dataset sizeof ImageNet is 1,281,167, and it only takes 873 steps toconverge (90 epochs). None of the SGD-based training ofImageNet have sustained the top-1 validation accuracy atthis mini-batch size. Furthermore, this is the first work thatuses NGD for the training with extremely large mini-batchsize BS={16K,32K,65K,131K} and achieves a competitivetop-1 validation accuracy.

7. https://docs.chainer.org/en/stable/reference/generated/chainer.initializers.HeNormal.html

8. https://abci.ai/

Page 9: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

9

TABLE 2The hyperparameters of the training with large mini-batch size (BS) used for our schemes in Section 6 and top-1 single-crop validation accuracy ofResNet-50 for ImageNet. reduction and speedup correspond to the reduction rate of the communication amount and the speedup comes from

that, respectively, for emp+unitBN+stale compare to emp+unitBN in Figure 5.

BS αmixup pdecay estart eend η(0) m(0) λ # steps top-1 accuracy reduction ↓ speedup ↑

4K 0.4 11.0 1 53 8.18 · 10−3 0.997 2.5 · 10−4 10,948 75.2 %→ 74.8 % 23.6 % × 1.338K 0.4 8.0 1 53.5 1.25 · 10−2 0.993 2.5 · 10−4 5,434 75.5 %→ 75.3 % 15.1 % × 1.3216K 0.4 8.0 1 53.5 2.5 · 10−2 0.985 2.5 · 10−4 2,737 75.3 %→ 75.2 % 5.4 % × 1.6832K 0.6 3.5 1.5 49.5 3.0 · 10−2 0.97 2.0 · 10−4 1,760 75.6 %→ 75.4 % 7.8 % × 1.4065K 0.6 2.9 2 64.5 4.0 · 10−2 0.95 1.5 · 10−4 1,173 75.6 % n/a n/a

131K 1.0 2.9 3 100 7.0 · 10−2 0.93 1.0 · 10−4 873 74.9 % n/a n/a

1 2 4 8 16 32 64 128

256

512

1024

# GPUs

0.0

0.2

0.4

0.6

0.8

seco

nds /

step

1mc + full BN1mc + unit BNemp + full BNemp + unit BNemp + unit BN + stale

Fig. 5. Time per step for trainig ResNet-50 (107 layers in total) onImageNet with our scalable and practical NGD. Each GPU processes32 images. 1mc and emp correspond to NGD with F`,1mc and that withF`,emp, respectively. fullBN and unitBN correspond to NGD and unit-wise NGD on BatchNorm parameters, respectively. stale correspondsto NGD with the stale statistics.

7.3 Scalability

We measure the scalability of our distributed NGD imple-mentation for training ResNet-50 on ImageNet. Figure 5shows the time for one step with different number of GPUsand different techniques proposed in Section 4. Note thatsince we fix the number of images to be processed per GPU(=32), doubling the number of GPUs means doubling thetotal number of images (mini-batch size) to be processedin a step (e.g, 32K images are processed with 1024 GPUsin a step). In a distributed training with multiple GPUs, itis considered ideal if this plot shows a flat line parallel tothe x-axis, that is, the time per step is independent of thenumber of GPUs, and the number of images processed ina certain time increases linearly with the number of GPUs.From 1 GPU to 64 GPUs, however, we observed a superlinearscaling. For example, the time per step with 64 GPUs is 300%faster than that with 1 GPU for emp+fullBN. This is theconsequence of our model-parallel design since ResNet-50has 107 layers in total when all the Conv, FC, and Batch-Norm layers are accounted for. With more than 128 GPUs,we observe slight performance degradation due to the com-munication overhead comes from ReduceScatterV andAllGatherV collective. Yet for emp+unitBN+stale, wesee almost the ideal scaling from 128 GPUs to 1024 GPUs.

Moreover, with 512 GPUs, which corresponds to BS=16K,we see a superlinear scaling, again. We discuss this in thenext sub-section.

7.4 Effectiveness of Practical Natural GradientWe examine the effectiveness of our practical NGD ap-proaches proposed in Section 4 for training ResNet-50 onImageNet with extreamely large mini-batch. We show thatour practical techniques makes the overhead of NGD closeto a negligible amount and improves training time signifi-cantly. The summary of the training time is shown in Table 1and Figure 1.

Natural Gradient by Empirical Fisher. We compare thetime and model accuracy in training by NGD with empir-ical Fisher and that with a Fisher estimation by a singleMonte Carlo sampling (F`,emp vs F`,1mc) . In Figure 5,the time per a step by each training is labeled as empand 1mc, respectively. Due to the extra backward-passrequired for constructing F`,1mc, 1mc is slower than empat any number of GPUs. We do not see any differencein the convergence behavior (accuracy vs steps) and thefinal accuracy for training ResNet-50 on ImageNet withBS={4K,8K,16K,32K,65K,131K}. Note that we used the samehyperparameters tuned for emp for each BS (shown inTable 2) for the limitation of computational resource to tunefor 1mc.

Unit-Wise Natural Gradient. We also compare training withnatural gradient and that with unit-wise natural gradienton BatchNorm parameters (F` vs F`,unitBN) . In Figure 5,the time per step for each method is labeled as fullBNand unitBN, respectively. From 1 GPU to 16 GPUs, wecan see that unitBN effectively accelerates the time perstep compare to fullBN. For more than 32 GPUs, we cansee only a slight improvements since inverting statistics(A`−1,G` and F`) for all the layers are already distributedamong enough number of processes, and inverting F` is nolonger a bottleneck of processing a step. We do not see anydifference in the convergence behavior (accuracy vs steps)and the final accuracy for training ResNet-50 on ImageNetwith BS={4K,8K,16K,32K,65K,131K}.Natural Gradient with Stale Statistics. We apply our adap-tive frequency strategy described in Section 4 to trainingResNet-50 on ImageNet with BS={4K,8K,16K,32K}. In Fig-ure 5, the time per a step with all the practical techniques islabeled as emp+unitBN+stale. The model accuracy beforeand after applying this technique, reduction rate (smaller

Page 10: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

10

Fig. 6. The communication amount (bytes) for the statistics (A`−1,G`,F`,unitBN) in each step in training ResNet-50 on ImageNet withBS={4K,8K,16K,32K} (stacked graph — the amount for G/F is stacked on the amount for A). A and G/F correspond to the communication amountfor A`−1 and G`/F`,unitBN, respectively. The reduction rate (smaller is better) of the communication amount for all the statistics throughout thetraining is shown with the percentage (%).

is better) of the communication volume for the statistics,and speedup (emp+unitBN vs emp+unitBN+stale) isshown in Table 2. The communication amount (bytes) inthe ReduceScatterV collective in each step during a train-ing and the reduction rate are plotted in Figure 6. WithBS=16K,32K, we can reduce the communication amount forthe statistics (A`−1,G` and F`) to 5.4%,7.8%, respectively.We might be able to attribute this significant reductionrate to the fact that the statistics with larger BS (16K,32K)is more stable than that with smaller BS (4K,8K). Notethat though we show the reduction rate of the amountof communication, this rate is also applicable to estimatethe reduction rate of the amount of computation for thestatistics, and the cost for inverting them is also removed.With these improvements on NGD, we see almost an idealscaling from 128 GPUs to 1024 GPUs, which corresponds toBS=4K to 32K.

7.5 Training ResNet-50 on ImageNet with NGD in 5.5minutes

Finally, we combine all the practical techniques — empiricalFisher, unit-wise NGD and NGD with stale statistics. Using1024 NVIDIA Tesla V100, we achieve 75.4 % top-1 accuracywith ResNet-50 for ImageNet in 5.5 minutes (1760 steps = 45epochs, including a validation after each epoch). We usedthe same hyperparameters shown in Table 2. The trainingtime and the validation accuracy are competitive with theresults reported by related work that use SGD for training(the comparison is shown in Table 1). We refer to ourtraining method as Scarable and Practical NGD (SP-NGD).

8 DISCUSSION AND FUTURE WORK

In this work, we proposed a Scalable and Practical NaturalGradient Descent (SP-NGD), a framework which combinesi) a large-scale distributed computational design with dataand model hybrid parallelism for the Natural Gradient De-scent (NGD) [14] and ii) practical Fisher information estima-tion techniques including Kronecker-Factored ApproximateCurvature (K-FAC) [17], that alleviates the computationaloverhead of NGD over SGD. Using our SP-NGD frame-work, we showed the advantages of the NGD over first-order stochastic gradient descent (SGD) for training ResNet-50 on ImageNet classification with extremely large mini-batches. We introduced several schemes for the trainingusing the NGD with mini-batch sizes up to 131,072 and

achieved over 75% top-1 accuracy in much fewer numberof steps compared to the existing work using the SGDwith large mini-batch. Contrary to prior claims that modelstrained with second-order methods do not generalize aswell as the SGD, we were able to show that this is notat all the case, even for extremely large mini-batches. OurSP-NGD framework allowed us to train on 1024 GPUs andachieved 75.4% in 5.5 minutes. This is the first work whichobserves the relationship between the FIM of ResNet-50and its training on large mini-batches ranging from 4Kto 131K. The advantage that we have in designing betteroptimizers by taking this approach is that we are startingfrom the most mathematically rigorous form, and everyimprovement that we make is a systematic design decisionbased on observation of the FIM. Even if we end up havingsimilar performance to the best known first-order methods,at least we will have a better understanding of why it worksby starting from second-order methods including NGD.

More accurate and efficient FIM estimation. We showedthat NGD using the empirical Fisher matrix [17] (emp) ismuch faster than that with an estimation using a singleMonte Carlo (1mc), which is widely used by related workon the approximate natural gradient. Although it is statedthat emp is not a good approximation of the NGD in theliterature [30], [31], we observed the same convergencebehavior as 1mc for training ResNet-50 on ImageNet. Wemight be able to attribute this result to the fact that emp isa good enough approximation to keep the behavior of thetrue NGD or that even 1mc is not a good approximation.To know whether these hypotheses are correct or not andto examine the actual value of the true NGD, we need amore accurate and effcient estimation of the NGD with lesscomputational cost.

Towards Bayesian deep learning. NGD has been applied toBayesian deep learning for estimating the posterior distri-bution of the network parameters. For example, K-FAC [17]has been applied to Bayesian deep learning to realize NoisyNatural Gradient [21], and our distributed NGD has beenapplied to that at ImageNet scale [26]. We similarly expectthat our SP-NGD framework will accelerate Bayesian deeplearning research using natural gradient methods.

ACKNOWLEDGMENTS

Computational resource of AI Bridging Cloud Infrastructure(ABCI) was awarded by ”ABCI Grand Challenge” Program,

Page 11: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

11

National Institute of Advanced Industrial Science and Tech-nology (AIST). This work was supported by JSPS KAKENHIGrant Number JP18H03248 and JP19J13477. This workis supported by JST CREST Grant Number JPMJCR19F5,Japan. (Part of) This work is conducted as research activitiesof AIST - Tokyo Tech Real World Big-Data ComputationOpen Innovation Laboratory (RWBC-OIL). This work issupported by ”Joint Usage/Research Center for Interdis-ciplinary Large-scale Information Infrastructures” in Japan(Project ID: jh180012-NAHI). This research used computa-tional resources of the HPCI system provided by (the namesof the HPCI System Providers) through the HPCI SystemResearch Project (Project ID:hp190122). We would like tothank Yaroslav Bulatov (South Park Commons) for helpfulcomments on the manuscript.

REFERENCES

[1] C. J. Shallue, J. Lee, J. Antognini, J. Sohl-Dickstein, R. Frostig, andG. E. Dahl, “Measuring the Effects of Data Parallelism on NeuralNetwork Training,” Journal of Machine Learning Research, vol. 20,no. 112, pp. 1–49, 2019.

[2] E. Hoffer, R. Banner, I. Golan, and D. Soudry, “Norm matters:Efficient and accurate normalization schemes in deep networks,”in Advances in Neural Information Processing Systems, 2018, pp.2160–2170.

[3] S. L. Smith, P.-J. Kindermans, C. Ying, and Q. V. Le, “Don’tDecay the Learning Rate, Increase the Batch Size,” in InternationalConference on Learning Representations, 2018.

[4] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei,“ImageNet: A Large-Scale Hierarchical Image Database,” in IEEEConference on Computer Vision and Pattern Recognition, 2009, pp.248–255.

[5] K. He, X. Zhang, S. Ren, and J. Sun, “Deep Residual Learningfor Image Recognition,” in IEEE Conference on Computer Vision andPattern Recognition, 2016, pp. 770–778.

[6] P. Goyal, P. Dollar, R. Girshick, P. Noordhuis, L. Wesolowski,A. Kyrola, A. Tulloch, Y. Jia, and K. He, “Accurate, LargeMinibatch SGD: Training ImageNet in 1 Hour,” arXiv preprintarXiv:1706.02677, 2017.

[7] T. Akiba, S. Suzuki, and K. Fukuda, “Extremely Large MinibatchSGD: Training ResNet-50 on ImageNet in 15 Minutes,” arXivpreprint arXiv:1711.04325, 2017.

[8] Y. You, I. Gitman, and B. Ginsburg, “Large Batch Training ofConvolutional Networks,” arXiv preprint arXiv:1708.03888, 2017.

[9] X. Jia, S. Song, W. He, Y. Wang, H. Rong, F. Zhou, L. Xie, Z. Guo,Y. Yang, L. Yu, T. Chen, G. Hu, S. Shi, and X. Chu, “Highly ScalableDeep Learning Training System with Mixed-Precision: TrainingImageNet in Four Minutes,” arXiv preprint arXiv:1807.11205, 2018.

[10] H. Mikami, H. Suganuma, P. U-chupala, Y. Tanaka, andY. Kageyama, “Massively Distributed SGD: ImageNet/ResNet-50Training in a Flash,” arXiv preprint arXiv:1811.05233, 2018.

[11] M. Yamazaki, A. Kasagi, A. Tabuchi, T. Honda, M. Miwa, N. Fuku-moto, T. Tabaru, A. Ike, and K. Nakashima, “Yet Another Acceler-ated SGD: ResNet-50 Training on ImageNet in 74.7 seconds,” arXivpreprint arXiv:1903.12650, 2019.

[12] T. Lin, S. U. Stich, and M. Jaggi, “Don’t Use Large Mini-Batches,Use Local SGD,” arXiv preprint arXiv:1808.07217, 2018.

[13] C. Ying, S. Kumar, D. Chen, T. Wang, and Y. Cheng, “Image Classi-fication at Supercomputer Scale,” arXiv preprint arXiv:1811.06992,2018.

[14] S.-i. Amari, “Natural Gradient Works Efficiently in Learning,”Neural Computation, vol. 10, pp. 251–276, 1998.

[15] N. L. Roux, P.-a. Manzagol, and Y. Bengio, “Topmoumoute OnlineNatural Gradient Algorithm,” in Advances in Neural InformationProcessing Systems 20, 2008, pp. 849–856.

[16] J. Martens, “Deep learning via Hessian-free optimization,” inProceedings of the 27th International Conference on Machine Learning,2010, pp. 735–742.

[17] J. Martens and R. Grosse, “Optimizing Neural Networks withKronecker-factored Approximate Curvature,” in Proceedings of the32nd International Conference on International Conference on MachineLearning, 2015, pp. 2408–2417.

[18] R. Grosse and J. Martens, “A Kronecker-factored approximateFisher matrix for convolution layers,” in Proceedings of the 33rdInternational Conference on International Conference on Machine Learn-ing, 2016, pp. 573–582.

[19] J. Ba, R. Grosse, and J. Martens, “Distributed second-order op-timization using Kronecker-factored approximations,” in Interna-tional Conference on Learning Representations, 2017.

[20] J. Martens and M. Johnson, “Kronecker-factored Curvature Ap-proximations for Recurrent Neural Networks,” in InternationalConference on Learning Representations, 2018.

[21] G. Zhang, S. Sun, D. Duvenaud, and R. Grosse, “Noisy NaturalGradient as Variational Inference,” in Proceedings of the 35th Inter-national Conference on Machine Learning, 2018.

[22] Y. Wu, E. Mansimov, S. Liao, R. Grosse, and J. Ba, “Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation,” in Advances in Neural Information Process-ing Systems, 2017, pp. 5279–5288.

[23] G. Zhang, L. Li, Z. Nado, J. Martens, S. Sachdeva, G. E. Dahl,C. J. Shallue, and R. Grosse, “Which Algorithmic Choices Matterat Which Batch Sizes? Insights From a Noisy Quadratic Model,”arXiv preprint arXiv:1907.04164, 2019.

[24] K. Osawa, Y. Tsuji, Y. Ueno, A. Naruse, R. Yokota, and S. Mat-suoka, “Large-Scale Distributed Second-Order Optimization Us-ing Kronecker-Factored Approximate Curvature for Deep Convo-lutional Neural Networks,” in IEEE Conference on Computer Visionand Pattern Recognition, 2019, pp. 12 359–12 367.

[25] Y. Tsuji, K. Osawa, Y. Ueno, A. Naruse, R. Yokota, and S. Mat-suoka, “Performance Optimizations and Analysis of DistributedDeep Learning with Approximated Second-Order OptimizationMethod,” in Proceedings of the 48th International Conference onParallel Processing: Workshops, 2019, pp. 21:1–21:8.

[26] K. Osawa, S. Swaroop, A. Jain, R. Eschenhagen, R. E. Turner,R. Yokota, and M. E. Khan, “Practical Deep Learning withBayesian Principles,” arXiv preprint arXiv:1906.02506, 2019.

[27] M. Abadi, P. Barham, J. Chen, Z. Chen, A. Davis, J. Dean, M. Devin,S. Ghemawat, G. Irving, M. Isard, M. Kudlur, J. Levenberg,R. Monga, S. Moore, D. G. Murray, B. Steiner, P. Tucker, V. Va-sudevan, P. Warden, M. Wicke, Y. Yu, and X. Zheng, “TensorFlow:A system for large-scale machine learning,” in 12th USENIXSymposium on Operating Systems Design and Implementation, 2016,pp. 265–283.

[28] A. Botev, H. Ritter, and D. Barber, “Practical Gauss-Newton Opti-misation for Deep Learning,” in Proceedings of the 34th InternationalConference on Machine Learning, 2017, pp. 557–565.

[29] S. Tokui, H. Yamazaki Vincent, R. Okuta, T. Akiba, Y. Niitani,T. Ogawa, S. Saito, S. Suzuki, K. Uenishi, and B. Vogel, “Chainer: ADeep Learning Framework for Accelerating the Research Cycle,”in Proceedings of the 25th ACM SIGKDD International Conference onKnowledge Discovery & Data Mining, 2019, pp. 2002–2011.

[30] V. Thomas, F. Pedregosa, B. van Merrienboer, P.-A. Mangazol,Y. Bengio, and N. L. Roux, “Information matrices and general-ization,” arXiv preprint arXiv:1906.07774, 2019.

[31] F. Kunstner, L. Balles, and P. Hennig, “Limitations of the EmpiricalFisher Approximation,” arXiv preprint arXiv:1905.12558, 2019.

[32] S. Ioffe and C. Szegedy, “Batch Normalization: Accelerating DeepNetwork Training by Reducing Internal Covariate Shift,” Proceed-ings of the 32nd International Conference on International Conferenceon Machine Learning, pp. 448–456, 2015.

[33] S.-i. Amari, R. Karakida, and M. Oizumi, “Fisher Information andNatural Gradient Learning in Random Deep Networks,” in The22nd International Conference on Artificial Intelligence and Statistics,2019, pp. 694–702.

[34] Y. Ueno and R. Yokota, “Exhaustive Study of Hierarchical AllRe-duce Patterns for Large Messages Between GPUs,” in 19thIEEE/ACM International Symposium on Cluster, Cloud and Grid Com-puting (CCGRID), 2019, pp. 430–439.

[35] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz, “Mixup:Beyond Empirical Risk Minimization,” in International Conferenceon Learning Representations, 2018.

[36] Z. Zhong, L. Zheng, G. Kang, S. Li, and Y. Yang, “Random ErasingData Augmentation,” arXiv preprint arXiv:1708.04896, 2017.

[37] T. van Laarhoven, “L2 Regularization versus Batch and WeightNormalization,” arXiv preprint arXiv:1706.05350, 2017.

Page 12: 1 Scalable and Practical Natural Gradient for Large-Scale ... · 50 on ImageNet classification as a benchmark. We converged to over 75% top-1 validation accuracy for large mini-batch

12

Kazuki Osawa received his BS and MS fromTokyo Institute of Technology in 2016 and 2018,respectively. He is currently a PhD candidate atTokyo Institute of Technology and a ResearchFellow of Japan Society for the Promotion ofScience (JSPS). His research interests includeoptimization, approximate Bayesian inference,and distributed computing for deep learning. Heis a student member of IEEE.

Yohei Tsuji received his BS and MS from TokyoInstitute of Technology in 2017 and 2019, re-spectively. He is currently a PhD student atTokyo Institute of Technology. His research in-terests include high performance computing formachine learning, probabilistic programming.

Yuichiro Ueno received his BS from Tokyo In-stitute of Technology in 2019. He is currently amaster course student at Tokyo Institute of Tech-nology. His research interests include a rangeof high-performance computing, such as GPUcomputing and networking, and its application todeep learning. He is a student member of ACM.

Akira Naruse is a senior developer technologyengineer at NVIDIA. He holds the MS degree incomputer science from Nagoya University. Priorto joining NVIDIA, he was a research engineerat Fujitsu Laboratory and was involved in var-ious high performance computing projects. Hismain interest is the performance analysis andoptimization of scientific computing and deeplearning applications on very large systems.

Chuan-Sheng Foo is a Scientist at the Institutefor Infocomm Research, A*STAR. He receivedhis BS, MS and PhD from Stanford University.His research focuses on developing deep learn-ing algorithms that can learn from less labeleddata, inspired by applications in healthcare andmedicine.

Rio Yokota received his BS, MS, and PhD fromKeio University in 2003, 2005, and 2009, re-spectively. He is currently an Associate Profes-sor at GSIC, Tokyo Institute of Technology. Hisresearch interests range from high performancecomputing, hierarchical low-rank approximationmethods, and scalable deep learning. He waspart of the team that won the Gordon Bell prizefor price/performance in 2009.