
簡單易學的機器學習算法—Mean Shift聚類算法
一、Mean Shift算法概述
Mean Shift算法,又稱為均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后來由Yizong Cheng對其進行擴充,主要提出了兩點的改進:
定義了核函數;
增加了權重系數。
核函數的定義使得偏移值對偏移向量的貢獻隨之樣本與被偏移點的距離的不同而不同。權重系數使得不同樣本的權重不同。Mean Shift算法在聚類,圖像平滑、分割以及視頻跟蹤等方面有廣泛的應用。
二、Mean Shift算法的核心原理
2.1、核函數
在Mean Shift算法中引入核函數的目的是使得隨著樣本與被偏移點的距離不同,其偏移量對均值偏移向量的貢獻也不同。核函數是機器學習中常用的一種方式。核函數的定義如下所示:
并且滿足:
(1)、k是非負的
(2)、k是非增的
(3)、k是分段連續的
那么,函數K(x)就稱為核函數。
常用的核函數有高斯核函數。高斯核函數如下所示:
其中,h稱為帶寬(bandwidth),不同帶寬的核函數如下圖所示:
上圖的畫圖腳本如下所示:
'''
Date:201604026
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
import math
def cal_Gaussian(x, h=1):
molecule = x * x
denominator = 2 * h * h
left = 1 / (math.sqrt(2 * math.pi) * h)
return left * math.exp(-molecule / denominator)
x = []
for i in xrange(-40,40):
x.append(i * 0.5);
score_1 = []
score_2 = []
score_3 = []
score_4 = []
for i in x:
score_1.append(cal_Gaussian(i,1))
score_2.append(cal_Gaussian(i,2))
score_3.append(cal_Gaussian(i,3))
score_4.append(cal_Gaussian(i,4))
plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")
plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()
2.2、Mean Shift算法的核心思想
2.2.1、基本原理
對于Mean Shift算法,是一個迭代的步驟,即先算出當前點的偏移均值,將該點移動到此偏移均值,然后以此為新的起始點,繼續移動,直到滿足最終的條件。此過程可由下圖的過程進行說明(圖片來自參考文獻3):
步驟1:在指定的區域內計算偏移均值(如下圖的黃色的圈)
步驟2:移動該點到偏移均值點處
步驟3: 重復上述的過程(計算新的偏移均值,移動)
步驟4:滿足了最終的條件,即退出
從上述過程可以看出,在Mean Shift算法中,最關鍵的就是計算每個點的偏移均值,然后根據新計算的偏移均值更新點的位置。
2.2.2、基本的Mean Shift向量形式
對于給定的d維空間Rd中的n個樣本點,則對于x點,其Mean Shift向量的基本形式為:
其中,Sh指的是一個半徑為h的高維球區域,如上圖中的藍色的圓形區域。Sh的定義為:
這樣的一種基本的Mean Shift形式存在一個問題:在Sh的區域內,每一個點對x的貢獻是一樣的。而實際上,這種貢獻與x到每一個點之間的距離是相關的。同時,對于每一個樣本,其重要程度也是不一樣的。
2.2.3、改進的Mean Shift向量形式
基于以上的考慮,對基本的Mean Shift向量形式中增加核函數和樣本權重,得到如下的改進的Mean Shift向量形式:
其中:
G(x)是一個單位的核函數。H是一個正定的對稱d×d矩陣,稱為帶寬矩陣,其是一個對角陣。w(xi)?0是每一個樣本的權重。對角陣H的形式為:
上述的Mean Shift向量可以改寫成:
Mean Shift向量Mh(x)是歸一化的概率密度梯度。
2.3、Mean Shift算法的解釋
在Mean Shift算法中,實際上是利用了概率密度,求得概率密度的局部最優解。
2.3.1、概率密度梯度
對一個概率密度函數f(x),已知d維空間中n個采樣點xi,i=1,?,n,f(x)的核函數估計(也稱為Parzen窗估計)為:
其中
w(xi)?0是一個賦給采樣點xi的權重
K(x)是一個核函數
概率密度函數f(x)的梯度▽f(x)的估計為
令,則有:
其中,第二個方括號中的就是Mean Shift向量,其與概率密度梯度成正比。
2.3.2、Mean Shift向量的修正
Mh(x)=∑ni=1G(∥∥xi?xh∥∥2)w(xi)xi∑ni=1G(xi?xh)w(xi)?x
記:,則上式變成:
Mh(x)=mh(x)+x
這與梯度上升的過程一致。
2.4、Mean Shift算法流程
Mean Shift算法的算法流程如下:
計算mh(x)
令x=mh(x)
如果∥mh(x)?x∥<ε,結束循環,否則,重復上述步驟
三、實驗
3.1、實驗數據
實驗數據如下圖所示(來自參考文獻1):
畫圖的代碼如下:
'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data")
x = []
y = []
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 2:
x.append(float(lines[0]))
y.append(float(lines[1]))
f.close()
plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()
3.2、實驗的源碼
#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''
import math
import sys
import numpy as np
MIN_DISTANCE = 0.000001#mini error
def load_data(path, feature_num=2):
f = open(path)
data = []
for line in f.readlines():
lines = line.strip().split("\t")
data_tmp = []
if len(lines) != feature_num:
continue
for i in xrange(feature_num):
data_tmp.append(float(lines[i]))
data.append(data_tmp)
f.close()
return data
def gaussian_kernel(distance, bandwidth):
m = np.shape(distance)[0]
right = np.mat(np.zeros((m, 1)))
for i in xrange(m):
right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
right[i, 0] = np.exp(right[i, 0])
left = 1 / (bandwidth * math.sqrt(2 * math.pi))
gaussian_val = left * right
return gaussian_val
def shift_point(point, points, kernel_bandwidth):
points = np.mat(points)
m,n = np.shape(points)
#計算距離
point_distances = np.mat(np.zeros((m,1)))
for i in xrange(m):
point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)
#計算高斯核
point_weights = gaussian_kernel(point_distances, kernel_bandwidth)
#計算分母
all = 0.0
for i in xrange(m):
all += point_weights[i, 0]
#均值偏移
point_shifted = point_weights.T * points / all
return point_shifted
def euclidean_dist(pointA, pointB):
#計算pointA和pointB之間的歐式距離
total = (pointA - pointB) * (pointA - pointB).T
return math.sqrt(total)
def distance_to_group(point, group):
min_distance = 10000.0
for pt in group:
dist = euclidean_dist(point, pt)
if dist < min_distance:
min_distance = dist
return min_distance
def group_points(mean_shift_points):
group_assignment = []
m,n = np.shape(mean_shift_points)
index = 0
index_dict = {}
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
print item_1
if item_1 not in index_dict:
index_dict[item_1] = index
index += 1
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
group_assignment.append(index_dict[item_1])
return group_assignment
def train_mean_shift(points, kenel_bandwidth=2):
#shift_points = np.array(points)
mean_shift_points = np.mat(points)
max_min_dist = 1
iter = 0
m, n = np.shape(mean_shift_points)
need_shift = [True] * m
#cal the mean shift vector
while max_min_dist > MIN_DISTANCE:
max_min_dist = 0
iter += 1
print "iter : " + str(iter)
for i in range(0, m):
#判斷每一個樣本點是否需要計算偏置均值
if not need_shift[i]:
continue
p_new = mean_shift_points[i]
p_new_start = p_new
p_new = shift_point(p_new, points, kenel_bandwidth)
dist = euclidean_dist(p_new, p_new_start)
if dist > max_min_dist:#record the max in all points
max_min_dist = dist
if dist < MIN_DISTANCE:#no need to move
need_shift[i] = False
mean_shift_points[i] = p_new
#計算最終的group
group = group_points(mean_shift_points)
return np.mat(points), mean_shift_points, group
if __name__ == "__main__":
#導入數據集
path = "./data"
data = load_data(path, 2)
#訓練,h=2
points, shift_points, cluster = train_mean_shift(data, 2)
for i in xrange(len(cluster)):
print "%5.2f,%5.2f\t%5.2f,%5.2f\t%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])
3.3、實驗的結果
經過Mean Shift算法聚類后的數據如下所示:
'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 3:
label = int(lines[2])
if label == 0:
data_1 = lines[0].strip().split(",")
cluster_x_0.append(float(data_1[0]))
cluster_y_0.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
elif label == 1:
data_1 = lines[0].strip().split(",")
cluster_x_1.append(float(data_1[0]))
cluster_y_1.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
else:
data_1 = lines[0].strip().split(",")
cluster_x_2.append(float(data_1[0]))
cluster_y_2.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
f.close()
plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')數據分析師培訓
#plt.legend(loc="best")
plt.show()
數據分析咨詢請掃描二維碼
若不方便掃碼,搜微信號:CDAshujufenxi
解碼數據基因:從數字敏感度到邏輯思維 每當看到超市貨架上商品的排列變化,你是否會聯想到背后的銷售數據波動?三年前在零售行 ...
2025-05-23在本文中,我們將探討 AI 為何能夠加速數據分析、如何在每個步驟中實現數據分析自動化以及使用哪些工具。 數據分析中的AI是什么 ...
2025-05-20當數據遇見人生:我的第一個分析項目 記得三年前接手第一個數據分析項目時,我面對Excel里密密麻麻的銷售數據手足無措。那些跳動 ...
2025-05-20在數字化運營的時代,企業每天都在產生海量數據:用戶點擊行為、商品銷售記錄、廣告投放反饋…… 這些數據就像散落的拼圖,而相 ...
2025-05-19在當今數字化營銷時代,小紅書作為國內領先的社交電商平臺,其銷售數據蘊含著巨大的商業價值。通過對小紅書銷售數據的深入分析, ...
2025-05-16Excel作為最常用的數據分析工具,有沒有什么工具可以幫助我們快速地使用excel表格,只要輕松幾步甚至輸入幾項指令就能搞定呢? ...
2025-05-15數據,如同無形的燃料,驅動著現代社會的運轉。從全球互聯網用戶每天產生的2.5億TB數據,到制造業的傳感器、金融交易 ...
2025-05-15大數據是什么_數據分析師培訓 其實,現在的大數據指的并不僅僅是海量數據,更準確而言是對大數據分析的方法。傳統的數 ...
2025-05-14CDA持證人簡介: 萬木,CDA L1持證人,某電商中廠BI工程師 ,5年數據經驗1年BI內訓師,高級數據分析師,擁有豐富的行業經驗。 ...
2025-05-13CDA持證人簡介: 王明月 ,CDA 數據分析師二級持證人,2年數據產品工作經驗,管理學博士在讀。 學習入口:https://edu.cda.cn/g ...
2025-05-12CDA持證人簡介: 楊貞璽 ,CDA一級持證人,鄭州大學情報學碩士研究生,某上市公司數據分析師。 學習入口:https://edu.cda.cn/g ...
2025-05-09CDA持證人簡介 程靖 CDA會員大咖,暢銷書《小白學產品》作者,13年頂級互聯網公司產品經理相關經驗,曾在百度、美團、阿里等 ...
2025-05-07相信很多做數據分析的小伙伴,都接到過一些高階的數據分析需求,實現的過程需要用到一些數據獲取,數據清洗轉換,建模方法等,這 ...
2025-05-06以下的文章內容來源于劉靜老師的專欄,如果您想閱讀專欄《10大業務分析模型突破業務瓶頸》,點擊下方鏈接 https://edu.cda.cn/g ...
2025-04-30CDA持證人簡介: 邱立峰 CDA 數據分析師二級持證人,數字化轉型專家,數據治理專家,高級數據分析師,擁有豐富的行業經驗。 ...
2025-04-29CDA持證人簡介: 程靖 CDA會員大咖,暢銷書《小白學產品》作者,13年頂級互聯網公司產品經理相關經驗,曾在百度,美團,阿里等 ...
2025-04-28CDA持證人簡介: 居瑜 ,CDA一級持證人國企財務經理,13年財務管理運營經驗,在數據分析就業和實踐經驗方面有著豐富的積累和經 ...
2025-04-27數據分析在當今信息時代發揮著重要作用。單因素方差分析(One-Way ANOVA)是一種關鍵的統計方法,用于比較三個或更多獨立樣本組 ...
2025-04-25CDA持證人簡介: 居瑜 ,CDA一級持證人國企財務經理,13年財務管理運營經驗,在數據分析就業和實踐經驗方面有著豐富的積累和經 ...
2025-04-25在當今數字化時代,數據分析師的重要性與日俱增。但許多人在踏上這條職業道路時,往往充滿疑惑: 如何成為一名數據分析師?成為 ...
2025-04-24