まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 二重確率行列分解

非負の対称行列を低ランク二重確率行列で近似するタスクに関する論文を読んだ. ここでいう近似は,KL情報量の意味での近似.

jmlr.org

論文中のAlgorithm 1を実装した.

using LinearAlgebra

function D_KL(A,B)
    (N, M) = size(A)

    dkl = 0.0
    for i=1:N
        for j=1:M
            dkl += A[i,j] * log( A[i,j] / B[i,j] )
        end
    end
    dkl -= sum(A)
    dkl += sum(B)

    return dkl
end

function getAhat(W)
    (N, r) = size(W)
    Ahat = zeros(N, N)
    for i = 1 : N
        for j = 1 : N
            term = 0
            for k = 1 : r
                term += W[i,k] * W[j,k] / sum(W[:,k])
            end
            Ahat[i,j] = term
        end
    end
    return Ahat
end

function DCD(A, W, r; alpha=1.0, cnt_max=20, verbose=true)
    N = size(A)[1]
    Z = zeros(N, N)

    cnt = 0
    while true
        for i = 1 : N
            for j = 1 : N
                term = 0.0
                for k = 1 : r
                    term += W[i,k] * W[j,k] / sum( W[:,k] )
                end
                Z[i, j] = term^(-1) * A[i, j]
            end
        end

        s = zeros(r)
        for k = 1 : r
            s[k] = sum( W[:, k] )
        end

        nabla_m = zeros(N, r)
        nabla_p = zeros(N, r)

        ZW = Z*W
        WtZW = W'*Z*W
        for i = 1 : N
            for k = 1 : r
                nabla_m[i,k] = 2.0 * ZW[i,k] * 1.0 / s[k] + alpha * 1.0 / W[i,k]
                nabla_p[i,k] = WtZW[k,k] * 1.0 / ( s[k]^2 ) + 1.0 / W[i,k]
            end
        end

        a = zeros(N)
        b = zeros(N)
        term1 = W ./ nabla_p
        term2 = nabla_m ./ nabla_p
        for i = 1 : N
            a[i] = sum( term1[i,:] )
            for l = 1:r
                b[i] += W[i,l] * term2[i,l]
            end
        end

        for i = 1 : N
            for k = 1 : r
                W[i,k] = W[i,k] *  (nabla_m[i,k] * a[i] + 1.0 ) / (nabla_p[i,k] * a[i] + b[i])
            end
        end

        if verbose
            Ahat = getAhat(W)
            dkl = D_KL(A, Ahat)
            println("cnt $cnt D_KL(A; Ahat) $dkl")
        end

        if cnt == cnt_max
            break
        end
        cnt += 1
    end
    Ahat = getAhat(W)
    return W, Ahat
end

function main()
    n = 10
    r = 2
    cnt_max = 10
    A = rand(n, n)
    A = Symmetric(A)
    A = n * A ./ sum(A)
    W_init = 0.5 * rand(n, r)

    W, Ahat = DCD(A, W_init, r, cnt_max=cnt_max)
end

main()