まぃふぇいばりっと

退職D進して機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 非負値行列因子分解

ここにある更新式を見て,実装した.白抜きの丸はelement-wiseの積だと理解した.
https://www.jjburred.com/research/pdf/jjburred_nmf_updates.pdf

収束判定条件はsklearnのNMFの標準設定とできるだけ揃えた.
scikit-learn.org

以下実装.

using LinearAlgebra
using Random
Random.seed!(123)

function KL(A, B)
    n, m = size(A)
    kl = 0.0
    for i = 1:n
        for j = 1:n
            kl += A[i,j] * log( A[i,j] / B[i,j] ) - A[i,j] + B[i,j]
        end
    end
    return kl
end

function nmf_euc(X, r ; max_iter=200, tol = 0.0001)
    m, n = size(X)
    W = rand(m, r)
    H = rand(r, n)

    error_at_init = norm(X - W*H)
    previous_error = error_at_init
    for iter in 1:max_iter
        H .=  H .* ( W' * X ) ./ ( W' * W * H  )
        W .=  W .* ( X  * H') ./ ( W  * H * H' )

        if tol > 0 && iter % 10 == 0
            error = norm(X - W*H)
            if (previous_error - error) / error_at_init < tol
                break
            end
            previous_error = error
        end
    end

    return W, H
end

function nmf_kl(X, r ; max_iter=200, tol = 0.0001)
    m, n = size(X)
    W = rand(m, r)
    H = rand(r, n)
    one_mn = ones(m, n)

    error_at_init = KL(X, W*H)
    previous_error = error_at_init
    for iter in 1:max_iter
        println(KL(X, W*H))
        H .=  H .* ( W' * ( X ./ (W*H))) ./ ( W' * one_mn )
        W .=  W .* ( (X ./ (W*H) ) * H') ./ ( one_mn * H' )

        if tol > 0 && iter % 10 == 0
            error = KL(X, W*H)
            if (previous_error - error) / error_at_init < tol
                break
            end
            previous_error = error
        end
    end

    return W, H
end

更新するところ,H = hogehoge より,H.= hogehoge の方が微妙に速いって強いツイッタラーに教えてもらった.
Performance Tips · The Julia Language


ちなみに,sklearnでKL-NMFはこんな感じ.デフォルトのsklearnの設定より,上のJulia実装の方が,コスト関数を小さくできる(なぜ?)

import numpy as np
from math import log
from sklearn.decomposition import NMF

rank = 3
nmf = NMF(rank, beta_loss='kullback-leibler', alpha=0, verbose=1,solver='mu') # NMF
W = nmf.fit_transform(X)
H = nmf.components_
X_nmf = np.dot(W, H)


追記

αダイバージェンスでNMFできるように拡張した.
genkaiphd.hatenablog.com