用pytorch及numpy计算成对余弦相似性矩阵,并用numpy实现kmeans聚类

艾成益
2023-12-01

  sklearn和scipy里面都提供了kmeans聚类的库,但是它们都是根据向量直接进行计算欧氏距离、闵氏距离或余弦相似度,如果使用其他的度量函数或者向量维度非常高需要先计算好度量距离然后再聚类时,似乎这些库函数都不能直接实现,于是我用numpy自己写了一个,运行也非常快。这里记录下来以后备用:

import numpy as np
import matplotlib.pyplot as plt
import time
t0 = time.time()

Num = 512
corr = np.load('corrs20000.npy')  #相关系数矩阵
u = np.arange(Num)                #设置初始中心点
#u = np.random.choice(20000,Num,replace=False)
for n in range(1000):                  #设置1000次循环
    cluster = [[v] for v in u]         #每个簇放在一个列表中,总体再有一个大列表存放
    others = np.array([v for v in range(20000) if v not in u])  #其他未归类的点
    temp = corr[:,u]
    temp = temp[others,:]              #通过两步提取出所有其他未归类点和各中心点的子相关矩阵
    inds = temp.argmax(axis=1)         #计算每个未归类点与各中心点的最大关系那个点的序号
    new_u = []
    for i in range(Num):              #对每个簇分别计算(暂未想到矢量化方法)
        ind = np.where(inds==i)[0]     #提取各簇中所有新点在未归类点中的序号
        points = others[ind]           #根据序号查找对应的未归类点实际编号
        cluster[i] = cluster[i] + points.tolist()  #把本簇未归类点加入到簇中
        temp = corr[cluster[i],:]
        temp = temp[:,cluster[i]]      #通过两步计算提取本簇各点子相关矩阵
        ind_ = temp.sum(axis=0).argmax()   #计算各点和其他各点的总相关系数之和,取最大的一个的序号
        ind_new_center = cluster[i][ind_]  #根据序号转换为实际编号,得到新的本簇中心点
        new_u.append(ind_new_center)       #加入到新中心点向量
    new_u = np.asarray(new_u,dtype=np.int32)
    if (new_u==u).sum() == Num:            #如果新的中心点向量已不再变化,停止循环
        break
    print(n,(new_u==u).sum(),time.time()-t0)
    u = new_u.copy()                   #计算全部结束后得到cluster就是各簇的点集和,u是中心点向量

--------------------------后续补充:
  然而,快速计算一组向量的自相关性矩阵或者两组向量的相互成对相关系数矩阵也是很常用的,在pytorch中用torch.cosine_similarity只能计算两个向量间的,不能批量整体处理,如果循环计算,或者把向量通过repeat方法扩展显然计算速度比较慢。这里给出一种使用torch.matmul批量计算的方法,可以在cuda中计算,速度非常快。记录备用:

def pairs_cosineSimilarity_matrix_pytorch(v1,v2):
    #v.shape = (N, vector_dims)
    v1 = v1.permute(1,0).unsqueeze(2).float()
    v2 = v2.permute(1,0).unsqueeze(1).float()
    
    part1 = torch.matmul(v1,v2).sum(0)
    part2 = torch.matmul(v1.pow(2).sum(0).pow(0.5),v2.pow(2).sum(0).pow(0.5))
    
    return part1 / (part2+1e-15)
 类似资料: