logsumexp は人類の黒歴史

あえて言い切る。

ここでは主に計算量について言っているので、スクリプト言語でデモ的なものを作るような場合は除く。この記事では C++ を使って書く。

logsumexp でどのように計算の量が増えるかはunnonouno: logsumexpとスケーリング法に詳しい。

ちょっと引用。

linear-chain CRFのパラメタ推定に必要なのは対数尤度関数の微分です.これの計算に必要なのが,前向き・後ろ向きのスコアαとβです.時刻t(系列上での位置)とラベルiに対する前向きスコアαは,以下の式で計算されます.fは特徴ベクトル,wは重みベクトルです.

この後の話の流れは、

  1. 掛け算いっぱいで大変! オーバー/アンダーフローしちゃうよ!
  2. じゃあ log の世界で扱えばいいんじゃね?
  3. 足し算もあるよ! どうするの?
  4. 「logsumexp〜」
  5. すごい! これでオーバーフローしないし安心だね!
  6. でも重いよ! log とか exp 使わないでなんとかできないの?
  7. それじゃ、「スケーリング〜」
  8. すごい! 計算の量が減ったよ!
  9. でも実装大変…ソース見にくい…←イマココ

という感じ。

1. に戻って考える。

「オーバーフローしない実数型があったらいいんじゃね?」

こういうことを言うと、「そんなのあったら苦労しないんじゃアホボケカス氏ね」と言われそうだが、まあ実際にあったとしよう。そしたら誰も logsumexp なんて使わないよね?

そうは言っても、グラハム数を入れてもオーバーフローしない実数型なんていうのは難しいので、どのくらいの範囲が表せればいいかを考える。

まず、一般的な 64bit double の場合だが、指数部は 11bit。符号を抜くと 10bit。2のプラスマイナス 1024 乗までということなので、だいたい 10 のプラスマイナス 300乗。これではすぐにオーバー/アンダーフローするのが目に見えている。時刻ごとに確率が 10のマイナス 5乗程度になったとして、系列の長さが 60以上あるともうダメということだ。

じゃあ、指数部を増やせたらいいんじゃないの? もし指数部が 32bit あればどうだろう。符号を抜いて 31 bit、だいたい 10 のプラスマイナス 21億乗まで表せるということだ。これで不足することはあるだろうか? 同じ仮定でいうと、系列の長さが 4億まで大丈夫ということだ。これなら現実的に考えて問題ない。

そういうわけなので、もしも 指数部 32bit、仮数部 32bit といった浮動小数点型がサポートされていたら、誰も(少なくとも HMM や CRF の文脈では) logsumexp なんて話はしていないはずなのだ。

まあ、ないものはしょうがない。せっかく C++ には演算子オーバーロードという機能があるので、それを利用することを考える。

の式をよく見てみる。


は、

  1. (実数型)との掛け算をする。
  2. (Σで)お互いに足し算をする。

ここではこの二つさえできればいい。(のところは、すぐにオーバーフローすることはない)

じゃあ書いてみよう。クラス名は BigDouble。

#include <cmath>

class BigDouble {
public:
  BigDouble(const double d = 0.0) {
    *this = d;
  }

  int GetExponent() const {
    return exponent_;
  }

  double GetFraction() const {
    return fraction_;
  }

  BigDouble& operator=(double d) {
    fraction_ = frexp(d, &exponent_);
    return *this;
  }

  BigDouble operator*(const double to_multiply) {
    int exponent;
    BigDouble temp = *this;
    temp.fraction_ = frexp(to_multiply * temp.fraction_, &exponent);
    temp.exponent_ += exponent;
    return temp;
  }

  BigDouble& operator+=(const BigDouble& to_add) {
    bool this_is_bigger = (fraction_ != 0 && exponent_ > to_add.GetExponent());
    const BigDouble& smaller = this_is_bigger ? to_add : *this;
    const BigDouble& bigger = this_is_bigger ? *this : to_add;
    int exponent_diff = smaller.GetExponent() - bigger.GetExponent();
    exponent_ = bigger.GetExponent();
    fraction_ = bigger.fraction_ + ldexp(smaller.GetFraction(), exponent_diff);
  }
private:
  int exponent_;
  double fraction_;
};

frexp というのは浮動小数点数仮数部と指数部に分解する関数で、ldexp というのはその逆。

BigDouble と double との掛け算では、BigDouble の仮数部と double をそのまま掛け、返り値の指数部を BigDouble がそれまで持っていた指数部に足すということをしている。

BigDouble 同士の足し算では、指数部を比べて大きい方の仮数部に、小さい方の仮数部と指数部の差を組み合わせた数を足して、指数部は大きい方のものをそのまま使っている。こうすると仮数部が 1 を超えることもあるので純粋な意味での仮数部にはならないが、特に害はないはずなのでそのまま持っている。

実用的に使おうと思ったら、この他に前向き・後ろ向きの値を掛けるために BigDouble 同士の掛け算がいるし、比を取るための BigDouble 同士の割り算もいる。ここでは省略している。

で、こうすることにより logsumexp より早くなる。一応時間を計ってみた。(上のコードは bigdouble.h で保存)

#include <climits>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <ctime>

#include "bigdouble.h"

template <class T> inline T _min (T x, T y) { return (x < y) ? x : y; }
template <class T> inline T _max (T x, T y) { return (x > y) ? x : y; }

#define MINUS_LOG_EPSILON 50

inline double logsumexp(double x, double y, bool flg) {
  if (flg) return y;  // init mode
  double vmin = _min(x, y);
  double vmax = _max(x, y);
  if (vmax > vmin + MINUS_LOG_EPSILON) {
    return vmax;
  } else {
    return vmax + std::log(std::exp(vmin - vmax) + 1.0);
  }
}

#define kFeatureCount 1000
#define kSequenceLen 100
#define kLabelCount 10
#define kFeaturesPerTransition 5
#define kIteration 1000

int main(void) {
  double features[kFeatureCount];
  double exp_features[kFeatureCount];

  srand(0);
  for (int i = 0; i < kFeatureCount; ++i) {
    features[i] = (rand() / (RAND_MAX / 10.0)) - 5.0;
    exp_features[i] = exp(features[i]);
  }
  
  double trans_score_matrix[kLabelCount][kLabelCount] = {0};
  double trans_score_exp_matrix[kLabelCount][kLabelCount];

  for (int i = 0; i < kLabelCount; ++i) {
    for (int j = 0; j < kLabelCount; ++j) {
      trans_score_matrix[i][j] = trans_score_exp_matrix[i][j] = 0;
      for (int k = 0; k < kFeaturesPerTransition; ++k) {
	int feature_id = (rand() / (RAND_MAX / (double)kFeatureCount));
	trans_score_matrix[i][j] += features[feature_id];
      }
      trans_score_exp_matrix[i][j] = exp(trans_score_matrix[i][j]);
    }
  }

  clock_t start;
  clock_t end;
  BigDouble forward_score_exp_matrix[kSequenceLen][kLabelCount];
  BigDouble test(2.0);
  
  start = clock();
  for (int l = 0; l < kLabelCount; ++l) {
    forward_score_exp_matrix[0][l] = 1.0;
  }
  for (int i = 0; i < kIteration; ++i) {
    for (int t = 1; t < kSequenceLen; ++t) {
      for (int j = 0; j < kLabelCount; ++j) {
	forward_score_exp_matrix[t][j] = 0.0;
	for (int k = 0; k < kLabelCount; ++k) {
	  forward_score_exp_matrix[t][j] += forward_score_exp_matrix[t-1][k] *
	    trans_score_exp_matrix[k][j];
	}
      }
    }
  }
  end = clock();
  printf("BigDouble: %.2f secs%5Cn", (double)(end - start)/CLOCKS_PER_SEC);

  double forward_score_matrix[kSequenceLen][kLabelCount];

  start = clock();
  for (int l = 0; l < kLabelCount; ++l) {
    forward_score_matrix[0][l] = 0.0;
  }
  for (int i = 0; i < kIteration; ++i) {
    for (int t = 1; t < kSequenceLen; ++t) {
      for (int j = 0; j < kLabelCount; ++j) {
	for (int k = 0; k < kLabelCount; ++k) {
	  forward_score_matrix[t][j] = logsumexp(forward_score_matrix[t][j],
						 forward_score_matrix[t-1][k] +
						 trans_score_matrix[k][j],
						 k==0);
	}
      }
    }
  }
  end = clock();
  printf("logsumexp: %.2f secs%5Cn", (double)(end - start)/CLOCKS_PER_SEC);

  return 0;
}

CRF の計算もどき(前向きスコア計算のうち、計算量の多くを占めると思われる遷移素性の計算の真似事。キャッシュされていると考える)。

時間を計るようなプログラムを書いたことがほとんどないのであまり自信はないし、いろいろツッコミどころもあるだろうけど、とりあえずこんな感じで logsumexp と比べてみた。

> g++ -O2 test.cc
> ./a.out
BigDouble: 0.25 secs
logsumexp: 0.70 secs

うまくいってるんじゃないだろうか。 ちなみに Macbook Air (プロセッサ 2.13 GHz Intel Core 2 Duo, メモリ 4 GB)。

それにしても書くのが面倒だ。
logsumexp のほうがずっとシンプル。
結局、スケーリングのほうが早いわけだし(どのくらい違うかは試していない)。

まあ言いたいのは、logsumexp は結構遅くなるので、ちゃんとスケーリングするか、ロジックがわかりにくくなるのが嫌な場合はこんな感じのクラスでも作ってみるのもいいんじゃないのかな、ということ。