Computational Experience in Solving Equilibrium Models by ...
Deep Equilibrium Models
Transcript of Deep Equilibrium Models
![Page 1: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/1.jpg)
Deep Equilibrium ModelsShaojie Bai, J. Zico Kolter, Vladlen Koltun
Presented by: Alex Wang, Gary Leung, Sasha Doubov
![Page 2: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/2.jpg)
Agenda
1. Weight tying and infinite depth models2. Implicit layer formulation3. Approximation and computational considerations4. DEQ stacking?5. Experiments
2
![Page 3: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/3.jpg)
Motivation
Let’s start with a typical deep NN architecture:
3Image courtesy of Deep Implicit Layers Tutorial
![Page 4: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/4.jpg)
Motivation
Weight-Tying: Use the same W and inject the input for each layer
4Image courtesy of Deep Implicit Layers Tutorial
Just as expressive!
![Page 5: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/5.jpg)
Motivation
Weight-Tying: Use the same W and inject the input for each layer
5Image courtesy of Deep Implicit Layers Tutorial
Focusing on activation iteration:
![Page 6: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/6.jpg)
In the infinite limit, as i -> ∞ (under nice conditions)
6
Key insight: The network’s activations z* approach a fixed point!
This is a DEQ!
![Page 7: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/7.jpg)
Deep Equilibrium Model Overview
7
![Page 8: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/8.jpg)
Implicit vs. Explicit Layers
Explicit Layers: Typical Neural Network layers, which can directly compute the output and backward pass through backprop
Implicit Layers:Based on solving solution to some problem, such that x, z satisfy some condition
- Arises in naturally in some domains, such as ODEs and fixed-points
8Image courtesy of Deep Implicit Layers Tutorial
![Page 9: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/9.jpg)
Forward Pass
Naive Approach: we could repeatedly apply the function until convergence
Better Way: Use a root-finding algorithm to find the fixed point
1. Reformulate fixed-point as finding the root:
2. Apply generic root-finding algorithm (ex. Newton’s method!)
9
We’ll use this notation from here on!
![Page 10: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/10.jpg)
Backward Pass
We need to update the parameters Ө in our model, to minimize our loss function
Challenge: differentiating through fixed point
Naive Approach: Built-in Autodiff, through solver computation graph
- Memory Issues- Floating Point Errors
Better Approach: Use Implicit Function Theorem!
10
![Page 11: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/11.jpg)
Implicit Function Theorem (Informal)
Let be a relation with inputs 1. 2. is continuously differentiable with non-singular Jacobian Then there exist neighborhoods (open sets) around & and a function
1.2.3. is differentiable on
11
x
z
Image courtesy of Wikipedia
![Page 12: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/12.jpg)
Implicit Function Theorem
High-level Idea:
- Convert a relation to a function in a local region and find its derivative
- Explicit function at A:
- IFT allows us to find the derivative of g, without the explicit form
12
x
z
![Page 13: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/13.jpg)
Implicit Function Theorem
Let and be such that:1. 2. is continuously differentiable with non-singular Jacobian Then there exist open sets andContaining and respectively, and a unique continuous function such that
1.2.3. is differentiable on
13
x
z
![Page 14: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/14.jpg)
Backwards Pass
14
By IFT
![Page 15: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/15.jpg)
Backwards Pass - DEQ
● Backward Pass:○ Solve using root finding (e.g. Newtons)
15
VJP easily obtained from Pytorch/Jax/etc.
![Page 16: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/16.jpg)
Approximate Inverse Jacobian - Broyden’s Method
● Expensive to calculate the inverse Jacobian during root finding for both forwards and backwards!
● Broyden’s Method (quasi-newton solver):○ During root finding, approximates the inverse Jacobian using
16Initial Guess:
![Page 17: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/17.jpg)
DEQ Memory
● Very memory efficient because forward and backward passes just use root-finding algorithms.
● Avoids over all the overhead from uncurling backpropagating steps.● Storage:
○ Equilibrium Point
○ Network Input
○ Model
○ VJP (*no Jacobian construction needed!)
17
![Page 18: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/18.jpg)
Expressivity of DEQs
Intuition:
Consider a simple function compositionTransforming this into a DEQ:
Thus, the equilibrium point is:
The output equilibrium is the output of the function! (can be extended to arbitrary computation graph)
18
![Page 19: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/19.jpg)
DEQ Stacking?
● Stacking DEQs don’t really work, as a single DEQ layer can model any amount of stacked DEQ layers.
Intuition:
Consider a stack of 2 DEQ layers:
This is equivalent to the following single equilibrium problem:
19
![Page 20: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/20.jpg)
Experiments
● Apply Deep Equilibrium Networks to sequence modelling tasks○ Sequence empirically converge○ Already use of weight tying over the temporal sequence (Trellis Nets & Transformer models)○ Long-range copy-memory, Penn Treebank Language Modelling, WikiText-103
● Demonstrate the memory efficiency and expressivity of DEQ models given similar parameter counts as well as the speed of computation
20
![Page 21: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/21.jpg)
Convergence Caveat
● One might expect the network to diverge in the infinite limit● In practice, many networks do not, which is explored more formally in a later
work [ Kolter et. al 2020] ● In this work, the authors empirically show that the contributions of subsequent
layers diminish at very large depths
21
![Page 22: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/22.jpg)
Experiments
22Not SOTA, but good at same param size
Efficient!
![Page 23: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/23.jpg)
Experiments - Runtime
23
● Notice DEQs are slower!
● This is a consequence of solving an inner optimization inside the network
![Page 24: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/24.jpg)
Conclusion
● Deep Equilibrium Models are a weight tied, approximately infinite depth neural network
○ Output is the fixed point of some neural network function● Computes two root finding solutions for both the forward and backwards pass
○ Uses IFT to compute the gradient updates rather than backpropagation and autodiff through the iterative graph
● Performs comparatively to SOTA models of the same size but are considerably more memory efficient
○ Typically slower due to the inner loop optimization both forward and backwards
24
![Page 25: Deep Equilibrium Models](https://reader033.fdocuments.in/reader033/viewer/2022041718/6253ec208f83940817410c47/html5/thumbnails/25.jpg)
References
Deep Implicit Layers Tutorial - http://implicit-layers-tutorial.org/
Deep Equilibrium Models [Bai 2019] - https://arxiv.org/abs/1909.01377
25