Inference in Probabilistic Graphical -...

33
Inference in Probabilistic Graphical Models by Graph Neural Networks Author: KiJung Yoon, Renjie Liao, Yuwen Xiong, Lisa Zhang, Ethan Fetaya, Raquel Urtasun, Richard Zemel, Xaq Pitkow Presenter: Shihao Niu, Zhe Qu, Siqi Liu, Jules Ahmar

Transcript of Inference in Probabilistic Graphical -...

  • Inference in Probabilistic Graphical Models by Graph Neural Networks

    Author: KiJung Yoon, Renjie Liao, Yuwen Xiong, Lisa Zhang, Ethan Fetaya, Raquel Urtasun, Richard Zemel, Xaq Pitkow

    Presenter: Shihao Niu, Zhe Qu, Siqi Liu, Jules Ahmar

  • TL;DR: Use Graph Neural Networks (GNNs) to learn a message-passing algorithm that solves inference tasks in probabilistic graphical models.

    Motivation● Inference is difficult for probabilistic graphical models. ● Message passing algorithms, such as belief propagation, struggles when the

    graph contains loops○ Loopy belief propagation: convergence are not guaranteed.

  • Why GNNs● Essentially an extension of recurrent neural networks (RNN) on the graph

    inputs. ● Central idea is to update hidden states at each node iteratively, by

    aggregating incoming messages. ● Have a similar structure as a message passing algorithm.

  • ● Recall that the distribution of a factor graph is○

    ● Recall the formulas of a belief propagation algorithm○ ○

    Factor graph and belief propagation

  • BP to GNNs: mapping the messages

    ● BP is recursive and graph-based. Naturally, we could map the messages to GNN nodes, and use Neural Networks to describe the nonlinear updates.

  • BP to GNNs: mapping the variable nodes

  • BP to GNNs: mapping the variable nodesMarginal probability of in MRF:

    Marginal joint probability of in factor graph:

    ● All of the messages depend only on one variable node at a time● The nonlinear functions between GNN nodes can account for AFTER

    equilibrium is reached.

  • Preliminaries for model● Binary MRF, aka Ising models.● and are specified randomly, and are provided as input for GNN inference. ● ●

  • GNN Recap

    Update the state embedding of based on

    - the feature of - the feature of the edges of- the state embeddings of the neighbors of - the feature of the neighbor of

    Local output function:

  • GNN Recap (Cont.)Scarselli, Franco, et al. "The graph neural network model."

    Decompose the state update function to be a sum of per-edge terms

  • Message Passing Neural Networks

    Define Message from i to j at time t+1 as:

    Step 1: Aggregate all incoming message into a single message at the destination node

    Step 2: Update hidden state based on the current hidden state and the aggregated message

    An abstraction of several GNN variants

    Phase 1Message Passing

  • Message Passing Neural Networks (Cont.)

    Phase 2: Readout Phase

    The message function, node update function, and readout function could have different settings.

    MPNN could generalize several different models.

  • GG-NN (Gated Graph Neural Network)

    Source: Zhou, Jie, et al. "Graph neural networks: A review of methods and applications."

    Gate Recurrent Units (GRU)

  • GG-NN (Cont.)

    Readout Phase:

  • GG-NN (Cont.)

    Gate Recurrent Units (GRU)

  • GG-NN (Cont.)

    Gate Recurrent Units (GRU)

  • Two mappings between Factor graph and GNN

    message-GNN and node-GNN perform similarly, and much better than belief propagation

    message-GNN

    node-GNN

  • Mapping I: Message-GNN (graphical model) (GNN) Message 𝜇ij between node i and j Node v Message nodes are ij and jk Node v and w connected

    Conforms closely to the structure of conventional belief propagation, and reflects how messages depend on each other:Motivation:

  • Mapping I: Message-GNN1. If connected, message from node to :

    2. Then update its hidden state by:

    3. Readout function to extract marginal or MAP:

    a. First aggregates all GNN nodes with same target by summation

    b. Then apply a shared readout function

    neural network (GRU)

    Multi-layer Perceptron with ReLU activation function

    another MLP with sigmoid activation function

    (nodes in graphical model)

  • Mapping II: Node-GNN

    ● Mapping: (graphical model) (GNN) Variable nodes Node

    1. Message function:

    2. Aggregate Messages:

    3. Node update function:

    4. Readout is generated directly from hidden states:

  • Message-GNN and Node-GNN● Objective: backpropagation to minimize total cross-entropy loss function

    --- ground truth, --- estimated result

    ● Receives external inputs about couplings between edges● Depends on the hidden states of source and destination nodes at the

    previous time step.

    Message Passing Function (General):

  • Experiments● In each experiment, two types of GNNs are tested:

    ○ Variable nodes (node-GNN)○ Message nodes (msg-GNN)

    ● Examine generalization of the model when...○ Testing on unseen graphs of the same structure○ Testing on completely random graphs○ Testing on graphs with the same size○ Testing on graphs with larger size

    ● Analyze performance in estimating both marginal probabilities and MAP state

  • Training Graphs

  • Larger, Novel Test Graphs

  • Marginal Inference Accuracy

  • Random Graphs

  • Generalization Performance on Random Graphs

  • Convergence of Inference Dynamics

  • MAP Estimation

  • Conclusion● Experiments showed that GNNs provide a flexible learning method for

    inference in probabilistic graphical models

    ● Proved that learned representations and nonlinear transformations on edges generalize to larger graphs with different structures

    ● Examined two possible representations of graphical models within GNNs: variable nodes and message nodes

    ● Experimental results support GNNs as a great framework for solving hard inference problems

    ● Future work: train and test on larger and more diverse graphs, as well as broader classes of graphical models

  • References1. Zhou, Jie, et al. "Graph neural networks: A review of methods and applications." arXiv preprint

    arXiv:1812.08434 (2018).

    2. Gilmer, Justin, et al. "Neural message passing for quantum chemistry." Proceedings of the 34th

    International Conference on Machine Learning-Volume 70. JMLR. org, 2017.

    3. Scarselli, Franco, et al. "The graph neural network model." IEEE Transactions on Neural Networks

    20.1 (2008): 61-80.

    4. Li, Yujia, et al. "Gated graph sequence neural networks." arXiv preprint arXiv:1511.05493 (2015).

    5. Wu, Zonghan, et al. "A comprehensive survey on graph neural networks." arXiv preprint

    arXiv:1901.00596 (2019).

  • Homework1. Where do GNNs outperform belief propagation? Where does belief

    propagation outperform GNNs?2. Given the following factor graph, draw the GNN using Message-GNN

    mapping: