3

In the CS229 Lecture Notes on the EM algorithm by Tengyu Ma and Andrew Ng (2019), the authors write that $$ \log(p(\mathbf{x};\theta)) = \log\left(\mathbb{E}_{q(\mathbf{z})}\left[\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right]\right) \geq \mathbb{E}_{q(\mathbf{z})}\left[\log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right)\right] $$ They also write that this inequality applies to all possible choices of $q(\mathbf{z})$, which means that we should choose the $q(\mathbf{z})$ that makes the right-hand side of this inequality equal to the left-hand side. This happens when $$ \frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})} = c $$ for some constant $c$. The authors then write

This is easily accomplished by choosing $$ q(\mathbf{z}) \propto p(\mathbf{x},\mathbf{z};\theta) $$ Actually, since we know $\sum_{\mathbf{z}} q(\mathbf{z}) = 1$, this further tells us that \begin{align} q(\mathbf{z}) &= \frac{p(\mathbf{x},\mathbf{z};\theta)}{\sum_{\mathbf{z}} p(\mathbf{x},\mathbf{z};\theta)} \\ &= \frac{p(\mathbf{x},\mathbf{z};\theta)}{p(\mathbf{x};\theta)} \\ &= p(\mathbf{z}|\mathbf{x};\theta) \end{align}

However, I am not sure how they reasoned that $$ q(\mathbf{z}) = \frac{p(\mathbf{x},\mathbf{z};\theta)}{\sum_{\mathbf{z}} p(\mathbf{x},\mathbf{z};\theta)} $$ so would appreciate some clarification on this.

mhdadk
  • 4,940

2 Answers2

2

Another way to make the evidence lower-bound $$ \log(p(\mathbf{x};\theta)) = \log\left(\mathbb{E}_{q(\mathbf{z})}\left[\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right]\right) \geq \mathbb{E}_{q(\mathbf{z})}\left[\log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right)\right] $$ as tight as possible with respect to $q(\mathbf{z})$ is to minimize the difference $$ \log(p(\mathbf{x};\theta)) - \mathbb{E}_{q(\mathbf{z})}\left[\log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right)\right] $$ with respect to $q(\mathbf{z})$. Since \begin{align} \mathbb{E}_{q(\mathbf{z})}\left[\log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right)\right] &= \int_{\mathbf{z}} q(\mathbf{z}) \cdot \log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right) \text{d}\mathbf{z} \\ &= \int_{\mathbf{z}} q(\mathbf{z}) \cdot \log\left(\frac{p(\mathbf{z}|\mathbf{x};\theta) \cdot p(\mathbf{x};\theta)}{q(\mathbf{z})}\right) \text{d}\mathbf{z} \\ &= \int_{\mathbf{z}} q(\mathbf{z}) \cdot \log\left(\frac{p(\mathbf{z}|\mathbf{x};\theta)}{q(\mathbf{z})}\right) \text{d}\mathbf{z} + \int_{\mathbf{z}} q(\mathbf{z}) \cdot \log\left(p(\mathbf{x};\theta)\right) \text{d}\mathbf{z} \\ &= \int_{\mathbf{z}} q(\mathbf{z}) \cdot \log\left(\frac{p(\mathbf{z}|\mathbf{x};\theta)}{q(\mathbf{z})}\right) \text{d}\mathbf{z} + \log\left(p(\mathbf{x};\theta)\right) \cdot \int_{\mathbf{z}} q(\mathbf{z}) \text{d}\mathbf{z} \\ &= -D_{\text{KL}}(q(\mathbf{z}) \mid\mid p(\mathbf{z}|\mathbf{x};\theta)) + \log\left(p(\mathbf{x};\theta)\right) \end{align} where $D_{\text{KL}}(q(\mathbf{z}) \mid\mid p(\mathbf{z}|\mathbf{x};\theta))$ is the Kullback–Leibler divergence between $q(\mathbf{z})$ and $p(\mathbf{z}|\mathbf{x};\theta)$, then \begin{align} \log(p(\mathbf{x};\theta)) - \mathbb{E}_{q(\mathbf{z})}\left[\log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right)\right] &= \log(p(\mathbf{x};\theta)) + D_{\text{KL}}(q(\mathbf{z}) \mid\mid p(\mathbf{z}|\mathbf{x};\theta)) - \log\left(p(\mathbf{x};\theta)\right) \\ &= D_{\text{KL}}(q(\mathbf{z}) \mid\mid p(\mathbf{z}|\mathbf{x};\theta)) \end{align} Since $D_{\text{KL}}(q(\mathbf{z}) \mid\mid p(\mathbf{z}|\mathbf{x};\theta)) = 0$ if $q(\mathbf{z}) = p(\mathbf{z}|\mathbf{x};\theta)$, then the difference $$ \log(p(\mathbf{x};\theta)) - \mathbb{E}_{q(\mathbf{z})}\left[\log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right)\right] $$ is minimized when $q(\mathbf{z}) = p(\mathbf{z}|\mathbf{x};\theta)$, which in turn means that the inequality $$ \log(p(\mathbf{x};\theta)) = \log\left(\mathbb{E}_{q(\mathbf{z})}\left[\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right]\right) \geq \mathbb{E}_{q(\mathbf{z})}\left[\log\left(\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right)\right] $$ becomes an equality.

mhdadk
  • 4,940
2

Referring to the note you posted, the choice of $q(\mathbf{z})$ is dictated by the strategy they use to prove/illustrate the method. So basically you want to choose $q(\mathbf{z})$ proportional to $p(\mathbf{x}, \mathbf{z}; \theta)$, because you want the lower bound to be strict, that is your inequality above should hold with equality. When does this happen? When you are integrating something that is a constant with respect to $\mathbf{z}$.

Indeed, note that:

  • $ \log (\mathbb{E}_{q(\mathbf{z})} c)=\log \sum_z c P(Z=z)=\log c \sum_z P(Z=z) = \log c $;
  • $\mathbb{E}_{q(\mathbf{z})} \log ( c)=\sum_z (\log c) P(Z=z)=\log c \sum_z P(Z=z) = \log c $.

Thus, we want $\left[\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})}\right]$ to be equal to a constant (with respect to $\mathbf{z})$. The general way to achieve this is taking $q(\mathbf{z})$ to be proportional to $p(\mathbf{x},\mathbf{z};\theta)$, that is: $$q(\mathbf{z})=c\cdot p(\mathbf{x}, \mathbf{z}; \theta).$$

Now, we also want $q(\mathbf{z})$ to be a proper distribution over $\mathbf{z}$, and to ensure this, you want it to integrate to $1$, that is: $$\sum_z q(\mathbf{z})=1\implies \sum_z c\cdot p(\mathbf{x}, \mathbf{z}; \theta)=1 $$ Rearranging the above leads to define $c$: $$\sum_z c\cdot p(\mathbf{x}, \mathbf{z}; \theta)=1 \iff c =\frac{1}{\sum_z p(\mathbf{x}, \mathbf{z}; \theta)}$$ which in turn gives: $$q(\mathbf{z})=\frac{p(\mathbf{x},\mathbf{z};\theta)}{\sum_z p(\mathbf{x},\mathbf{z};\theta)}.$$

On page 5 of the notes it is showed that plugging this in attains the bound with equality. (Exactly same reasoning applies with continuous distributions.)

lcorag
  • 316
  • (+1) thanks a lot! Just wanted to point out that if $$\frac{p(\mathbf{x},\mathbf{z};\theta)}{q(\mathbf{z})} = c$$ then $$c = \sum_z p(\mathbf{x}, \mathbf{z}; \theta)$$ obtains what the authors wrote in the paper. – mhdadk Jul 06 '21 at 13:10