均值漂移聚类(Mean-shift)
均值漂移算法:一种基于密度梯度上升的聚类算法(沿着密度上升方向寻找聚类中心点)
公式
\[M(x) = {1 \over k}\sum_{x_i\in S_h}(u-x_i) \] 均值偏移
\[u^{t+1} = M^t+u^t \] 中心更新
\(S_h\):以u为中心点,半径为h的高维球区域;
k:包含在\(S_h\)范围内点的个数;
\(x_i\):包含在\(S_h\)范围内的点;
\(M^t\)为t状态下求得的偏移均值;
\(u^t\):为t状态下的中心
算法流程
- 随机选择未分类点作为中心点
- 找出离中心点距离在带宽之内的点,记作集合S
- 计算从中心点到集合S中每个元素的偏移向量M
- 中心点以向量M移动
- 重复步骤2-4,直到收敛
- 重复步骤1-5,直到所有的点都被归类
- 分类:根据每个类,对每个点的访问频率,取访问频率最大的那个类,作为当前点集的所属类。
模型训练
自动计算带宽(区域半径)
from sklearn.cluster import MeanShift,estimate_bandwidth
# 计算半径
bandwidth = estimate_bandwidth(X, n_samples=500)
## 模型建立与训练
ms = MeansShift(bandwidth=bandwidth)
ms.fit(X)