まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 Non-Negative Tucker Decomposition の実装 [LSエラー版]

↓ こっちのほうが役に立つと思う

genkaiphd.hatenablog.com

↑ こっちのほうが役に立つと思う

研究に必要なので,非負テンソルのタッカー分解の論文をjuliaで実装した.とりあえず,コスト関数がLSエラーのものを実装した.KLダイバージェンスの場合についてはまた明日,実装する予定.

ieeexplore.ieee.org

論文中にいくつかトラップがある.

・式(17)は間違い.(まぁ,これは実装には関係ないけど.)

・Table4の S←Rand(R,m)はS←Rand(R,l)が正しいと思われる.

・Table4, Table5に無定義で出てくる行列の肩の+は転置Tの誤植と思われる.

あと,式(22)を使うと計算効率がよくなるっぽい.

using LinearAlgebra
using TensorToolbox

function reconst(S, A)
    N = ndims(S)
    Xr = ttm(S, A[1], 1)
    for n=2:N
        Xr = ttm(Xr, A[n], n)
    end
    return Xr
end

function init_random(tensor_size, reqrank)
    A = []
    N = length(reqrank)
    for n = 1:N
        An = rand(tensor_size[n], reqrank[n])
        push!(A, An)
    end
    S = rand( reqrank... )
    return S, A
end

function init_for_NMF(X, R, eps=1.0e-10, iter_max_inloop=40)
    (m, l) = size(X)
    A = rand(m, R)
    S = rand(R, l)

    for iter = 1:iter_max_inloop
        A .= max.(X*S', eps)
        A .= A .* ((X*S') ./ (A*S*S'))
        S .= max.(A'*X, eps)
        S .= S .* ( (A'*X) ./ (A'*A*S) )
    end
    return S, A
end

function init_based_NMF(X, reqrank, eps=1.0e-10, iter_max=40)
    N = ndims(X)
    A = []
    for n=1:N
        _, An = init_for_NMF(tenmat(X, n), reqrank[n])
        push!(A, An)
    end

    S = rand( reqrank... )
    for iter = 1:iter_max
        SAns = []
        for n=1:N
            SAn = S
            for m=1:N
                if m != n
                    SAn = ttm(SAn, A[m], m)
                end
            end
            SAn = tenmat(SAn, n)
            push!(SAns, SAn)
            A[n] = A[n] .* ( ( tenmat(X,n) * SAns[n]' ) ./ (A[n] * SAns[n] * SAns[n]' ))
        end
        S = X
        for m=1:N
            S = ttm(S, A[m]', m)
        end
        S .= max.(S, eps)

        numerator = X
        denominator = S
        for m=1:N
            numerator = ttm(numerator, A[m]', m)
            denominator = ttm(denominator, A[m]'*A[m], m)
        end
        S .= S .* ( numerator ./ denominator )
    end
    return S, A
end

function update(X, S, A, max_iter=100, verbose=true, verbose_interval=20)
    N = ndims(S)

    cnt_iter = 0
    while(cnt_iter < max_iter)
        ############
        # update A #
        ############
        SAns = []
        for n=1:N
            # get SAn
            SAn = S
            for m=1:N
                if m != n
                    SAn = ttm(SAn, A[m], m)
                end
            end
            SAn = tenmat(SAn, n)
            push!(SAns, SAn)
            A[n] = A[n] .* ( ( tenmat(X,n) * SAns[n]' ) ./ (A[n] * SAns[n] * SAns[n]' ))
        end

        ########################
        # update Core tensor S #
        ########################
        numerator = X
        denominator = S
        for m=1:N
            numerator = ttm(numerator, A[m]', m)
            denominator = ttm(denominator, A[m]'*A[m], m)
        end
        S .= S .* ( numerator ./ denominator )

        if verbose && (cnt_iter % verbose_interval == 0)
            Xr = reconst(S, A)
            cost = norm(X - Xr)
            println("$cnt_iter $cost")
        end
        cnt_iter += 1
    end

    return S, A
end

"""
Non-Negative Tucker Decomposition
proposed by Young-Deok Kim et al.
See [original paper](https://ieeexplore.ieee.org/document/4270403)

# Aruguments
- `X` : input non-negative tensor
- `reqrank` : Target Tucker rank, array
- `init_method` : initial values, "NMF" or "random"
"""
function NTD(X, reqrank ; init_method="NMF", max_iter=400, verbose=true, verbose_interval=20)
    @assert ndims(X) == length(reqrank)
    tensor_size = size(X)

    #A[n] \in R^( tensor_size[n] \times reqrank[n] )
    #S \in R^(reqrank[1] \times ... \times reqrank[N])
    if init_method == "random"
        S, A = init_random(tensor_size, reqrank)
    elseif init_method == "NMF"
        S, A = init_based_NMF(X, reqrank)
    else
        error("no init method ", init_method)
    end

    S, A = update(X, S, A, max_iter, verbose, verbose_interval)
    Xr = reconst(S, A)

    return S, A, Xr
end

実装に丸一日かけてしまった.