まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 補助行列が2つある場合のNon-negative Multiple Matrix Factorization (NM2F)の実装

前回,補助行列が一つだけのNMMFの実装を公開した. 今回は同様に,これらの論文にのっとり,補助行列が2つある場合のNMMFを実装した. 行列X,Y,Zを,WH, AH,WBに同時に分解するというタスクである.コスト関数はKL情報量.(LSの場合の更新式ってどこかにない??)

www.ijcai.org

ci.nii.ac.jp

原著論文に忠実に,一様離散分布で初期化する実装は以下.

function cost(X,Y,Z,W,H,A,B,alpha,beta)
    d_kl(P,Q) = sum( P .* log.( P ./ Q ) ) - sum(P) + sum(Q)
    total_cost = d_kl(X, W*H) + alpha*d_kl(Y, A*H) + beta*d_kl(Z, W*B)
    return total_cost
end

"""
Non-negative Multiple Matrix Factorization (NM2F)
See [original paper 1](https://www.ijcai.org/proceedings/2017/407),
[original paper 2](https://ci.nii.ac.jp/naid/110009691643)
# Aruguments
- 'X' target matrix. Factorized to the bases W and the coefficients H.
- 'Y' auxiliary feature. Factorized to A*H
- 'Z' auxiliary data. Factorized to W*B
- 'K' target rank, int.
- 'alpha' scale of auxiliary feature
- 'beta'  scale of auxiliary data
"""
function NM2F(X, Y, Z, K; alpha=1.0, beta=1.0, max_iter = 200, tol = 1.0E-4, verbose = true)
    W = rand(I, K)
    H = rand(K, J)
    A = rand(N, K)
    B = rand(K, M)
    Xhat = W*H
    Yhat = A*H
    Zhat = W*B
    
    error_at_init = cost(X,Y,Z,W,H,A,B,alpha,beta)
    previous_error = error_at_init
    for iter = 1 : max_iter
        W .*= ( (X ./ Xhat) * H' .+ beta * (Z ./ Zhat) * B') ./ ( ones(I,J) * H' .+ beta * ones(I, M) * B'  )
        Xhat = W*H
        Zhat = W*B
        
        H .*= ( W' * (X ./ Xhat) .+ alpha * A'*(Y ./ Yhat) ) ./ ( W' * ones(I,J) .+ alpha * A' * ones(N, J) )
        Xhat = W*H
        Yhat = A*H
        
        A .*= ( (Y ./ Yhat) * H'  ) ./ ( ones(N,J) * H'  )
        Yhat = A*H
        
        B .*= (  W' * (Z ./ Zhat) ) ./ ( W' * ones(I, M) )
        Zhat = W*B
        
        error = cost(X,Y,Z,W,H,A,B,alpha,beta)
        if verbose && iter % 20 == 1
            println("iter:$iter cost:$error")
        end
        if (previous_error - error) / error_at_init < tol
            break
        end
        previous_error = error
    end
    return W, H, A, B
end

I = 6
J = 5
M = 4
N = 3

X = rand(I, J)
Y = rand(N, J)
Z = rand(I, M)
W,H,A,B = NM2F(X,Y,Z,2);

Kはターゲットランク. このプログラムを実行すると,コスト関数がどんどん下がっていく様子が分かる.

iter:1 cost:6.1250592119187575
iter:21 cost:4.615593687122641
iter:41 cost:4.465480179778419
iter:61 cost:4.150666236462589
iter:81 cost:3.854919025274069