FSNLP7章のEMが怪しい件

FSNLP(Foundations of Statistical Natural Language Processing)は自然言語処理業界の中では知らない人はないというほど有名な本(英語)。出版年度は古く、内容もかなり時代遅れになってきつつあるのだが、自然言語処理の広い範囲を網羅した本ということで、英語を読む訓練をかねて新入生が輪読をするのはこの業界の風物詩。

うちの研究室(黒橋研)でもその輪読をしていて、B4(学部4年)とM1(修士1年)の他に、M2(修士2年)も復習と新入生の指導という意味合いで参加している。ぼくはM2なので、やはり参加している。

(ここから先は本の内容にからむことなので、FSNLPと合わせて読んでください)

7章 "Word Sense Disambiguation" の Figure 7.8、EMアルゴリズムのところ(ここは errata があるので、修正されたバージョンが前提)。

M-step の中で、P(v_j \mid s_k)=\frac{\sum_{\{c_i:v_j \in c_i\}}h_{ik}}{Z_k}となっている。ここで違和感。v_jが含まれる文脈c_iの確率を足していくというところ。

EMアルゴリズムのM-stepでは、最尤推定でパラメータを決めるというのが直感的な理解。ここで仮定しているのは bag-of-words モデルなので、一つの文に複数回v_jが出現するなら、それを考慮に入れないといけないのでは。今の式では、それがまったく考慮されない。

ここらで直感の限界を感じたので、一般的な EM の式から導出してみる。EMのM-stepは、現在推測されている確率分布を元にして、対数尤度の期待値を最大化するステップ。対数尤度の期待値はQ関数と呼ばれ、内部状態を\mathbf{Z}、与えられたデータを\mathbf{C}、旧パラメータを\theta、新パラメータを\bar{\theta}として、一般的に次のように表せる。

Q(\bar{\theta} \mid \theta)=\sum_{\mathbf{Z}}P(\mathbf{Z} \mid \theta, \mathbf{C})\log{P(\mathbf{Z}, \mathbf{C} \mid \bar{\theta})}

これを今回の例に当てはめる。旧パラメータと新パラメータを区別するために、パラメータにそれぞれ記号を割り振る。ある語義 s_k からある単語 v_j が生成される確率 P(v_j \mid s_k)x_{jk} \in \mathbf{X}とし、対象とする曖昧な単語が語義s_kを持つ確率P(s_k)y_k \in \mathbf{Y}、更新後のパラメータをそれぞれ\bar{x}, \bar{y}, \bar{\mathbf{X}}, \bar{\mathbf{Y}}とおき、それぞれの文脈c_iについて、その文脈での多義語の語義をk_iという内部状態とする。その他は原文通りの表記を使うと、この場合のQ関数は次のように書ける。

Q(\bar{\mathbf{X}}, \bar{\mathbf{Y}} \mid \mathbf{X}, \mathbf{Y}) = \sum_{k_1^I} P(k_1^I \mid C, \mathbf{X}, \mathbf{Y})\log{P(C, k_1^I \mid \bar{\mathbf{X}}, \bar{\mathbf{Y}})}

ここで、原文と同じようにナイーブベイズ推定を使うと、対数尤度は次のように書ける。

\begin{eqnarray} \log P(C, k_1^I \mid \bar{\mathbf{X}}, \bar{\mathbf{Y}}) &=& \log(\prod_{i}\bar{y}_{k_i}\prod_{v_j \in c_i}\bar{x}_{jk}) \nonumber \\ &=& \sum_i \log{\bar{y}_{k_i}} + \sum_{i}\sum_{v_j \in c_i}\log{\bar{x}_{jk}} \end{eqnarray}

また、\begin{eqnarray}P(k_1^I \mid C, \mathbf{X}, \mathbf{Y}) &=&\frac{P(k_1^I, C \mid \mathbf{X}, \mathbf{Y})}{P(C \mid \mathbf{X}, \mathbf{Y})} \\ &=&\frac{\prod_{i} y_{k_i} \prod_{v_j \in c_i}x_{jk}}{\sum_{k_1^I} \prod_{i}y_{k_i}\prod_{v_j \in c_i}x_{jk}} \end{eqnarray}

簡単にするためにu_{ik} = y_{k}\prod_{v_j \in c_i}x_{jk}とおくと、これは\frac{\prod_i u_{i k_i}}{\sum_{k_1^I}\prod_i u_{i k_i}}となる。

これらを使うと、Q関数は次のように書ける。

\begin{eqnarray} Q(\bar{\mathbf{X}}, \bar{\mathbf{Y}} \mid \mathbf{X}, \mathbf{Y}) &=& \sum_{k_1^I} \left(\frac{\prod_i u_{i k_i}}{\sum_{k_1^I}\prod_i u_{i k_i}} \sum_i \log{\bar{y}_{k_i}}\right) + \\ && \sum_{k_1^I} \left(\frac{\prod_i u_{i k_i}}{\sum_{k_1^I}\prod_i u_{i k_i}} \sum_{i}\sum_{v_j \in c_i}\log{\bar{x}_{jk}}\right) \end{eqnarray}

これを最大化する。まず、\bar{y}_kについて。\bar{y}_kには、\sum_k \bar{y}_k = 1という制約条件がある。書き換えると1-\sum_k \bar{y}_k = 0となる。ラグランジュの未定乗数法を適用し、関数F(\bar{\mathbf{X}}, \bar{\mathbf{Y}}, \lambda)を作り、それぞれの\bar{y}_kについて偏微分が 0 になるようにする。

F(\bar{\mathbf{X}}, \bar{\mathbf{Y}}, \lambda) = Q(\bar{\mathbf{X}}, \bar{\mathbf{Y}} \mid \mathbf{X}, \mathbf{Y}) + \lambda \left(1 - \sum_k \bar{y}_k\right)

それぞれの\bar{y}_kについて:

\begin{eqnarray}\frac{\partial F(\bar{\mathbf{X}}, \bar{\mathbf{Y}}, \lambda)}{\partial \bar{y}_{k}} &=& \frac{\partial}{\partial \bar{y}_{k}} \left( \frac{\sum_{k_1^I} \prod_i u_{i k_i} \sum_i \log \bar{y}_{k_i}}{\sum_{k_1^I}\prod_i u_{i k_i}} + \lambda\left(1 - \sum_k \bar{y}_k\right) \right) \\ &=& \frac{\partial}{\partial \bar{y}_{k}} \left( \sum_i\frac{\sum_{k_1^I} \prod_{i'} u_{i'} k_{i'} \log \bar{y}_{k_i}}{\sum_{k_1^I}\prod_{i'} u_{i'} k_{i'}}\right) - \lambda \\ &=&  \frac{\partial}{\partial \bar{y}_{k}} \left( \sum_i\frac{\sum_{k_1}\sum_{k_2} \ldots \sum_{k_I} \prod_{i'} u_{i'} k_{i'} \log \bar{y}_{k_i}}{\sum_{k_1}\sum_{k_2} \ldots \sum_{k_I} \prod_{i'} u_{i'} k_{i'}}\right) - \lambda \\ &=&  \frac{\partial}{\partial \bar{y}_{k}} \left( \sum_i\frac{\sum_{k_i}u_{i k_i} \log \bar{y}_{k_i} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'} }{\sum_{k_i}u_{i k_i} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'}}\right) - \lambda \\ &=&  \sum_i\frac{u_{i k} \frac{1}{\bar{y}_{k}} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'}}{\sum_{k_i}u_{i k_i} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'}} - \lambda \\ &=& \sum_i\frac{u_{ik}}{\sum_{k'} u_{ik'}} \frac{1}{\bar{y}_{k}} - \lambda\\ &=& 0 \end{eqnarray}が成り立つようにする。

それを満たすのは、\bar{y}_k = \sum_i\frac{u_{ik}}{\sum_{k'} u_{ik'}}\frac{1}{\lambda}のとき。

すべてのkに対するこれらの両辺を足し合わせると、\sum_k \bar{y}_k = \sum_k \sum_i\frac{u_{ik}}{\sum_{k'} u_{ik'}}\frac{1}{\lambda}だが、\sum_k \bar{y}_k = 1であるので\lambda = Iを得る。

よって、それぞれの\bar{y}_kについて:

Q関数を最大にするのは、\bar{y}_k = \sum_i\frac{u_{ik}}{\sum_k u_{ik}}\frac{1}{I}h_{ik} = \frac{u_{ik}}{\sum_{k'} u_{ik'}}とおくと、これは\bar{y}_k = \frac{\sum_i{h_{ik}}}{I}となり、本文のものと一致する。

次は\bar{x}_{jk}についてQ関数を最大化する。
それぞれのkについて:

\sum_j \bar{x}_{jk} = 1という制約条件がある。書き換えると1-\sum_j \bar{x}_{jk} = 0 となる。
ラグランジュの未定乗数法を適用し、関数G(\bar{\mathbf{X}}, \bar{\mathbf{Y}}, \lambda)を作り、それぞれの\bar{x}_{jk}について偏微分が 0 になるようにする。
G(\bar{\mathbf{X}}, \bar{\mathbf{Y}}, \lambda) = Q(\bar{\mathbf{X}}, \bar{\mathbf{Y}} \mid \mathbf{X}, \mathbf{Y}) + \lambda\left(1 - \sum_j \bar{x}_{jk}\right)
それぞれの\bar{x}_{jk}について:

文脈c_iに含まれるv_jの数をn_{ij}とおく。

\begin{eqnarray}\frac{\partial G(\bar{\mathbf{X}}, \bar{\mathbf{Y}}, \lambda)}{\partial \bar{x}_{jk}} &=& \frac{\partial}{\partial \bar{x}_{jk}} \left( \frac{\sum_{k_1^I} \prod_i u_{i k_i} \sum_{i}\sum_{v_j \in c_i}\log{\bar{x}_{jk}}}{\sum_{k_1^I}\prod_i u_{i k_i}} + \lambda\left(1 - \sum_j \bar{x}_{jk} \right) \right) \\ &=& \frac{\partial}{\partial \bar{x}_{jk}} \left( \sum_i\frac{\sum_{k_1^I} \prod_{i'} u_{i'} k_{i'} \sum_{v_j \in c_i}\log{\bar{x}_{jk}}}{\sum_{k_1^I}\prod_{i'} u_{i'} k_{i'}}\right) - \lambda \\ &=&  \frac{\partial}{\partial \bar{x}_{jk}} \left( \sum_i\frac{\sum_{k_1}\sum_{k_2} \ldots \sum_{k_I} \prod_{i'} u_{i'} k_{i'} \sum_{v_j \in c_i}\log{\bar{x}_{jk}}}{\sum_{k_1}\sum_{k_2} \ldots \sum_{k_I} \prod_{i'} u_{i'} k_{i'}}\right) - \lambda \\ &=&  \frac{\partial}{\partial \bar{x}_{jk}} \left( \sum_i\frac{\sum_{k_i}u_{i k_i} \sum_{v_j \in c_i}\log{\bar{x}_{jk}} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'}}{\sum_{k_i}u_{i k_i} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'}}\right) - \lambda \\ &=&  \sum_i\frac{u_{ik} n_{ij} \frac{1}{\bar{x}_{jk}} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'}}{\sum_{k_i}u_{i k_i} \sum_{k_1} \ldots \sum_{k_{i-1}}\sum_{k_{i+1}}\ldots \sum_{k_I} \prod_{i'=1 \ldots i-1,\\ i+1 \ldots I} u_{i'} k_{i'}} - \lambda \\ &=& \sum_i\left( \frac{u_{ik}}{\sum_{k'} u_{i k'}}n_{ij} \frac{1}{\bar{x}_{jk}}\right) - \lambda \\ &=& 0 \end{eqnarray}が成り立つようにする。
それを満たすのは、\bar{x}_{jk} = \frac{1}{\lambda}\sum_i\left(n_{ij}\frac{u_{ik}}{\sum_{k'} u_{i k'}}\right)のとき。

これらの両辺をすべて足し合わせると、\sum_j \bar{x}_{jk} = \frac{1}{\lambda} \sum_j \sum_i \left(n_{ij} \frac{u_{ik}}{\sum_{k'} u_{i k'}}\right)だが、\sum_j \bar{x}_{jk} = 1であるので\lambda = \sum_j \sum_i \left(n_{ij}\frac{u_{ik}}{\sum_{k'} u_{i k'}}\right)を得る。
よって、それぞれの\bar{x}_{jk}について:

Q関数を最大にするのは、\bar{x}_{jk} = \frac{\sum_i\left(n_{ij}\frac{u_{ik}}{\sum_{k'} u_{i k'}}\right)}{\sum_{j'} \sum_i \left(n_{ij'}\frac{u_{ik}}{\sum_{k'} u_{i k'}}\right)}。再びh_{ik} = \frac{u_{ik}}{\sum_{k'} u_{i k'}}とおくと、これは\bar{x}_{jk} = \frac{\sum_i n_{ij}h_{ik}}{\sum_{j'} \sum_i n_{ij'}h_{ik}}となる。

出現回数は更新式に現れるべきだという直観は間違っていなかった(はず)。そうでなければ、bag of words でなく、set になってしまい、それまでの記述と矛盾する。

Errata では、

In the last line of the E-step, the product must be interpreted as j = 1, ..., J. This is a bit confusing, since we otherwise mainly iterate over the words in the context. It could alternatively have been expressed as Product v_j in c_i, and the exponent dropped.

としている。要するに、単語が 2回出てきても 1回しか掛けるな、ということ。しかしそれは(7.4)の Naive Bayes assumption に反するし、そもそも P(v_j \mid s_k) の意味がおかしなことになる。

やっぱり、本は少し間違っているんじゃないだろうか。