Kmeans聚类 使用Pytorch和GPU加速

罗毅
2023-12-01

目标

sklearn库里面的kmeans算法默认运行在gpu上,运行效率较低。有时候需要在网络内动态的对特征进行分聚类。如果有基于Pytorch Tensor的kmeans实现则可以极大提升效率。

经过检索发现ContrastiveSceneContexts有类似实现,可以参考该实现:

环境

pytorch, pykeops

pip install pykeops  -i https://pypi.tuna.tsinghua.edu.cn/simple

代码

import os
import torch
import numpy as np
import glob
import time
import argparse
import pykeops
from pykeops.torch import LazyTensor
pykeops.clean_pykeops() 

def kmeans(pointcloud, k=10, iterations=10, verbose=True):
    n, dim = pointcloud.shape  # Number of samples, dimension of the ambient space
    start = time.time()
    clusters = pointcloud[:k, :].clone()  # Simplistic random initialization
    pointcloud_cuda = LazyTensor(pointcloud[:, None, :])  # (Npoints, 1, D)

    # K-means loop:
    for _ in range(iterations):
        clusters_previous = clusters.clone()
        clusters_gpu = LazyTensor(clusters[None, :, :])  # (1, Nclusters, D)
        distance_matrix = ((pointcloud_cuda - clusters_gpu) ** 2).sum(-1)  # (Npoints, Nclusters) symbolic matrix of squared distances
        cloest_clusters = distance_matrix.argmin(dim=1).long().view(-1)  # Points -> Nearest cluster

        # #points for each cluster
        clusters_count = torch.bincount(cloest_clusters, minlength=k).float()  # Class weights
        for d in range(dim):  # Compute the cluster centroids with torch.bincount:
            clusters[:, d] = torch.bincount(cloest_clusters, weights=pointcloud[:, d], minlength=k) / clusters_count
        
        # for clusters that have no points assigned
        mask = clusters_count == 0
        clusters[mask] = clusters_previous[mask]

    end = time.time()

    if verbose:
        print("K-means example with {:,} points in dimension {:,}, K = {:,}:".format(n, dim, k))
        print('Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n'.format(
                iterations, end - start, iterations, (end-start) / iterations))
    
    # nearest neighbouring search for each cluster
    cloest_points_to_centers = distance_matrix.argmin(dim=0).long().view(-1)
    return cloest_points_to_centers

Reference

  • Ji Hou, Benjamin Graham, Matthias Nießner, Saining Xie:
    Exploring Data-Efficient 3D Scene Understanding With Contrastive Scene Contexts. CVPR 2021: 15587-15597
  • https://github.com/facebookresearch/ContrastiveSceneContexts/blob/83515bef4754b3d90fc3b3a437fa939e0e861af8/downstream/semseg/lib/sampling_points.py#L28
 类似资料: