Machine learning - K-means algorithm

1, K-means algorithm

1. Introduce

  • As shown in figure a above, all green points are data. It is obvious from the subjective naked eye that the data is divided into two categories, one above and one below. But computers don't know. How do computers classify?
  • As shown in Figure b, the computer randomly generates two centroids (generally, the computer will select centroids from the existing original data rather than randomly generate new centroids), one red and one blue.
  • As shown in Figure c, the computer will calculate which centroid point the data is closest to according to the coordinates of each point and classify it into which category. As shown in Figure c, part of the data is divided into blue category and part into red category.
  • Next, the centroid point will be updated (centroid point update: calculate the mean value of the previously classified data according to the class to obtain the new centroid point of each class). As shown in Figure d, the computer will reselect two centroid points and recalculate the distance between each data and the centroid point
  • Further classify the data. Which data is closest to which centroid point is divided into which category, as shown in Figure e.
  • Update the centroid again, as shown in Figure f, so that the centroid is close to the center in each classification. Finally, two classifications are obtained
  • At the same time, the clustering algorithm is divided into the mean value of KNN and the centroid of which class is similar. At the same time, the clustering algorithm is divided into the mean value of KNN and the centroid of which class is similar.

2. K-means

  • k-means is the simplest and most efficient clustering algorithm, which belongs to unsupervised learning algorithm
  • Core idea: the user specifies k initial centroids as the cluster, and repeats the iteration until the algorithm converges
  • Basic algorithm flow
    1. Select K initial centroids (as the initial cluster)
    (1) For each sample point, the nearest centroid is calculated, and its category is marked as the cluster corresponding to the centroid
    (2) Recalculate the centroids corresponding to k cluster s
    3. Until the l centroid no longer changes or the iteration reaches the upper limit

3. Code case

import numpy as np
import matplotlib.pyplot as plt
#The distance function in scipy is introduced, and the default European distance
from scipy.spatial.distance import cdist
#Generate clustering data directly from sklearn
from sklearn.datasets._samples_generator import make_blobs

if __name__ == '__main__':
    #Data loading
    # n_samples: indicates the number of data
    # centers: indicates the number of centroids, which is also the number of classifications
    # random_state: random seed. The randomly generated data is the same data
    # cluster_std: when the centroid points are randomly distributed, what is the deviation value
    # x: Indicates the data set (coordinates), y: indicates the category and which category it belongs to
    x,y = make_blobs(n_samples=100,centers=6,random_state=1024,cluster_std=0.6)
    #Because the data set x is a two-dimensional array, the first column is the x coordinate of the point, and the second column is the y coordinate
    #Establish a scatter diagram. c represents the color, c=y represents the number of colors according to the number of classifications
  • Scatter chart display
#Implementation of k-means algorithm
    class K_Means(object):
        #Initialization, parameter n_clusters(K), number of iterations max_iter, initial centroid centroids
        def __init__(self,n_clusters=6,max_iter=300,centroids=[]):
            self.n_clusters = n_clusters
            self.max_iter = max_iter
            self.centroids = np.array(centroids,dtype=np.float)

        #Training model method, k-means clustering process, incoming raw data
        def fit(self,data):
            #If the initial centroid is not specified, the points in data are randomly selected as the initial centroid
                #Six integers from 0 to data rows are randomly generated from data as index values
                self.centroids = data[np.random.randint(0,data.shape[0],self.n_clusters),:]

            #Start iteration
            for i in range(self.max_iter):
                #1. Calculate the distance matrix and get a 100 * 6 matrix
                distance = cdist(data,self.centroids)
                #2. Sort the distance from today to far, and select the category of the nearest centroid point as the classification of the current point
                c_ind = np.argmin(distance,axis=1)
                #3. Calculate the mean value of each type of data and update the coordinates of centroid points
                for i in range(self.n_clusters):
                    #First, rule out what doesn't appear in C_ Category in ind
                    if i in c_ind:
                        #Select the point i in the u category, take the mean value of the coordinates in the data, and update the ith centroid
                        self.centroids[i] = np.mean(data[c_ind==i],axis=0)

        #Implementation prediction method
        def predict(self,sample):
            #First calculate the distance matrix, and then select the category of the centroid closest to the centroid
            distance = cdist(sample, self.centroids)
            c_ind = np.argmin(distance, axis=1)

            return c_ind

    dist = np.array([[121,221,32,43],
    c_ind = np.argmin(dist, axis=1)

    #Define a function to draw subgraphs
    def plotKmeans(x,y,centroids,subplot,title):
        #The assignment subgraph 121 represents the first of the subgraphs with one row and two columns
        #Draw the center of mass

    kmeans = K_Means(max_iter=300,centroids=np.array([[2,1],[2,2],[2,3],[2,4],[2,5],[2,6]]))
    plotKmeans(x,y,kmeans.centroids,121,'Inital State')

    #Start clustering
    plotKmeans(x,y,kmeans.centroids, 122, 'Final State')

    #Predict the category of new data points
    x_new = np.array([[0,0],[10,7]])
    y_pred = kmeans.predict(x_new)


    plt.scatter(x_new[:,0],x_new[:,1], c='black')
  • Initial and final forecast scatter chart display:

Posted by hob_goblin on Sun, 08 May 2022 07:22:00 +0300