まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia 言語 テンソルリング分解による欠損補完

この論文に出てくる PTRC-RW を実装しました.

https://ieeexplore.ieee.org/document/9158539

Folding と Unfolding が肝で,よく分かんないので,stackoverflow の強い人に助けてもらいました.

arrays - The inverse operation of tensor circular unfolding in Julia - Stack Overflow

Stridedを使うとちょっと早くなります.本当は,permutedims は新しいテンソルを作っていて,実は view さえあれば計算はできると思うので,もうちょっと最適化できるかもしれないです.このあたりはややこしくて苦手です.ただ,実行時間は論文の図9くらいにはなったので,matlab と julia の環境の差はあれど,まぁ許容できるくらいには最適化できたのかなと納得して進むことにしました.

とりあえず,TCUとその逆作用.

using LinearAlgebra
using TensorToolbox
using Tullio
using Distributions
using Printf
using TransmuteDims
using Strided

function TCU(X,d,k)
    N = ndims(X)
    @assert d < N
    if d <= k
        a = k-d+1
    else
        a = k-d+1+N
    end
    @strided tmp = permutedims(X,circshift(1:N,-a+1))
    tenmat(tmp,row=1:d)
end

function invTCU(M,d,k, presize)
    N = length(presize)
    a = d<=k ? k-d+1 : k-d+1+N
    X = reshape(M,Tuple(circshift(collect(presize),1-a)))
    @strided permutedims(X,circshift(1:N,a-1))
end

アルゴリズム本体は以下.RWfalseにすると論文中のアルゴリズム1(PTRC)になります. Xが入力テンソルで,Mが欠損の位置を表すバイナリテンソルです.alphadはハイパラです.alphaは総和が1で大きさがテンソルのオーダーと同じベクトルです.dテンソルのオーダーよりも小さい整数です.

function PTRCRW!(X, M, R, alpha, d;RW=true, iter_max=1000, verbose=true, verbose_inval=5, Xgt=NaN, tol=1.0e-5)
    idxs_missing = findall( M .== 0 )
    N = ndims(X)
    J = size(X)

    Xd = Vector{Matrix{Float64}}(undef,N)
    W = Vector{Matrix{Float64}}(undef,N)
    H = Vector{Matrix{Float64}}(undef,N)
    if RW
        B = Vector{Matrix{Float64}}(undef,N)
    end

    for k = 1:N
        a = get_a(d,k,N)
        if a == 1
            s1 = R[k]*R[N]
        else
            s1 = R[k]*R[a-1]
        end

        if a == 1
            s2 = prod(J[k+1:N])
        elseif k+1 <= a-1
            s2 = prod(J[k+1:a-1])
        else
            s2 = prod(vcat(J[k+1:N]...,J[1:a-1]...))
        end
        H[k] = rand( s1, s2 )

        if RW
            Wdk = TCU(M,d,k)
            omega_kd = sum( Wdk )
            beta = zeros(size(Wdk)[1])
            for i = 1:size(Wdk)[1]
                omega_kdi = sum( Wdk[i,:] )
                beta[i] = max( omega_kdi/omega_kd ,1.0e-16)
            end
            B[k] = diagm(beta)
        end
    end

    X_pre = zeros( J )
    for iter = 1:iter_max
        for k = 1:N
            if iter == 1
                Xd[k] = TCU(X,d,k)
                W[k] = Xd[k] * H[k]'
            else
                Xd[k] .= TCU(X,d,k)
                W[k] .= Xd[k] * H[k]'
            end

            if RW
                H[k] .= ( W[k]'*B[k]*W[k] ) \ W[k]' * B[k] * Xd[k]
            else
                H[k] .= ( W[k]'*W[k] ) \ W[k]' * Xd[k]
            end

            Xd[k] .= W[k]*H[k]
        end

        foldM = zeros(J)
        for k = 1:N
            foldM .+= ( alpha[k] .* invTCU(Xd[k],d,k,J) )
        end
        X[idxs_missing] .= foldM[idxs_missing]

        if iter > 1
            diff_X = norm( X_pre .- X ) / norm(X_pre)

            if verbose && iter > 1
                if mod(iter, verbose_inval) == 0
                    if Xgt isa Array
                        rms = norm(X .- Xgt)/norm(Xgt)
                    else
                        rms = NaN
                    end
                    @printf("%4d %5f %5f \n", iter, diff_X, rms)
                end
            end

            if diff_X < tol
                return X
            end

            if iter == iter_max
                println("PTRCRW was not converged")
            end
        end
        X_pre .= X
    end

    return X
end

適当にダウンロードしてきた画像ファイルに欠損を90%与えてみたらこんな感じになりました.

using FileIO
function main()
    T = load("../../data/SIPI/tensor_SIPI.jld2")
    Xgt = T["Baboon"]
    #Xgt = reshape(Xgt, 16,16,16,16,3)
    Xgt = reshape(Xgt, 4,4,4,4,4,4,4,4,3)
    W = generate_weight(Xgt, sr=10)
    Xin = init_missing_val!(deepcopy(Xgt), W)

    D = ndims(Xin)
    alpha = ones(D)/D
    R = 5*ones(Int64,D)
    d = 3
    Xpre = PTRCRW!(deepcopy(Xin), W, R, alpha, d, RW=true, tol=1.0e-3, iter_max=1000, verbose=true, verbose_inval=20, Xgt=Xgt)

end

RSE(a,b) = norm(a - b) / norm(a)
main()

下がってます,下がってます.

iter     cost       RMSE
  20 0.007605 0.261927
  40 0.003536 0.215563
  60 0.002055 0.198852
  80 0.001495 0.190946
 100 0.001139 0.186659

使った,init_missing_val!generate_weightの定義はこちらをご参照ください.

https://genkaiphd.hatenablog.com/entry/2023/01/05/232531genkaiphd.hatenablog.com