# 1, K-means algorithm

## 1. Introduce

- As shown in figure a above, all green points are data. It is obvious from the subjective naked eye that the data is divided into two categories, one above and one below. But computers don't know. How do computers classify?
- As shown in Figure b, the computer randomly generates two centroids (generally, the computer will select centroids from the existing original data rather than randomly generate new centroids), one red and one blue.
- As shown in Figure c, the computer will calculate which centroid point the data is closest to according to the coordinates of each point and classify it into which category. As shown in Figure c, part of the data is divided into blue category and part into red category.
- Next, the centroid point will be updated (centroid point update: calculate the mean value of the previously classified data according to the class to obtain the new centroid point of each class). As shown in Figure d, the computer will reselect two centroid points and recalculate the distance between each data and the centroid point
- Further classify the data. Which data is closest to which centroid point is divided into which category, as shown in Figure e.
- Update the centroid again, as shown in Figure f, so that the centroid is close to the center in each classification. Finally, two classifications are obtained
- At the same time, the clustering algorithm is divided into the mean value of KNN and the centroid of which class is similar. At the same time, the clustering algorithm is divided into the mean value of KNN and the centroid of which class is similar.

## 2. K-means

- k-means is the simplest and most efficient clustering algorithm, which belongs to unsupervised learning algorithm
- Core idea: the user specifies k initial centroids as the cluster, and repeats the iteration until the algorithm converges
- Basic algorithm flow

1. Select K initial centroids (as the initial cluster)

2,repeat:

(1) For each sample point, the nearest centroid is calculated, and its category is marked as the cluster corresponding to the centroid

(2) Recalculate the centroids corresponding to k cluster s

3. Until the l centroid no longer changes or the iteration reaches the upper limit

## 3. Code case

import numpy as np import matplotlib.pyplot as plt #The distance function in scipy is introduced, and the default European distance from scipy.spatial.distance import cdist #Generate clustering data directly from sklearn from sklearn.datasets._samples_generator import make_blobs if __name__ == '__main__': #Data loading # n_samples: indicates the number of data # centers: indicates the number of centroids, which is also the number of classifications # random_state: random seed. The randomly generated data is the same data # cluster_std: when the centroid points are randomly distributed, what is the deviation value # x: Indicates the data set (coordinates), y: indicates the category and which category it belongs to x,y = make_blobs(n_samples=100,centers=6,random_state=1024,cluster_std=0.6) #Because the data set x is a two-dimensional array, the first column is the x coordinate of the point, and the second column is the y coordinate #Establish a scatter diagram. c represents the color, c=y represents the number of colors according to the number of classifications plt.scatter(x[:,0],x[:,1],c=y) plt.show()

- Scatter chart display

#Implementation of k-means algorithm class K_Means(object): #Initialization, parameter n_clusters(K), number of iterations max_iter, initial centroid centroids def __init__(self,n_clusters=6,max_iter=300,centroids=[]): self.n_clusters = n_clusters self.max_iter = max_iter self.centroids = np.array(centroids,dtype=np.float) #Training model method, k-means clustering process, incoming raw data def fit(self,data): #If the initial centroid is not specified, the points in data are randomly selected as the initial centroid if(self.centroids.shape==(0,)): #Six integers from 0 to data rows are randomly generated from data as index values self.centroids = data[np.random.randint(0,data.shape[0],self.n_clusters),:] #Start iteration for i in range(self.max_iter): #1. Calculate the distance matrix and get a 100 * 6 matrix distance = cdist(data,self.centroids) #2. Sort the distance from today to far, and select the category of the nearest centroid point as the classification of the current point c_ind = np.argmin(distance,axis=1) #3. Calculate the mean value of each type of data and update the coordinates of centroid points for i in range(self.n_clusters): #First, rule out what doesn't appear in C_ Category in ind if i in c_ind: #Select the point i in the u category, take the mean value of the coordinates in the data, and update the ith centroid self.centroids[i] = np.mean(data[c_ind==i],axis=0) #Implementation prediction method def predict(self,sample): #First calculate the distance matrix, and then select the category of the centroid closest to the centroid distance = cdist(sample, self.centroids) c_ind = np.argmin(distance, axis=1) return c_ind dist = np.array([[121,221,32,43], [121,1,12,23], [65,21,2,43], [1,321,32,43], [21,11,22,3]]) c_ind = np.argmin(dist, axis=1) print(c_ind) x_new=x[0:5] print(x_new) print(c_ind==2) print(x_new[c_ind==2]) np.mean(x_new[c_ind==2],axis=0) #test #Define a function to draw subgraphs def plotKmeans(x,y,centroids,subplot,title): #The assignment subgraph 121 represents the first of the subgraphs with one row and two columns plt.subplot(subplot) plt.scatter(x[:,0],x[:,1],c='r') #Draw the center of mass plt.scatter(centroids[:,0],centroids[:,1],c=np.array(range(6)),s=100) plt.title(title) kmeans = K_Means(max_iter=300,centroids=np.array([[2,1],[2,2],[2,3],[2,4],[2,5],[2,6]])) plt.figure(figsize=(16,6)) plotKmeans(x,y,kmeans.centroids,121,'Inital State') #Start clustering kmeans.fit(x) plotKmeans(x,y,kmeans.centroids, 122, 'Final State') #Predict the category of new data points x_new = np.array([[0,0],[10,7]]) y_pred = kmeans.predict(x_new) print(kmeans.centroids) print(y_pred) plt.scatter(x_new[:,0],x_new[:,1], c='black') plt.show()

- Initial and final forecast scatter chart display: