4

I was trying to understand the mathematical proof of KL-Divergence when using the chain rule:

$D(p(x,y)||q(x,y)) = D(p(x)||q(x)) + D(p(y|x)||q(y|x))$

And I'm a bit lost in the last step (https://www.cs.princeton.edu/courses/archive/fall11/cos597D/L03.pdf, 1.3 Conditional Divergence). What I don't understand is why is this true?

$D(p(x)||q(x)) = \sum_x \sum_y p(x, y) log \frac{p(x)}{q(x)}$

The definition says something slightly different:

$D(p(x)||q(x)) = \sum_x p(x) log \frac{p(x)}{q(x)}$

I have also seen this written in another way that I still don't understand (https://homes.cs.washington.edu/~anuprao/pubs/CSE533Autumn2010/lecture3.pdf, 2.3 Conditional Divergence):

$D(p(x)||q(x)) = \sum_x p(x) log \frac{p(x)}{q(x)} \sum_y p(y|x)$

Why is that last part $\sum_y p(y|x)$ also absorbed into the KL definition?

kuonb
  • 143

2 Answers2

2

First of all, $\sum_y{p(x,y)}=p(x)$, because you're fixing $x$ and summing over all possible $y$. Therefore, $$\begin{align}D(p||q) &= \sum_x \sum_y p(x, y) log \frac{p(x)}{q(x)}\\ &= \sum_x log \frac{p(x)}{q(x)} \sum_y p(x, y) \\ &= \sum_x log \frac{p(x)}{q(x)} p(x)\end{align}$$, which is what the definition says.

For the latter part, from the definition of joint PMF, we have $p(x,y)=p(x)p(y|x)$. So, $$\sum_y{p(x,y)}=\sum_y p(x)p(y|x)=p(x)\sum_y p(y|x)$$, which explains the factorization in the KL formulation.

gunes
  • 57,205
1

What I don't understand is why is this true? $$D(p(x) \parallel q(x)) = \sum_x \sum_y p(x, y) \log \frac{p(x)}{q(x)}$$

$\log \frac{p(x)}{q(x)}$ doesn't depend on $y$, so it can be moved out of the inner sum:

$$D(p(x) \parallel q(x)) = \sum_x \log \frac{p(x)}{q(x)} \sum_y p(x, y)$$

$\sum_y p(x, y) = p(x)$ by the definition of marginal probability. Plugging this in gives the canonical expression for KL divergence:

$$D(p(x) \parallel q(x)) = \sum_x p(x) \log \frac{p(x)}{q(x)}$$

Why is that last part $\sum_y p(y \mid y)$ also absorbed into the KL definition?

I don't see that in the notes you linked. Maybe a typo? Assuming you meant $\sum_y p(y \mid x)$, then this is equal to one, since a probability distribution must sum to one by definition.

user20160
  • 32,439
  • 3
  • 76
  • 112