The Math Behind K-Means Clustering: A Simplified Overview

The computer scientist Yann LeCun famously said that if intelligence was a cake, unsupervised learning would be the cake , supervised learning would be the icing and reinforcement learning would be the cherry on the cake. So, here we are dealing with the cake it means K-Means falls under the category of unsupervised learning algorithm which means here the data has no labels.

K-Means is a clustering method used to group similar items together. For example, imagine you have different flowers in your garden, like hibiscus, sunflowers, and roses. While you know their names, a child might not. However, the child can still sort the flowers into separate piles based on how they look, grouping all the roses together, all the hibiscus together, and so on. K-Means works in a similar way by organizing data into groups based on their similarities. It can be applied to both 2D data and High dimension data as well.

Now, let’s look at the picture of K-Means in order to understand it better.

Analyticsvidya.com

Image source: analyticsvidhya

Lets further deep dive into the mathematics of K-means.

Step 1: First, randomly choose the centroids. For example, let's say we pick three centroids, which are marked as crosses in the image.

Image source: Mathworks

Step 2: Next, we calculate the distance between each data point and the centroids using Euclidean distance.

The formula for Euclidean distance between two points P1 and P2 in an n-dimensional space is

$$d = \sqrt{\sum_{i=1}^n \left( p_{1i} - p_{2i} \right)^2}$$

Step 3: Update the Centroids

Step 4: Recalculate distance

Step 5: Keep repeating the steps until the centroids stop moving

Let’s do this by taking an example.

A(2,3) , B(5,4) , C(1,8) , D(7,5) , E(6,9) , F(8,7)

From the above data points choose two random centroids, C1 (2,3) and C2 (7,5). Since we've selected two centroids, this means k = 2.

Now, begin calculating the Euclidean distance from each data point to Centroid C1.

$$d = \sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}$$

PointC1(2,3)C2(7,5)Closest Centroid
A(2,3)05.39C1
B(5,4)3.162.24C2
C(1,8)5.106.71C1
D(7,5)5.390C2
E(6,9)7.214.12C2
F(8,7)7.812.24C2

We determine the closest centroids by comparing the distances. The smaller the distance, the closer the data point is to the centroid.

Here we can se A(2,3) is closer to C1 (2,3) for d = 0

Cluster 1 (C1): A(2,3),C(1,8)

Cluster 2 (C2): B(5,4),D(7,5),E(6,9),F(8,7)

Update the centroids

New C1 = (1.5,5.5), mean of A(2,3) and C(1,8)

$$x = \frac{2 + 1}{2} = 1.5 \quad \text{and} \quad y = \frac{3 + 8}{2} = 5.5$$

New C2 = (6.5,6.25), mean of B(5,4),D(7,5),E(6,9),F(8,7)

$$x = \frac{5 + 7 + 6 + 8}{4} = 6.5 \quad \text{and} \quad y = \frac{4 + 5 + 9 + 7}{4} = 6.25$$

Recalculate the distance

PointsC1(1.5,5.5)C2(6.5,6.25)Closest Centroid
A(2,3)2.555.03C1
B(5,4)3.801.95C2
C(1,8)2.556.09C1
D(7,5)5.520.55C2
E(6,9)5.703.78C2
F(8,7)6.672.30C2

Final Clusters are

Cluster 1 (C1): A(2,3),C(1,8)

Cluster 2 (C2):B(5,4),D(7,5),E(6,9),F(8,7)

Here we can see the clusters have not changed. So, we can finally say that the algorithm is converged and we don’t need to repeat the steps.