ここにある更新式を見て,実装した.白抜きの丸はelement-wiseの積だと理解した.
https://www.jjburred.com/research/pdf/jjburred_nmf_updates.pdf
収束判定条件はsklearnのNMFの標準設定とできるだけ揃えた.
scikit-learn.org
以下実装.
using LinearAlgebra using Random Random.seed!(123) function KL(A, B) n, m = size(A) kl = 0.0 for i = 1:n for j = 1:m kl += A[i,j] * log( A[i,j] / B[i,j] ) - A[i,j] + B[i,j] end end return kl end function nmf_euc(X, r ; max_iter=200, tol = 0.0001) m, n = size(X) W = rand(m, r) H = rand(r, n) error_at_init = norm(X - W*H) previous_error = error_at_init for iter in 1:max_iter H .= H .* ( W' * X ) ./ ( W' * W * H ) W .= W .* ( X * H') ./ ( W * H * H' ) if tol > 0 && iter % 10 == 0 error = norm(X - W*H) if (previous_error - error) / error_at_init < tol break end previous_error = error end end return W, H end function nmf_kl(X, r ; max_iter=200, tol = 0.0001) m, n = size(X) W = rand(m, r) H = rand(r, n) one_mn = ones(m, n) error_at_init = KL(X, W*H) previous_error = error_at_init for iter in 1:max_iter println(KL(X, W*H)) H .= H .* ( W' * ( X ./ (W*H))) ./ ( W' * one_mn ) W .= W .* ( (X ./ (W*H) ) * H') ./ ( one_mn * H' ) if tol > 0 && iter % 10 == 0 error = KL(X, W*H) if (previous_error - error) / error_at_init < tol break end previous_error = error end end return W, H end
更新するところ,H = hogehoge より,H.= hogehoge の方が微妙に速いって強いツイッタラーに教えてもらった.
Performance Tips · The Julia Language
ちなみに,sklearnでKL-NMFはこんな感じ.デフォルトのsklearnの設定より,上のJulia実装の方が,コスト関数を小さくできる(なぜ?)
import numpy as np from math import log from sklearn.decomposition import NMF rank = 3 nmf = NMF(rank, beta_loss='kullback-leibler', alpha=0, verbose=1,solver='mu') # NMF W = nmf.fit_transform(X) H = nmf.components_ X_nmf = np.dot(W, H)
追記
αダイバージェンスでNMFできるように拡張した.
genkaiphd.hatenablog.com