まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 Non-Negative Tucker Decomposition の実装 [完全版] LSエラー, KLエラー

昨日の記事の完全版.Nonnegative Tucker Decomposition の Julia 実装. 重要なアルゴリズムと思うけど,Julia版は探しても見つからなかったので,自分で書いた.

ieeexplore.ieee.org

Python版はすでに実装があるみたい.これはSVDで初期化しているっぽいね.

github.com

論文中にいくつかトラップがある.(追記)

・式(17)は間違い. An-1 ... A1 \Otimes AN ... An+1が正しい.(まぁ,これは実装には関係ないけど.)

・Table4の S←Rand(R,m)はS←Rand(R,l)が正しいと思われる.

・Table4, Table5に無定義で出てくる行列の肩の+は転置Tの誤植と思われる.(じゃないと行列の積が演算できない)

・更新は,A1→A2→…→AN→S→A1...の順でやる.A1→S→A2→S→A3→....なのか迷ったけど,論文にはIn updating A, the rest of parameters denoted by Sn are fixedとあるので,前者で正しいと思う.

・式(27)のεは入力Xと同じサイズのテンソル

あと,式(22)を使うと計算効率がよくなるらしいが,これ,両辺,行列のサイズ違くないか? 何度か計算したがつじつまが合いそうになかったで,とりあえず式(22)を使わない実装をした.(僕が勘違いしているようだったらコメントで教えてほしい)

初期化アルゴリズムは,論文に倣った”NMF"と"Random"の二つを用意した. 適当にrand(48,48,24)でテンソルつくって,(24,24,30)にランクを落としてみたら,NMFで初期化するとコスト関数は初めからかなり下がりきっているようだった.Figure1.を見ると,Iterations in initialization step are also added in these error traceとある.Figure1.のstep 1-50 くらいでかなりコスト関数が落ちているように思うので,初期化でかなり収束に近づいているっぽい?

using LinearAlgebra
using TensorToolbox

function cost_function(cost, X, Xr)
    if cost == "LS"
        return norm(X-Xr)
    elseif cost == "KL"
        return sum( X .* log.( X ./ Xr ) ) - sum(X) + sum(Xr)
    end
end

function reconst(S, A)
    N = ndims(S)
    Xr = S
    for n=1:N
        Xr = ttm(Xr, A[n], n)
    end
    return Xr
end

function init_random(tensor_size, reqrank)
    A = []
    N = length(reqrank)
    for n = 1:N
        An = rand(tensor_size[n], reqrank[n])
        push!(A, An)
    end
    S = rand( reqrank... )
    return S, A
end

"""
NMF Initializer. See Table 4 in the paper.
"""
function init_for_NMF(X, R; eps=1.0e-10, pre_iter_max_inloop=40)
    (m, l) = size(X)
    A = rand(m, R)
    S = rand(R, l)

    for iter = 1:pre_iter_max_inloop
        A .= max.(X*S', eps)
        A .= A .* ((X*S') ./ (A*S*S'))
        S .= max.(A'*X, eps)
        S .= S .* ( (A'*X) ./ (A'*A*S) )
    end
    return S, A
end

"""
Initializer based on NMF. See Table 5 in the paper.
"""
function init_based_NMF(X, reqrank; eps=1.0e-10, pre_iter_max=40, pre_iter_max_inloop=40)
    N = ndims(X)
    A = []
    for n=1:N
        _, An = init_for_NMF(tenmat(X, n), reqrank[n], pre_iter_max_inloop=pre_iter_max_inloop)
        push!(A, An)
    end

    S = rand( reqrank... )
    for iter = 1:pre_iter_max
        for n=1:N
            SAn = S
            for m in [1:n-1; n+1:N]
                SAn = ttm(SAn, A[m], m)
            end
            SAn = tenmat(SAn, n)
            A[n] = A[n] .* ( ( tenmat(X,n) * SAn' ) ./ (A[n] * SAn * SAn' ))
        end
        S = X
        for m=1:N
            S = ttm(S, A[m]', m)
        end
        S .= max.(S, eps)

        numerator = X
        denominator = S
        for m=1:N
            numerator = ttm(numerator, A[m]', m)
            denominator = ttm(denominator, A[m]'*A[m], m)
        end
        S .= S .* ( numerator ./ denominator )
    end
    return S, A
end

function update_An_LS(A, n, SAn, X)
    An = A[n] .* ( ( tenmat(X,n) * SAn' ) ./ (A[n] * SAn * SAn' ))
   return An
end

function update_An_KL(A, n, SAn, X)
    tensor_size = size(X)
    z = sum(SAn, dims=2)
    An = A[n] .* ((( tenmat(X,n) ./ ( A[n] * SAn ) ) * SAn' ) ./ ( ones(tensor_size[n]) * z' ))
    return An
end

function update_S_LS(S, A, X)
    N = ndims(S)
    numerator = X
    denominator = S
    for m=1:N
        numerator = ttm(numerator, A[m]', m)
        denominator = ttm(denominator, A[m]'*A[m], m)
    end
    S = S .* ( numerator ./ denominator )
    return S
end

function update_S_KL(S, A, X)
    N = ndims(S)
    numerator = X ./ reconst(S, A)
    denominator = ones( size(X)... )
    for m=1:N
        numerator = ttm(numerator, A[m]', m)
        denominator = ttm(denominator, A[m]', m)
    end
    S = S .* ( numerator ./ denominator )
    return S
end

function update(X, S, A, cost, max_iter=100, verbose=true, verbose_interval=20)
    N = ndims(S)

    cnt_iter = 0
    while(cnt_iter < max_iter)
        ############
        # update A #
        ############
        for n=1:N
            # get SAn
            SAn = S
            for m in [1:n-1;n+1:N]
                SAn = ttm(SAn, A[m], m)
            end
            SAn = tenmat(SAn, n)
            if cost == "LS"
                A[n] .= update_An_LS(A, n, SAn, X)
            elseif cost == "KL"
                A[n] .= update_An_KL(A, n, SAn, X)
            end
        end

        ########################
        # update Core tensor S #
        ########################
        if cost == "LS"
            S .= update_S_LS(S, A, X)
        elseif cost == "KL"
            S .= update_S_KL(S, A, X)
        end

        if verbose && (cnt_iter % verbose_interval == 0)
            Xr = reconst(S, A)
            error = cost_function(cost, X, Xr)
            println("$cnt_iter $error")
        end
        cnt_iter += 1
    end
    return S, A
end

"""
Non-Negative Tucker Decomposition
proposed by Young-Deok Kim et al.
See [original paper](https://ieeexplore.ieee.org/document/4270403)

# Aruguments
- `X` : input non-negative tensor
- `reqrank` : Target Tucker rank, array
- `init_method` : initial values, "NMF" or "random"
- `cost` : cost function, "LS" or "KL"
- `verbose` : true or false
- `pre_iter_max` : iter_max of initialization based on "NMF"
- `pre_iter_max_inloop` : iter_max of initialization of "NMF"
"""
function NTD(X, reqrank ;
        cost="LS", init_method="NMF", max_iter=500, verbose=true, verbose_interval=50,
        pre_iter_max=40, pre_iter_max_inloop=40)

    @assert ndims(X) == length(reqrank)
    tensor_size = size(X)

    #A[n] \in R^( tensor_size[n] \times reqrank[n] )
    #S \in R^(reqrank[1] \times ... \times reqrank[N])
    if init_method == "random"
        S, A = init_random(tensor_size, reqrank)
    elseif init_method == "NMF"
        S, A = init_based_NMF(X, reqrank, pre_iter_max=pre_iter_max, pre_iter_max_inloop=pre_iter_max_inloop)
    else
        error("no init method ", init_method)
    end

    S, A = update(X, S, A, cost, max_iter, verbose, verbose_interval)
    Xr = reconst(S, A)
    return S, A, Xr
end

Xrが再構成後の低ランクテンソルmrank(Xr)でXrのランクが落ちていることを確認できる.