Mean-Shift Clustering Tutorial with Python Examples

The Mean-Shift algorithm is a hill-climbing algorithm based on kernel density estimation. The algorithm can be widely used for tasks such as clustering, image segmentation, tracking, etc. using a framework like Python. However, the main advantage over an algorithm such as K-Means is the fact that Mean-Shift does not require the user to input the number of clusters. The algorithm will look for clusters that occur naturally in the data.

The examples in this guide showcases the python library called scikit-learn (sklearn) which is a widely used machine learning library.

The Mean-Shift Algorithm Explained

The goal of Mean-Shift is to find clusters in the data. It does so by looking for centroids which are determined based on data point density. The algorithm will remove all near-duplicate centroids during post-processing.

Let’s check out how to find the centroids. First, pick an arbitrary initial centroid:

We now want to move this centroid in the direction of the highest data point density. This is done by first calculating the Mean-Shift direction and then updating our centroid, x.

Calculate Mean-Shift vector

Given any amount of data points xi in n-dimensional space, the basic form of the mean shift vector for any point x in space can be expressed as

 v_s = \frac{1}{K} \sum_{x_i\epsilon S_k} (x_i - x)

The vector, vs, is the mean shift vector, Sk represents the data points where the distance from the point of the data set to x is less than the radius r of the sphere. That is

 S_h(x) = \{ y: (y-x_i)^T (y-x_i)<r^2 \}

We now have this
mean shift python clustering principle

Update centroid

We can now update the position of our centrod as follows.

 x := x + v_s

Which gives us

This process is repeated until the change in x in neglectable. We then know that we have found an optimal centroid with the highest density. This will look like this
mean shift python clustering finished search

When the vector vs is sufficiently small, the algorithm has converged and the optimal centroid for that cluster has been found.

Pros and cons of Mean-Shift Clustering

The good thing about this algorithm is clearly that fact that you don’t have to specify the number of clusters in your data. The algorithm will find natural clusters in your data which is often preffered. You can also be sure that your algorithm will converge and stop within a finite number of iterations.

One of the main weaknesses of Mean-Shift is its scalability. The scalability is limited by the fact that we have to do multiple nearest neighbour searches in each iteration.
The algorithm is not highly scalable, as it requires multiple nearest neighbor searches during the execution of the algorithm.

How to do Mean-Shift clustering in Python with sklearn

Examples can sometimes be the easiest way to learn how to do stuff yourself. Here is a quick example on how use sklearn to do Mean-Shift clustering in python.

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs

# Generate sample data
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# Bandwidth is found automatically with
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

# Run the algorithm
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

# Print the number of clusters in the data
print("Clusters found: %d" % n_clusters_)

Feel free to ask questions below if you have any about the code!

Summary

Thank you for studying our tutorial on Mean-Shift clustering with Python examples. If you want more python examples like this one, be sure to check out our other guides such as this.