Differential Inference: A Criminally Underused Tool

@srush

Style

This talk is a live working PyTorch notebook.

https://github.com/srush/ProbTalk

Preface

It is bizarre that the main technical contribution of so many papers seems to be something that computers can do for us automatically. We would be better off just considering autodiff part of the optimization procedure, and directly plugging in the objective function. In my opinion, this is actually harmful to the field.

  • Justin Domke, 2009

Differential Inference

Goal: Use differentiation to perform complex probabilistic inference.

Disclaimer

This talk contains no new research:

Also:

Part 1: Counting the Hard Way

Problem

I have two coins, how many different ways can I place them?

Answer:

Observed Coins

Let $\lambda^1$ represent Coin 1:

Latent Coins

If we do not know the state, we use $\lambda^1 = \mathbf{1}$.

Counting

We can use this to count.

$$f(\lambda) = \lambda_0^1 \lambda_0^2 + \lambda_0^1 \lambda_1^2 + \lambda_1^1 \lambda_0^2 + \lambda_1^1 \lambda_1^2$$

Constrained Counting

We can also count under a known constraint. Starting from:

$$f(\lambda) = \lambda_0^1 \lambda_0^2 + \lambda_0^1 \lambda_1^2 + \lambda_1^1 \lambda_0^2 + \lambda_1^1 \lambda_1^2$$

Let $\lambda^2 = \delta_0$, $$f(\lambda) = \lambda_0^1 + \lambda_1^1$$

Differential Counting

Even better we can also count under all constraints. Starting from:

$$f(\lambda)=\lambda_0^1\lambda_0^2 + \lambda_0^1 \lambda_1^2 + \lambda_1^1 \lambda_0^2 + \lambda_1^1 \lambda_1^2$$

Derivative gives

$$f'_{\lambda_0^1}(\lambda)=\lambda_0^2+\lambda_1^2$$

Differential Counting 2

Derivative gives

$$f'_{\lambda_0^2}(\lambda)=\lambda_0^1+\lambda_1^1$$$$f'_{\lambda_1^2}(\lambda)=\lambda_0^1+\lambda_1^1$$

Problem: Counting with Branching

Place Coin 1.

  • If tails, Coin 2 must be heads.
  • If heads, Coin 2 can be either.

Answer:

Counting Function

Generative count for process,

$$f(\lambda) = \lambda_0^1 \lambda_1^2 + (\sum_j \lambda_1^1 \lambda_j^2)$$

Counting

Number of ways the coins can land.

Query

Number of ways the coins can land.

Constrained Query

Number of ways the coins can land, depending on the first.

Part 2: Probabilistic Inference

Differential Inference

We specify the Joint $$p(x_1,x_2)$$

For observed evidence $e$, we get for free:

Problem: More Coins the Hard Way

Flip two fair coins.

Joint

Function for joint probability $$p(x_1,x_2)$$

$$f(\lambda) = \sum_{i,j} \lambda^1_i \lambda^2_j\ p(x_1=i, x_2=j)$$

Joint Probability

Using the function with $\delta_0$ and $\delta_1$, $$p(x_1=1, x_2=0)$$

Marginal Probability

Using the function to marginalize with $\mathbf{1}$, $$p(x_2=0)$$

$$f(\lambda^1 =\mathbf{1}, \lambda^2 = \delta_0) = \sum_{i} \lambda^1_i p(x_1=i, x_2=0) $$

Constrained Joint

$$f(\lambda^1 =\mathbf{1}, \lambda^2 = \delta_0) = \sum_{i} \lambda^1_i p(x_1=i, x_2=0) $$$$f'_{\lambda^1_0}(\mathbf{1}, \delta_0)= p(x_1=0, x_2=0)\ |\ f'_{\lambda^1_1}(\mathbf{1}, \delta_0) = p(x_1=1, x_2=0)$$

Conditional

With Bayes' Rule, $$p(x_1 | x_2=e) = \frac{p(x_1, x_2=e)}{p(x_2=e)}$$

$$f'(\lambda) / f(\lambda)$$

Conditional Computation

Use log trick, $$(\log f)' = f'(\mathbf{1}, \delta_1) / f(\mathbf{1}, \delta_1) = p(x_1 | x_2=1)$$

Part 3: Fancy Coins

Problem: Conditional Coins

Flip Coin 1.

  • If tails, place Coin 2 as heads.
  • If heads, flip Coin 2.

Conditional Inference

$$p(x_1 | x_2=0)$$

Problem: Coins and Dice

I flipped a fair coin, if it was heads I rolled a fair die, otherwise I rolled a weighted die.

Generative Story:

I flipped a fair coin, if it was heads I rolled a fair die, otherwise I rolled a weighted die.

Dice from Coin 1

Coin from Dice 1

Coin from Dice 2

Dice Marginal

Problem: Summing Up

Can construct more complex operations.

Flip two coins, how many heads?

$$ C = X_1 + X_2$$

Sum of Variables

Let $\lambda^c$ be the sum of two uniform variables.

$$f(\lambda)=\lambda^c_0 \lambda^1_0 \lambda^2_0 p(x_1=0,x_2=0 ) + ...$$

Sum of Coins

Let $\lambda_c$ be the sum of two uniform variables.

Sum of Dice

Dice Conditioned on Sum

Part 4: Real Models

Problem: Graphical Models

Conditional Probabilities

Bayes Net

Construct the joint probability of the system.

Joint Probability

$$p(R=r, S=s, W=w)$$

Marginal Inference

$$p(R)$$

Conditional Inference

$$p(R | W=1)$$

Problem: Gaussian Mixture Model

Generative Model

Generate a class, generate point from Gaussian.

GMM - Expectation-Maximization

Problem: Hidden Markov Models (HMM)

Joint probability ($f$) of hidden states and observations.

Example: HMM

A simple HMM with circulant transitions

Differential Inference

Inference over states with some known observations

Conclusion

"Counting with Style"

What comes next?

So much more...

Thanks!