SparrowRecSys icon indicating copy to clipboard operation
SparrowRecSys copied to clipboard

graph embedding中transitionCountMatrix的计算问题

Open vcicii opened this issue 2 years ago • 0 comments

源代码:

def generateTransitionMatrix(samples):
    pairSamples = samples.flatMap(lambda x: generate_pair(x))
    pairCountMap = pairSamples.countByValue()
    pairTotalCount = 0
    transitionCountMatrix = defaultdict(dict)
    itemCountMap = defaultdict(int)
    for key, cnt in pairCountMap.items():
        key1, key2 = key

        # 此处是否应该改为 += cnt
        transitionCountMatrix[key1][key2] = cnt
        itemCountMap[key1] += cnt
        pairTotalCount += cnt
        ......

修改:

    for key, cnt in pairCountMap.items():
        key1, key2 = key

        if key1 not in transitionCountMatrix or key2 not in transitionCountMatrix[key1]:
            transitionCountMatrix[key1][key2] = cnt
        else:
            transitionCountMatrix[key1][key2] += cnt

        itemCountMap[key1] += cnt
        pairTotalCount += cnt

vcicii avatar Jun 16 '22 15:06 vcicii