Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA...
Transcript of Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA...
![Page 1: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/1.jpg)
xRecurrent
↙ ↑Networks← Relational→ Learning
↓ ↗Reinforcement
Petar Velickovic
Artificial Intelligence GroupDepartment of Computer Science and Technology, University of Cambridge, UK
MILA Graph Representation Reading Group Meeting 22 August 2018
![Page 2: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/2.jpg)
Introduction
I In this talk, I will survey the Relational Network architecture,and its recent deployment in recurrent neural networks anddeep reinforcement learning.
I The discussion will span the following papers:I A simple neural network module for relational reasoning
(Santoro, Raposo et al., NIPS 2017)I Relational deep reinforcement learning
(Zambaldi, Raposo, Santoro et al., 2018)I Relational recurrent neural networks
(Santoro, Faulkner, Raposo et al., 2018)
I Substantial part of DeepMind’s recent “graph networks surge”.
![Page 3: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/3.jpg)
Relational reasoning
I Being able to reason about relations between entities presentin an input is an important aspect of intelligence!
I Consider the simple task of inferring which two points from agiven point set are furthest apart—this requires computing andcomparing all* of their pairwise distances.
I Keep this task in mind—it will be revisited!
![Page 4: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/4.jpg)
Approaches to relational reasoning
I Relations can be naturally expressed within symbolic methods(defined by e.g. the rules of logic)—but these are not robust tosmall variations of inputs/tasks.
I Robustness is often achievable with standard neural networkarchitectures (such as MLPs), but it is extremely challengingfor them to capture relations, despite their theoretical potency!
I This claim is extensively validated throughout the three papers.
=⇒ Seek a model inspired by symbolic AI, while empowered byneural networks (explicitly represent relations in a robust way).
![Page 5: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/5.jpg)
Our task for today
♣♦♥
♠input
~o♣
~o♦
~o♥
~o♠
"two"
output
"How many outlined objects
are above the spade?"
query
![Page 6: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/6.jpg)
Our task for today
♣♦♥
♠input
~o♣
~o♦
~o♥
~o♠
♣♦
♥♠
objects
"two"
output
"How many outlined objects
are above the spade?"
query
![Page 7: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/7.jpg)
Our task for today
♣♦♥
♠input
~o♣
~o♦
~o♥
~o♠
♣♦
♥♠
objects
"two"
output
"How many outlined objects
are above the spade?"
query
RN
![Page 8: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/8.jpg)
Our task for today
♣♦♥
♠input
~o♣
~o♦
~o♥
~o♠
♣♦
♥♠
objects
"two"
output
"How many outlined objects
are above the spade?"
query
RN
![Page 9: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/9.jpg)
The Relational Network
♣♦♥
♠input
~o♣
~o♦
~o♥
~o♠
♣♦
♥♠
objects
"two"
output
"How many outlined objects
are above the spade?"
query
RN
![Page 10: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/10.jpg)
Relational Networks
I Initially, we will assume that the objects are provided as input.
I Consider a set of n objects, O = {~o1, ~o2, . . . , ~on}; with eachobject represented by a feature vector ~oi ∈ Rm.
I A Relational Network (RN) summarises the relations betweenthese objects as follows:
RN(O) = fφ
∑i,j
gθ(~oi , ~oj
)where gθ : Rm × Rm → Rk and fφ : Rk → Rl are functions withparameters θ and φ (usually MLPs).
![Page 11: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/11.jpg)
Properties of RNs
I The central component of RNs is gθ; the relation function. Itsrole is to infer the nature of relations between objects i and j .
I An RN may be seen as a message-passing neural networkover the complete graph of object nodes.
I RNs have several highly desirable properties:I Relational inference—given the all-pairs nature of the
computation, the module does not assume upfront knowledge ofwhich pairs of objects are related, and how.
I Data efficiency—an MLP would need to learn and embed n2
(identical) functions to replicate behaviour of RNs.I Permutation invariance—the summation operation ensures
that the order of objects does not matter; therefore RNs can beapplied to arbitrary sets.
![Page 12: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/12.jpg)
Dynamic physical systems
MuJoCo-simulated physical mass-spring systems with 10 objects.
Input: ~oi is RGB color and (x , y) coordinates across 16 time steps.Tasks: (i) infer relations; (ii) count number of systems (harder!).
![Page 13: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/13.jpg)
Results on physical systems
I Relational Networks achieve 93% accuracy in predicting theexistence/absence of relations between objects, and 95%accuracy in predicting the number of interacting systems.
I MLPs fail to predict better than chance on either task!
I Learnt function transferable to unseen motion capture data!
![Page 14: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/14.jpg)
Conditioning in RNs
♣♦♥
♠input
~o♣
~o♦
~o♥
~o♠
♣♦
♥♠
objects
"two"
output
"How many outlined objects
are above the spade?"
query
RN
![Page 15: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/15.jpg)
RN conditioning
I An RN may be seen as a module that “captures” the relationsbetween objects in a set—this computation may be arbitrarilyconditioned, e.g. to answer a specific relational query.
I Assuming we have a conditioning vector ~q, the RN architecturemay be trivially modified to include it:
RN(O, ~q) = fφ
∑i,j
gθ(~oi , ~oj , ~q
)
![Page 16: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/16.jpg)
The CLEVR dataset (Johnson et al., 2017)
Question Answering dataset on 3D-rendered objects.
What is the size of the brown sphere?
Non-relational question:Original Image:
Relational question:Are there any rubber things that have the same size as the yellow metallic cylinder?
Input: ~oi is RGB color, (x , y , z) coordinates, shape/material/size.Queries: count, exist, compare numbers, query attribute,
compare attribute.
![Page 17: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/17.jpg)
Results on CLEVR
I The query sentence is encoded into ~q as the last-stage outputof a word-level LSTM (with learned word embeddings).
I Relational Networks achieve an accuracy of 96.4% on CLEVR.
I Human performance is 92.6%! This sounds great!
I . . .
OK, I lied to you. (Sorry!)
![Page 18: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/18.jpg)
Results on CLEVR
I The query sentence is encoded into ~q as the last-stage outputof a word-level LSTM (with learned word embeddings).
I Relational Networks achieve an accuracy of 96.4% on CLEVR.
I Human performance is 92.6%! This sounds great!
I . . . OK, I lied to you. (Sorry!)
![Page 19: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/19.jpg)
The actual CLEVR dataset (Johnson et al., 2017)
Visual Question Answering dataset on 3D-rendered objects.
What is the size of the brown sphere?
Non-relational question:Original Image:
Relational question:Are there any rubber things that have the same size as the yellow metallic cylinder?
Input: The scene image. The ~oi vectors are not explicitly given!Queries: count, exist, compare numbers, query attribute,
compare attribute.
![Page 20: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/20.jpg)
Object extraction
♣♦♥
♠input
~o♣
~o♦
~o♥
~o♠
♣♦
♥♠
objects
"two"
output
"How many outlined objects
are above the spade?"
query
RN
![Page 21: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/21.jpg)
Object extraction from images
I In general, we should not assume the ~oi will be given!
I Arguably, obtaining the ~oi from raw input will be the mostvariable pipeline component.
I Often, we can obtain object representations as high-leveloutputs of neural networks specialised for such inputs.
I In the case of images (most common!) this will be aconvolutional neural network.
![Page 22: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/22.jpg)
CNN object extraction
I A convolutional architecture generally consists of interleavingconvolutional and pooling layers—progressively building moresophisticated feature maps.
I At any point during a CNN, a feature map f may have theshape n ×m × k , where n and m are the height and width ofthe feature map, and each pixel is represented by k features.
I Each pixel represents a summary of a certain region of theimage. Without any further assumptions, it is safest to let eachpixel constitute an object!
I Therefore, we will have an object set O = {~o1, . . . ~on·m} withn ·m objects and ~oi ∈ Rk that will correspond to ~fxy .
![Page 23: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/23.jpg)
CNN object extraction
Conv. Pool Conv.
input f
![Page 24: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/24.jpg)
CNN object extraction
Conv. Pool Conv.
input~o1
~o3
~o2
~o4
![Page 25: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/25.jpg)
Overall CLEVR architecture
small+
-MLP
...
...
... ...
is spherewhat size
Final CNN feature maps RN
LSTM
object -MLP
Conv.
What size is the cylinder that is left of the brown metal thing that is left of the big sphere?
*
Object pair with question
Element-wise sum
End-to-end trainable with gradient descent.
![Page 26: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/26.jpg)
Actual results on CLEVR
First approach to achieve superhuman performance on this task!
![Page 27: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/27.jpg)
Actual results on CLEVR
equallessthan
morethan
existcountoverall
querysize
queryshape
querymaterial
querycolor
comparesize
compareshape
comparematerial
comparecolor
Human
LSTMCNN+LSTMCNN+LSTM+SACNN+LSTM+RN
Q-type baseline
Acc
urac
yAcc
urac
y
0.00.250.5
0.751.0
0.00.250.5
0.751.0
compare attributequery attribute
compare numbers
Especially excels at compare attribute, the query type whichheavily relies on relational reasoning.
![Page 28: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/28.jpg)
Failure cases on CLEVR
Failure inputs are often occurring under heavyocclusion—challenging for humans as well!
![Page 29: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/29.jpg)
Results on Sort-of-CLEVR
A simple CLEVR-inspired dataset with clear separation of relationalvs. non-relational queries.
Demonstrates clear advantage of RNs on relational queries.
![Page 30: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/30.jpg)
LSTM object extraction: bAbI (Weston et al., 2015)
A text-based set of 20 question-answering tasks.
Let O = {~o1, . . . , ~o20} be the LSTM representations of up to 20sentences preceding the question. ~q is once again obtained as theLSTM representation of the question.
![Page 31: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/31.jpg)
Results on bAbI
I RN(O, ~q) passes (95+% accuracy) 18/20 tasks after jointtraining—comparable with other state-of-the-art memorynetwork architectures.
I Memory networks: 14/20I DNC: 18/20I Sparse DNC: 19/20I EntNet: 16/20
I Does not catastrophically fail (91.9% and 83.5% accuracy) onthe remaining two.
I Notably, it succeeds on the basic induction task (97.9%), whereSparse DNC (46%), DNC (44.9%) and EntNet (47.9%) all fail.
![Page 32: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/32.jpg)
Self-attention
I Thus far, the building block functions of a Relational Network(fφ, gθ) were simple MLPs.
I For more recent RN architectures, we focus instead on theself-attention operator.
I A self-attentional operator, Aθ, acts on a set of n entities,E = {~e1, ~e2, . . . , ~en}, producing higher-level representations:
E = Aθ(E)
where E = {~e′1, ~e′2, . . . , ~e′n}, and θ are learnable parameters.
![Page 33: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/33.jpg)
Self-attention, cont’d
I Each component of E will be derived by examining allcomponents of E (by way of linear combinations):
~e′i =∑
j
αij fψ(~ej)
where fψ : Rm → Rk is a learnable transformation.
I Here, the coefficients αij correspond to the importance of thefeatures of entity j to entity i , and are derived by a learnableattention mechanism, aφ : Rm × Rm → R:
αij = aφ(~ei , ~ej)
![Page 34: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/34.jpg)
Self-attention
~e1 . . . ~ej . . . ~en
aφ aφ aφ
α1,j αj,j αn,j
× × ×
Σ
~e′j
fψ fψ fψ
![Page 35: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/35.jpg)
The Transformer architecture
I In particular, the Transformer architecture (Vaswani et al.,2017) is used for Aθ.
I Here abbreviated as MHDPA (multi-head dot-product attention).
I First, derive queries, keys and values for the attention:
~qi = Wq~ei~ki = Wk~ei ~vi = Wv~ei
I Now, use the queries and keys to derive coefficients:
αij =exp
(〈~qi , ~kj〉/
√dk
)∑
m exp(〈~qi , ~km〉/
√dk
)where dk is the dimensionality of the keys.
![Page 36: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/36.jpg)
The Transformer architecture, cont’d
I Now, can use αij to recombine the values at each position:
~e′i =∑
j
αij~vj
I Can be conveniently written in matrix form as:
E = softmax(
QKT√
dk
)V
I Further optimised by using multi-head attention; replicatingthis operation K times (each with independent parameters Wq,Wk , Wv ) and featurewise-concatenating the results.
![Page 37: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/37.jpg)
Reinforcement learning: Box-World and StarCraft II
Box-World: A grid RL environment meant to stress relationalreasoning while deciding how to act.StarCraft II: Mini-games (Vinyals et al., 2017).
In both cases, agent receives pixel-structured inputs (minimaps,screens, etc.).
![Page 38: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/38.jpg)
Relational deep reinforcement learning
I Empowers a standard CNN-based policy network (in an RLsetting) with a relational module based on self-attention.
I Architectures for both tasks are very similar!
I Extract entities, ~ei , just as before (as separate pixels in ahigh-level feature map).
I Then perform several rounds of the Transformer self-attentionover E (each round followed by a small MLP, fθ, and layernormalisation to introduce nonlinearity).
I Finally, perform global pooling and a small MLP to derive thepolicy for the RL algorithm (IMPALA (Espeholt et al., 2018)).
![Page 39: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/39.jpg)
The Box-World architecture
ReLU
ReLU
FC 256
Relationalmodule
x 2
x 4
Input
Feature-wisemax pooling
... ...
query
key
value
...
Multi-head dot product attention
Conv. 2 x 2, stride 1
![Page 40: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/40.jpg)
Results on Box-World
Underlying graphObservation
Bra
nch
leng
th =
1B
ranc
h le
ngth
= 3
0 1 2 3 4 5 6 7 8Environment steps 1e8
1e8
0.0
0.2
0.4
0.6
0.8
1.0
Frac
tion
solv
ed
Environment steps
Frac
tion
solv
ed
0 2 4 6 8 10 12 140.0
0.2
0.4
0.6
0.8
1.0
Relational (1 block)Relational (2 blocks)Baseline (3 blocks)Baseline (6 blocks)
Relational (2 block)Relational (4 blocks)
![Page 41: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/41.jpg)
Visualising the attentional coefficients
![Page 42: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/42.jpg)
Zero-shot experiments in Box-World
Longer solution path lengthsa)
0.0
1.0
0.8
Frac
tion
solv
ed
0.6
0.4
0.2
Train.
Test
Train.
Test4 1086
Frac
tion
solv
ed
Train.
Train. 4 1086
Test Test
0.0
1.0
0.8
0.6
0.4
0.2
Not requiredduring training
Withheld keyb)
Relational BaselineRelational Baseline
The non-relational baseline (ResNet CNN) fails to generalise!
![Page 43: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/43.jpg)
Results on StarCraft II
Sets new state-of-the-art, often beating human grandmaster.
![Page 44: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/44.jpg)
Zero-shot experiments in StarCraft
1 marine 2 marines 3 marines 4 marines 5 marines 10 marines
Mea
n sc
ore
0
50
100
Relational
Control150
200
Exhibits higher—although not fully conclusive—generalisationability from 2 marines to higher numbers.
![Page 45: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/45.jpg)
Neural networks for sequential processing
I Finally, we turn our attention to architectures used for generalsequential processing of data.
I In the general setting, we require a stateful system, Sθ, capableof processing incoming inputs ~xt , and updating its internalstate, ~st , appropriately:
~st = Sθ(~xt ,~st−1)
I Inference may then be performed by leveraging ~st .
![Page 46: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/46.jpg)
Approaches to sequential processing
I Traditional approaches for modelling Sθ include recurrentneural networks (e.g. LSTM, GRU, etc.) andmemory-augmented neural networks (e.g. NTM, DNC, etc.)
I Recurrent neural networks generally represent their state as afixed-size vector, ~ct , which gets appropriately updated at eachstage of input processing.
I Memory-augmented networks have a memory matrix,M∈ Rn×m, which may be read from/written to by using arecurrent controller.
![Page 47: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/47.jpg)
Analysis of approaches
I Both approaches have shortcomings when explicit relationalreasoning through time is required:
I RNNs pack entire representation in a single dense vector,making it hard to reason about entities (and therefore relations);
I Memory-augmented networks explicitly represent entities (asrows ofM), but these cannot easily interact once written to.
I Relational recurrent neural networks address bothshortcomings simultaneously, explicitly allowing rows ofM tointeract using self-attention!
![Page 48: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/48.jpg)
Relational memory
I Assume we have a memory matrixM = {~m1, ~m2, . . . , ~mn}.
I Applying (Transformer) self-attention to it, we obtain a newmemory state M = {~m′1, ~m′2, . . . , ~m′n}, explicitly taking intoaccount the relations between memory rows ~mi :
M = softmax(
QKT√
dk
)V
whereQ =MWq K =MWk V =MWv
just as before.
![Page 49: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/49.jpg)
Incorporating new inputs
I The interactions thus far are self-contained to what’s already inthe memory; however, we’d like the memory to adapt toincoming inputs, ~x , appropriately.
I Simple extension: let the memory locations attend over ~x too!
M = softmax(
Q[K‖Wk~x ]T√dk
)[V‖Wv~x ]
where ‖ denotes row-concatenation.
![Page 50: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/50.jpg)
The LSTM controller
I We do not wish to fully overwriteM by M—can control thisprocess with an LSTM:
~ii,t = σ(
Wi~xt + Ui~hi,t−1 + ~bi
)~fi,t = σ
(Wf~xt + Uf
~hi,t−1 + ~bf
)~oi,t = σ
(Wo~xt + Uo~hi,t−1 + ~bo
)~mi,t = gψ(~m′i,t)�~ii,t + ~mi,t−1 �~fi,t~hi,t = tanh
(~mi,t)� ~oi,t
where gψ is a learnable function (2-layer MLP with layernormalisation in the paper).
![Page 51: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/51.jpg)
The Relational RNN
NextMemory
CORE
Prev.Memory
Input
Output
AResidual
+
MLP
Apply gating
*
Residual
+
***computation of gates not depicted
(a)
Memory
MULTI-HEAD DOT PRODUCT ATTENTION
Input
query
key
value
UpdatedMemory
(b)
.KeysQueries
Compute attention weights
Weights Normalized Weights
Normalize weights with row-wise softmax
Values
.
Weights
Compute weighted average of values Return updated memory
Updated Memory
(c)
![Page 52: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/52.jpg)
Tasks under consideration
A suite of supervised and reinforcement learning tasks demandingexplicit sequential relational reasoning.What is the Nth farthest from vector m?
x = 339for [19]: x += 597 for[94]: x += 875x if 428 < 778 else 652print(x)
BoxWorld Mini-Pacman
LockKey
Loose Key
AgentGem
Viewport
Reinforcement Learning
Program EvaluationNth farthest Language Modeling
Supervised Learning
It had 24 step programming abilities, which meant it was highly _____
A gold dollar had been proposed several times in the 1830s and 1840s , but was not initially _____
Super Mario Land is a 1989 side scrolling platform video _____
I N-th farthest vector from a given vector;I Program evaluation from characters (Learning to Execute);I Language modelling;I Mini-PacMan and Box-World with viewport!
![Page 53: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/53.jpg)
Results on N-th farthest: LSTM/DNC
LSTM DNC
Failing to surpass 30% batch accuracy!
![Page 54: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/54.jpg)
Results on N-th farthest: RRNN
![Page 55: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/55.jpg)
Attention weight visualisation
inpu
t
Input Vector Id
timeattention weights
(a) Reference vector is the last in a sequence, e.g. "Choose the 5th furthest vector from vector 7"
(b) Reference vector is the first in a sequence, e.g. "Choose the 3rd furthest vector from vector 4"
(c) Reference vector comes in the middle of a sequence, e.g. "Choose the 6th furthest vector from vector 6"
att
en
din
g f
rom
attending to
![Page 56: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/56.jpg)
Results on LTE
The RRNN is again highly competitive, especially in scenarioswhere strong relational reasoning may be required (full programs).
![Page 57: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/57.jpg)
Results on LTE
![Page 58: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/58.jpg)
Results on Language Modelling
The RRNN obtains competitive perplexity levels, compared toseveral strong baselines.
![Page 59: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/59.jpg)
Results on Language Modelling: WikiText-103
0 1 2 3 4 5 6 7 8Steps (B)
30
35
40
45
50
55
60
Perp
lexit
y
RMC
LSTM
![Page 60: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/60.jpg)
Results on Mini-PacMan
With viewport Without Viewport
The RRNN outperforms an LSTM when used as a policy network(for IMPALA). Specifically, when the entire map is observed, it
doubles the LSTM performance!
![Page 61: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,](https://reader033.fdocuments.in/reader033/viewer/2022052722/5f0ccacb7e708231d4372750/html5/thumbnails/61.jpg)
Concluding remarks
I Empowering neural networks with various kinds of relationalreasoning modules will likely be a necessary step towardsstrong and robust intelligent systems.
I This claim is clearly supported by several “failure modes” ofbaseline architectures we considered today.
I One limitation going forward lies in the all-pairs interactions,which will limit scalability to larger object sets, especially ifself-attention is used.
I The NRI (Kipf, Fetaya et al., 2018) offers one possible directionto address this, but probably not the ultimate solution. . .
I In my opinion, particularly important avenue for future work aregraph-structured memories; where we are not restricted to amatrix, and relations between slots are not all-pairs.