可変長オーダーCRF

修士論文のテーマである「可変長オーダー Linear-Chain CRF」(以下「可変長CRF」)についてのメモ。

まず、自分の考える「可変長CRF」の定義。

その前に、Linear-Chain CRF(以下 CRF)について。

一般的な CRF では、あらかじめ Markov オーダーを決める。1次なり 2次なり。素性関数は、その範囲内で設定される(1次なら f(X, y_{i-1}, y_{i})、2次なら f(X, y_{i-2}, y_{i-1}, y_{i})というように)。

この次数を可変長にして、複数次数の素性を混ぜて使えるようにするのが目的。

単純に複数次数の素性を混ぜるだけなら、普通の CRF でもできる。
最大次数の CRF を考えて、低次の素性関数は高次の引数を無視すればいい。
ただ、それでは最大次数に対して指数的な時間がかかる。
最大次数が 5次なら、ラベル数の 6乗。

高次の素性がスパースである場合、この計算方法は無駄が多い。それで、「可変長 CRF」というものを考える。

まず記号列に対する条件を設定し、それぞれの位置でアクティブになるもの(「属性」とする)を調べる。これは一般の CRF と同じ。(一般のCRFで、属性を設定するやり方はテンプレートによるもの( CRF++ 等)とユーザに任せるもの(CRFSuite 等)がある。後者のほうがユーザが自由に条件を設定できるというメリットがあるが、プログラミングの知識が必要とされる)

一般の CRF では、属性とすべてのラベルを組み合わせて「状態素性関数」を作り、すべてのラベルとラベルの組み合わせに対して「遷移素性関数」を作る。(ここも実装によってバリエーションがある。属性とラベル遷移を組み合わせることができるもの(CRF++等)もあれば、純粋な遷移素性しか扱わないもの(CRFSuite等)もある)

可変長CRFでは、属性ごとに訓練データ中でそれと共に出現するラベル列を記録し、それに何らかの基準で足切りをして(出現回数・長さ等)、属性とそれと共に出現したラベル列(可変長)をセットにして素性関数を作る(ラベル列の長さが 1 ならば一般の CRF でいうところの「状態素性関数」となり、2 以上ならば「遷移素性関数」となる)。このように素性関数を作ることで、次数をいくら大きくしても、素性関数の数は訓練データのサイズに限定される。

このためには、可変長の素性に対する前向き・後ろ向きアルゴリズムが必要。それを自分なりに考えてみた。

先日研究会発表したので、そこで使ったスライドをアップロードしておく。Variable-Order CRFs

また、プロトタイプを作って、jsdo.itにアップロードした。

Variable length forward-backward algorithm - jsdo.it - Share JavaScript, HTML5 and CSS

JavaScript 1.7 を使っているため、対応ブラウザでないと見られない(FireFox で確認)。バージョン指定をするため、HTML領域にコードを書いている。また、大きな画面でないとレイアウトが崩れる。Python版ソースもこの記事の末尾に載せる。

クリックで 1ステップずつ進む。Ctrl を押しながらクリックすると速く進む。

スライドのほうに、このデモについての説明を書いてある。ここでも簡単に説明する。

"(BOS)", "This", "is", "good", "(EOS)" が記号列、"a1", "a2", "a3", "a4" が属性列、"X", "Y", "Z" がラベル。

ここでは、素性関数は属性とラベル列を適当に選んで設定している。実際の応用では、属性と出現したラベル列に何らかの方法(出現回数・長さなど)で足切りをして、組み合わせて素性関数を決める。

水色の背景の四角が「チェックポイント」、ピンクの背景が「パス」。パスは、その位置でアクティブになる素性関数を、それらが持つラベル列でまとめたもの。チェックポイントは、次の位置のパスが持つラベル列から、次の位置のラベルを取り除いたもの。

前向き部分の処理は次のようになる。

まず、今見ている記号の位置で終わる素性について、それぞれパスを作る(長さ 1 のものを含む)。

それから、次の記号の位置で終わる素性について、最後の 1記号分を切り落として短くして、それぞれチェックポイントを作る(重複するものはまとめる、長さゼロのものも作る)。

それらのチェックポイントとパスを、後ろから見たラベル順にソートする。片方が他方に含まれる場合は、長いものが先に来るようにする。

ソートされたチェックポイントとパスのうち、まずパスだけを逆順にたどる。

短いものから見ていくことになるので、短いパスの倍率を覚えておいて、それを含む長いパスについては、短いパスの倍率をかけて倍率を更新する(長いパスの中では短いパスでアクティブになる素性関数も同時にアクティブになるため。具体的な手続きはもう少し複雑なのでソース参照)。

パスの倍率の更新が終わったら、ソートされたチェックポイントとパスの両方を、今度は正順に(片方を含むものは長いものから)たどる。

パスの処理:

今見ている列のラベルを除いて、それに対応する前列で記録されたスコアを取得。そのパスを含む、より長いパスによってすでにスコアが使われている場合は、それを差し引く(これも複雑なのでソース参照)。

差し引いた残りのスコアにパスの倍率をかけたものを、パスのスコアとして記録。それを現在のスコアに加算。

チェックポイントの処理:

そのチェックポイントを含むより長いチェックポイントのスコアを累積したものに現在のスコアを加えて、それに対応するスコアとして記録。現在のスコアをクリア。

これを最後まで行う。

その次に、後ろ向き部分の処理。

ソートされたチェックポイントとパスを逆順にたどる。

チェックポイントの処理:

チェックポイントに対応する後列で記録されたスコアを取得、現在のスコアに加算。

パスの処理:

前向きで求めたスコアに現在のスコア(後ろ向きスコア)をかけ、それをパスのスコア(正規化すると確率になる)として記録。

現在のスコアにパスの倍率をかけたものを求める。そのスコアから、そのパスに含まれる、より短いパスを記録した時のスコアを差し引いて、そのパスから今見ている列のラベルを除いたものに対応するスコアに加算。

これらの処理が終わったら、パスだけを正順に(片方を含むものは長いものから)たどる。

長いパスのスコアを、それに含まれる短いパスに足し込んでいく(ここまでは、短いパスを含む長いパスのスコアは独立に求めたが、実際には長いパスのスコア(→後で正規化定数で割って確率にする)の分だけ、それに含まれる短いパスに対応する素性関数もアクティブになっているため、足し込んでやらないといけない。

これを先頭まで行う。

最後に、すべてのパスのスコアを正規化定数(前向き時の最終位置で、空チェックポイントに対応するスコア)で割ることによって、それぞれのパスの期待値が得られる。その期待値を、それらに関連づけられている素性関数の期待値に足す。

これを、対数尤度が収束するまで続ける。

デコードのアルゴリズムは次の通り。

可変長前向きアルゴリズムと同様にパスとチェックポイントを作る

チェックポイントでは、前のチェックポイントからの最大スコアのみを記録する

パスでは、対応する前のチェックポイントと、それを含むより長いチェックポイントの中で使用済みでないものから最大スコアを求め、パスの指数重みをかけてスコアを求める(使用済みフラグは、現在位置のラベルが変わるたびにクリア)

ただ、このアルゴリズムはどうもあまりエレガントじゃない気がする。考え中。

(読まれた方へ)このアルゴリズムPython 版を貼り付けておきます。JavaScript版、Python版ともに前向き・後ろ向きアルゴリズムの部分だけです。属性・素性関数は適当にダミーを与えています。また、デコード部分は作っていません。自分一人で考えていると不安なので、誰かに意見をいただければうれしく思います。

import sys, math, random, itertools

class Checkpoint(object):
    def __init__(self, labelSequence):
        self._labelSequenceTuple = tuple(labelSequence)

    def __repr__(self):
        return '<%s>' % (str(self._labelSequenceTuple))

    def getLabelSequenceTuple(self):
        return self._labelSequenceTuple

class Path(object):
    def __init__(self, labelSequence):
        self._labelSequenceTuple = tuple(labelSequence)
        self._rate = 1.0
        self._score = 0.0
        self._features = []

    def __repr__(self):
        return '<%s:%.1f/%.1f>' % (str(self._labelSequenceTuple),self._rate,self._score)

    def getLabelSequenceTuple(self):
        return self._labelSequenceTuple

    def getRate(self):
        return self._rate

    def setRate(self, rate):
        self._rate = rate

    def setScore(self, score):
        self._score = score

    def getScore(self):
        return self._score

    def addFeature(self, feature):
        self._rate *= feature.getExpWeight()
        self._features.append(feature)

def compareLabelSequence(a, b):
    labelSequenceTupleA = a.getLabelSequenceTuple()
    labelSequenceTupleB = b.getLabelSequenceTuple()
    lenA = len(labelSequenceTupleA)
    lenB = len(labelSequenceTupleB)
    for i in range(min(lenA, lenB)):
        if labelSequenceTupleA[-i-1] != labelSequenceTupleB[-i-1]:
            return cmp(labelSequenceTupleA[-i-1], labelSequenceTupleB[-i-1])
    if lenA != lenB:
        return cmp(lenB, lenA)
    return cmp(isinstance(a, Checkpoint), isinstance(b, Checkpoint))

def getCommonSuffixLen(labelSequenceTupleA, labelSequenceTupleB):
    matchLen = 0
    for i in range(min(len(labelSequenceTupleA), len(labelSequenceTupleB))):
        if labelSequenceTupleA[-i-1] != labelSequenceTupleB[-i-1]:
            break
        else:
            matchLen += 1
    return matchLen

class Feature(object):
    def __init__(self, attr, labelSequence, observationCount):
        self._attr= attr
        self._labelSequenceTuple = tuple(labelSequence)
        self.setWeight(0)
        self._observationCount = observationCount
        self._expectation = 0.0

    def __repr__(self):
        return '<%s:%s:%f>' % (self._attr, str(self._labelSequenceTuple), self._weight)

    def getAttr(self):
        return self._attr

    def getLabelSequenceTuple(self):
        return self._labelSequenceTuple

    def getWeight(self):
        return self._weight

    def getExpWeight(self):
        return self._expWeight

    def setWeight(self, weight):
        self._weight = weight
        self._expWeight = math.exp(weight)

    def setExpWeight(self, expWeight):
        self._expWeight = expWeight
        self._weight = math.log(expWeight)

    def setExpectation(self, expectation):
        self._expectation = expectation

class Accumulator(object):
    def __init__(self, multiplyMode=False, isFromLongerToShorter=True, discountMode=False):
        self._prevLabelSequenceTuple = None
        self._prev = 0.0
        self._stack = []
        self._multiplyMode = multiplyMode
        self._isFromLongerToShorter = isFromLongerToShorter
        self._discountMode = discountMode
        self._add = (lambda a, b : a * b) if multiplyMode else (lambda a, b : a + b)
        self._sub = (lambda a, b : a / b) if multiplyMode else (lambda a, b : a - b)
        self._calc = self._sub if discountMode else self._add

    def accumulate(self, labelSequenceTuple, value):
        sequenceLen = len(labelSequenceTuple)
        current = value
        if self._prevLabelSequenceTuple is not None:
            commonLen = getCommonSuffixLen(self._prevLabelSequenceTuple, labelSequenceTuple)

            if self._isFromLongerToShorter:
                if commonLen == sequenceLen:
                    while len(self._stack) > 0 and self._stack[-1][0] >= commonLen:
                        current = self._calc(current, (self._stack.pop())[1])
                    current = self._calc(current, self._prev)
                else:
                    self._stack.append((commonLen, self._prev))
            else:
                if commonLen == self._prevSequenceLen:
                    self._stack.append((commonLen, self._prev))
                while len(self._stack) > 0 and self._stack[-1][0] > commonLen:
                    self._stack.pop()
                if len(self._stack) > 0 and sequenceLen > commonLen:
                    current = self._calc(current, self._stack[-1][1])

        if self._discountMode:
            self._prev = value
        else:
            self._prev = current

        self._prevSequenceLen = sequenceLen
        self._prevLabelSequenceTuple = labelSequenceTuple

        return current

class DataSet(object):
    def __init__(self, text):
        self._data = []
        sequence = []
        for line in text.split("\n"):
            line = line.strip()
            if len(line) == 0:
                if len(sequence) > 0:
                    self._data.append(sequence)
                sequence = []
            else:
                attrs = line.split()
                label = attrs.pop(0)
                sequence.append((label, attrs))

        self._allFeatureList = []
        self._attrToFeatureMap = {}

        for feature in self.featureGenerator():
            feature.setWeight(random.uniform(-1, 1))
            self._allFeatureList.append(feature)
            attr = feature.getAttr()
            if attr not in self._attrToFeatureMap:
                self._attrToFeatureMap[attr] = []
            self._attrToFeatureMap[attr].append(feature)
        return

    def sequenceGenerator(self):
        for sequence in self._data:
            yield sequence

    def featureGenerator(self):
        # placeholder / should be extracted from the data set in real-world applications
        mapData = {
             "" : [["X"], ["Y"], ["Z"], ["EOS"], ["BOS", "X"], ["BOS", "X", "Y"],
              ["X", "Y"], ["X", "Y", "Z"], ["X", "Z"], ["Y", "X"], ["Y", "Z"], ["Z", "Y"],
             ["Y", "EOS"], ["Z", "EOS"]],
             "a1" : [["X"], ["Y"], ["Z"], ["BOS", "X"], ["X", "Y"]],
             "a2" : [["X"], ["Y"], ["Z"]],
             "a3" : [["Y"], ["Z"], ["Y", "Z"]],
             "a4" : [["Z", "EOS"]]
        }
        for (attr, labelSequences) in mapData.items():
            for labelSequence in labelSequences:
                yield Feature(attr, labelSequence, 1.0)

    def updateFeatureExpectation(self):
        for sequence in self.sequenceGenerator():
            self.forwardBackward(sequence)

    def forwardBackward(self, sequence):
        pathsAndCheckpoints = [0] * len(sequence)
        pathsAndCheckpoints[0] = []
        scoreDict = {("BOS",) : 1.0, () : 1.0}
        newScoreDict = {}

        for sequencePos in range(1, len(sequence) + 1):
            prevCheckpointLabelSequenceSet = set(
                [x.getLabelSequenceTuple() for x in pathsAndCheckpoints[sequencePos - 1]
                if len(x.getLabelSequenceTuple()) == 1])
            if sequencePos < len(sequence):
                pathsAndCheckpoints[sequencePos] = []
                isEOS = (sequencePos == len(sequence) - 1)
                attrs = sequence[sequencePos][1]
                pathDict = {}
                for attr in itertools.chain(("",), attrs):
                    for feature in self._attrToFeatureMap.get(attr, ()):
                        labelSequenceTuple = feature.getLabelSequenceTuple()
                        if labelSequenceTuple[-1] == "EOS" and not isEOS or \
                           labelSequenceTuple[-1] != "EOS" and isEOS or \
                           len(labelSequenceTuple) > sequencePos and labelSequenceTuple[-sequencePos-1] != "BOS" or \
                           len(labelSequenceTuple) <= sequencePos and labelSequenceTuple[0] == "BOS":
                            continue
                        path = pathDict.get(labelSequenceTuple, Path(labelSequenceTuple))
                        pathDict[labelSequenceTuple] = path
                        path.addFeature(feature)
                        prevCheckpointLabelSequenceSet.add(labelSequenceTuple[:-1])

                for path in pathDict.itervalues():
                    pathsAndCheckpoints[sequencePos].append(path)
            else:
                prevCheckpointLabelSequenceSet.add(())

            for labelSequence in prevCheckpointLabelSequenceSet:
                pathsAndCheckpoints[sequencePos - 1].append(Checkpoint(labelSequence))

            pathsAndCheckpoints[sequencePos - 1].sort(compareLabelSequence)

        # forward
        for sequencePos in range(1, len(sequence)):

            pathRateAccumulator = Accumulator(multiplyMode=True, isFromLongerToShorter=False,discountMode=False)

            for pathOrCheckpoint in reversed(pathsAndCheckpoints[sequencePos]):
                if not isinstance(pathOrCheckpoint, Path):
                    continue
                path = pathOrCheckpoint
                path.setRate(pathRateAccumulator.accumulate(path.getLabelSequenceTuple(), path.getRate()))

            pathScoreDiscounter = Accumulator(isFromLongerToShorter=True, discountMode=True)
            checkpointScoreAccumulator = Accumulator(isFromLongerToShorter=True, discountMode=False)

            score = 0.0
            for pathOrCheckpoint in pathsAndCheckpoints[sequencePos]:
                labelSequence = pathOrCheckpoint.getLabelSequenceTuple()
                if isinstance(pathOrCheckpoint, Path):
                    path = pathOrCheckpoint
                    path.setScore(pathScoreDiscounter.accumulate(labelSequence, scoreDict[labelSequence[:-1]]) * path.getRate())
                    score += path.getScore()
                else:
                    checkpoint = pathOrCheckpoint
                    newScoreDict[labelSequence] = checkpointScoreAccumulator.accumulate(labelSequence, score)
                    score = 0.0

            scoreDict = newScoreDict
            newScoreDict = {}

        # backward
        scoreDict = {() : 1.0}
        for sequencePos in reversed(range(len(sequence))):
            score = 0.0

            pathScoreDiscounter = Accumulator(isFromLongerToShorter=False, discountMode=True)
            checkpointScoreAccumulator = Accumulator(isFromLongerToShorter=False, discountMode=False)

            for pathOrCheckpoint in reversed(pathsAndCheckpoints[sequencePos]):
                labelSequence = pathOrCheckpoint.getLabelSequenceTuple()
                if isinstance(pathOrCheckpoint, Path):
                    path = pathOrCheckpoint
                    path.setScore(path.getScore() * score)
                    oldScore = newScoreDict.get(labelSequence[:-1], 0.0)
                    newScoreDict[labelSequence[:-1]] = oldScore + pathScoreDiscounter.accumulate(labelSequence[:-1], score * path.getRate())
                else:
                    score = checkpointScoreAccumulator.accumulate(labelSequence, scoreDict.get(labelSequence, 0.0))

            pathScoreAccumulator = Accumulator(isFromLongerToShorter=True, discountMode=False)

            for pathOrCheckpoint in pathsAndCheckpoints[sequencePos]:
                if not isinstance(pathOrCheckpoint, Path):
                    continue
                path = pathOrCheckpoint
                path.setScore(pathScoreAccumulator.accumulate(path.getLabelSequenceTuple(), path.getScore()))

            scoreDict = newScoreDict
            newScoreDict = {}

def main(argv):
    dataSet = DataSet("""
BOS
X a1 a2
Y a1
Z a3
EOS a4
    """)
    dataSet.updateFeatureExpectation()

main(sys.argv)