Mean-Shift Clustering Tutorial with Python Examples
The Mean-Shift algorithm is a hill-climbing algorithm based on kernel density estimation. The algorithm can be widely used for tasks such as clustering, image segmentation, tracking, etc. using a framework like Python. However, the main advantage over an algorithm such as K-Means is the fact that Mean-Shift does not require the user to input the number of clusters. The algorithm will look for clusters that occur naturally in the data.
The examples in this guide showcases the python library called scikit-learn (sklearn) which is a widely used machine learning library.
The Mean-Shift Algorithm Explained
The goal of Mean-Shift is to find clusters in the data. It does so by looking for centroids which are determined based on data point density. The algorithm will remove all near-duplicate centroids during post-processing.
Let’s check out how to find the centroids. First, pick an arbitrary initial centroid:
We now want to move this centroid in the direction of the highest data point density. This is done by first calculating the Mean-Shift direction and then updating our centroid, x.
Calculate Mean-Shift vector
Given any amount of data points xi in n-dimensional space, the basic form of the mean shift vector for any point x in space can be expressed as
The vector, vs, is the mean shift vector, Sk represents the data points where the distance from the point of the data set to x is less than the radius r of the sphere. That is
We now have this
We can now update the position of our centrod as follows.
Which gives us
This process is repeated until the change in x in neglectable. We then know that we have found an optimal centroid with the highest density. This will look like this
When the vector vs is sufficiently small, the algorithm has converged and the optimal centroid for that cluster has been found.
Pros and cons of Mean-Shift Clustering
The good thing about this algorithm is clearly that fact that you don’t have to specify the number of clusters in your data. The algorithm will find natural clusters in your data which is often preffered. You can also be sure that your algorithm will converge and stop within a finite number of iterations.
One of the main weaknesses of Mean-Shift is its scalability. The scalability is limited by the fact that we have to do multiple nearest neighbour searches in each iteration.
The algorithm is not highly scalable, as it requires multiple nearest neighbor searches during the execution of the algorithm.
How to do Mean-Shift clustering in Python with sklearn
Examples can sometimes be the easiest way to learn how to do stuff yourself. Here is a quick example on how use sklearn to do Mean-Shift clustering in python.
import numpy as np from sklearn.cluster import MeanShift, estimate_bandwidth from sklearn.datasets.samples_generator import make_blobs # Generate sample data centers = [[1, 1], [-1, -1], [1, -1]] X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6) # Bandwidth is found automatically with bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500) # Run the algorithm ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) ms.fit(X) labels = ms.labels_ cluster_centers = ms.cluster_centers_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) # Print the number of clusters in the data print("Clusters found: %d" % n_clusters_)
Feel free to ask questions below if you have any about the code!
Thank you for studying our tutorial on Mean-Shift clustering with Python examples. If you want more python examples like this one, be sure to check out our other guides such as this.