すごく速くNMFができるらしい2012年の論文を実装しました.
epsilon を 0 にするとロスがNaNになるので注意.しばしばロスがNaNになるので使うときはverbose=trueにして様子を見ること推奨.
using LinearAlgebra using Random using Arpack Random.seed!(123) function lranmf_mu(X, r ; max_iter=200, tol = 1.0E-3, verbose = false) # lranmf_mu is effective when r << min(n,m) # input is a onnegative matrix # proposed by Guoxu Zhou in 2012 # https://ieeexplore.ieee.org/document/6166354 n, m = size(X) epsilon = 0.0001 # Step1 in section II-B. if r == min(n, m) svd_X = svd(X) else svd_X = svds(X; nsv=r, ritzvec=true)[1] end Achil = svd_X.U * diagm(svd_X.S) Bchil = svd_X.V # Step2 in section II-B. A = rand(n, r) B = rand(m, r) cost_at_init = norm(X - A*B') previous_cost = cost_at_init for iter = 1:max_iter B .= B .* ( max.( Bchil*(Achil' * A), epsilon ) ) ./ ( B*(A'*A) ) A .= A .* ( max.( Achil*(Bchil' * B), epsilon ) ) ./ ( A*(B'*B) ) if tol > 0 && iter % 10 == 0 cost = norm(X - A*B') if verbose println("iter: $iter cost: $cost") end if (previous_cost - cost) / cost_at_init < tol break end previous_cost = cost end end # A * B' is rank-r matrix return A, B' end
これで速くなるのかとこないだ実装した普通のNMF(論文中のNMF_MU)と比較してみた. 普通のNMFの実装は以下.ロスはL2.乱数のseedを固定した.
using LinearAlgebra using Random using Arpack Random.seed!(123) function nmf_euc(X, r ; max_iter=200, tol = 1.0E-3, verbose = false) m, n = size(X) W = rand(m, r) H = rand(r, n) error_at_init = norm(X - W*H) previous_error = error_at_init for iter = 1:max_iter H .= H .* ( W' * X ) ./ ( W' * W * H ) W .= W .* ( X * H') ./ ( W * H * H' ) if tol > 0 && iter % 10 == 0 error = norm(X - W*H) if verbose println("iter: $iter cost: $error") end if (previous_error - error) / error_at_init < tol break end previous_error = error end end return W, H end function main() X = rand(5000, 5000) r = 5 mxcnt = 10 runtime = 0.0 loss = 0.0 for i = 1:mxcnt runtime += @elapsed begin A, B = nmf_euc(X, r, verbose=false, tol=1.0E-4) end loss += norm(X-A * B) end ave_runtime = runtime / mxcnt ave_loss = loss / mxcnt println("nmf_euc runtime $ave_runtime") println("nmf_euc loss $ave_loss") runtime = 0.0 loss = 0.0 for i = 1:mxcnt runtime += @elapsed begin A, B = lranmf_mu(X, r, verbose=false, tol=1.0E-4) end loss += norm(X - A * B) end ave_runtime = runtime / mxcnt ave_loss = loss / mxcnt println("lranmf runtime $ave_runtime") println("lranmf loss $ave_loss") end main()
mxcnt回繰り返して,経過時間の平均をとってる.lossの評価は,本当はこんな雑にはかっちゃだめだけど,乱数シード固定しているから,まぁいいでしょう.結果.
nmf_euc runtime 35.48805575 nmf_euc loss 1443.2935317424924 lranmf runtime 9.62555351 lranmf loss 1444.506586518955
ターゲットランクrが小さいとlraNMF_muが十分に速いことがわかった.なので,rが小さい場合は,lra-NMFを使えば良いのではないでしょうか!ちなみにtol=1.0E-3
くらいだとあんまり大差ない.