まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 確率行列分解 Left-Stochastic Matrix Factorization

左確率行列分解(Left-Stochastic Matrix Factorization)というタスクがある.NMFに対称性と列規格化を要請する.これでうまいことクラスタリングができる. 研究で比較実験を行うためにクラスタ数2の場合のみを実装した.

jmlr.org

アルゴリズム中に確率単体(simplex)への射影が必要になるんだけど,論文中には具体的にどういうアルゴリズムを使って,simplexへ射影したかは書いてなかったので,ひとまずこのアルゴリズムを実装しました. [1101.6081] Projection Onto A Simplex

ちなみに,このコードはMichael Friedlanderさんのリポジトリにありものを拝借しています.これでexactな射影になっているはず.

github.com

function projsplx!(b)
    n = length(b)
    bget = false

    idx = sortperm(b, rev=true)
    tsum = 0
    @inbounds for i = 1:n-1
        tsum += b[idx[i]]
        tmax = (tsum - 1.0)/i
        if tmax ≥ b[idx[i+1]]
            bget = true
            break
        end
    end

    if !bget
        tmax = (tsum + b[idx[n]] - 1.0) / n
    end

    @inbounds for i = 1:n
        b[i] = max(b[i] - tmax, 0)
    end
end

あとは論文通りに実装するだけ. 入力Kは対称な正方非負行列.サンプル同士のsimilarity(サンプル同士の距離とか)が入っている.行列の大きさはサンプル数nに関して,n×n. 出力されるQ_triは,(i,j)成分にサンプルjがクラスiにいる確率が格納されている.

function getM(K, r=2)
    svd = svds(K ; nsv=r, ritzvec=true)[1]
    #Ar = svd.U * diagm(svd.S) * svd.Vt
    M = sqrt.(diagm(svd.S)) * svd.U'
    return M
end

function LSD(K)
    """
    input : K is n \times n symmetric non-negative mtx
    """
    n = size(K)[1]
    r = 2

    # M \in R^( r \times n )
    # M'M is the best approximation of K in terms of Fnorm
    M = getM(K)
    display(M)

    #c = norm( inv( M*M' ) * M * ones(n) )^2 / r
    #K = c * K

    # m \in R^k
    # mhat \in R^(k \times n)
    m = inv(M*M') * M * ones(n)
    abs_m = norm(m)

    Mhat = ( Matrix{Float64}(I, r, r) - m * m'./ abs_m ) * M
    Mhat = Mhat + 1.0 / ( sqrt(r) * abs_m) * ( m * ones(1, n) )

    # get RG
    u = ones(r) ./ sqrt(r)
    utm = u' * m
    RG = utm * Matrix{Float64}(I, r, r)
    RG[1, 2] = -1 + utm^2
    RG[2, 1] = -RG[1, 2]

    # get RS
    v = u - utm * m ./ (abs_m)^2
    abs_v = norm(v)
    U = [ m ./ abs_m  v ./ abs_v ]
    Rs = U * RG * U'

    Q = Rs * Mhat
    Q_tri = zeros(r, n)
    for i = 1:n
        tmp = Q[:, i]
        projsplx!(tmp)
        Q_tri[:, i] = tmp
    end

    return Q_tri
end