Machine learning - k-NN idea and implementation (based on Java)

k-nearest neighbors (k-NN) is a basic classification and regression method.
Input: the feature vector of the instance, corresponding to the points in the feature space;
Output: The class of the instance.
When classifying, the new instance is predicted by majority voting according to the class of its k nearest neighbors of the training instance.
During regression, for a new instance, it is determined by the average of its k nearest neighbor training instances.
The code of the two is similar, the current implementation of the k-nearest neighbor algorithm in the classification problem.
data preparation: download link
Class constructor and implementation data read-in

public KnnClassification(String paraFilename) {
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			System.out.println("The number of totall instances is " + dataset.numInstances());
			// The last attribute is the decision class.
			dataset.setClassIndex(dataset.numAttributes() - 1);
			System.out.println("The data set is: " + dataset.toString());
			fileReader.close();
		} catch (Exception ee) {
			System.out.println("Error occurred while trying to read \'" + paraFilename
					+ "\' in KnnClassification constructor.\r\n" + ee);
			System.exit(0);
		} // Of try
	}// Of the first constructor

Data printing:

The data set is: @relation iris

@attribute sepallength numeric
@attribute sepalwidth numeric
@attribute petallength numeric
@attribute petalwidth numeric
@attribute class {Iris-setosa,Iris-versicolor,Iris-virginica}

@data
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3,1.4,0.1,Iris-setosa
4.3,3,1.1,0.1,Iris-setosa
5.8,4,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5,3,1.6,0.2,Iris-setosa
5,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5,3.3,1.4,0.2,Iris-setosa
7,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5,2,3.5,1,Iris-versicolor
5.9,3,4.2,1.5,Iris-versicolor
6,2.2,4,1,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3,5,1.7,Iris-versicolor
6,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6,2.7,5.1,1.6,Iris-versicolor
5.4,3,4.5,1.5,Iris-versicolor
6,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3,4.1,1.3,Iris-versicolor
5.5,2.5,4,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3,4.6,1.4,Iris-versicolor
5.8,2.6,4,1.2,Iris-versicolor
5,2.3,3.3,1,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3,5.8,2.2,Iris-virginica
7.6,3,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3,5.5,2.1,Iris-virginica
5.7,2.5,5,2,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6,2.2,5,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2,Iris-virginica
7.7,2.8,6.7,2,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6,3,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3,5.2,2.3,Iris-virginica
6.3,2.5,5,1.9,Iris-virginica
6.5,3,5.2,2,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3,5.1,1.8,Iris-virginica

From the above print results, we find that the data is sorted by category, in order to have a better representation of the data, we need to shuffle the data set.
So, let's shuffle the index of the dataset first:

/**
	 * 
	 *********************
	 * @Title: getRandomIndices
	 * @Description: TODO(Get a random indices for data randomization.)
	 *
	 * @param paraLength The length of the sequence.
	 * @return An array of indices, e.g., {4, 3, 1, 5, 0, 2} with length 6.
	 *********************
	 *
	 */
	public static int[] getRandomIndices(int paraLength) {
		int[] resultIndices = new int[paraLength];

		// Step 1. Initialize.
		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		// Step 2. Randomly swap.
		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			// Generate two random indices.
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			// Swap.
			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i
		return resultIndices;
	}// Of getRandomIndices

Split the scrambled dataset:

/**
	 * 
	 *********************
	 * @Title: splitTrainingTesting
	 * @Description: TODO(Split the data into training and testing parts.)
	 *
	 * @param paraTrainingFraction The fraction of the training set.
	 *********************
	 *
	 */
	public void splitTrainingTesting(double paraTrainingFraction) {
		int tempSize = dataset.numInstances();
		int[] tempIndices = getRandomIndices(tempSize);
		int tempTrainingSize = (int) (tempSize * paraTrainingFraction);

		trainingSet = new int[tempTrainingSize];
		testingSet = new int[tempSize - tempTrainingSize];

		for (int i = 0; i < tempTrainingSize; i++) {
			trainingSet[i] = tempIndices[i];
		} // Of for i

		for (int i = 0; i < tempSize - tempTrainingSize; i++) {
			testingSet[i] = tempIndices[i + tempTrainingSize];
		} // Of for i
	}// Of splitTrainingTesting

At this point, the simple data processing is completed, and the entire data set is divided into training data set and test data set.
Since the k-NN algorithm does not have a specific model training process, it can skip the simulation and make predictions directly.
Forecast preparation:
The distance from the test instance to the training instance needs to be calculated:
The code exemplifies two calculation methods of distance - Manhattan distance, Euclidean distance (square of)
Manhattan Distance:
L 1 = ( x i , x j ) = ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ L_1=(x_i,x_j)=\sum\limits_{l=1}^n|x_i^{(l)}-x_j^{(l)}| L1​=(xi​,xj​)=l=1∑n​∣xi(l)​−xj(l)​∣
Euclidean distance:
L 2 = ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ 2 ) 1 2 L_2=(x_i,x_j)=\begin{pmatrix}\sum\limits_{l=1}^n|x_i^{(l)}-x_j^{(l)}|^2\end{pmatrix}^{\frac{1}{2}} L2​=(xi​,xj​)=(l=1∑n​∣xi(l)​−xj(l)​∣2​)21​
Replenish:
Minkowski distance:
L p = ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ p ) 1 p L_p=(x_i,x_j)=\begin{pmatrix}\sum\limits_{l=1}^n|x_i^{(l)}-x_j^{(l)}|^p\end{pmatrix}^{\frac{1}{p}} Lp​=(xi​,xj​)=(l=1∑n​∣xi(l)​−xj(l)​∣p​)p1​
when p = ∞ p=\infty When p=∞, it is called the Chebyshev distance:
L p = ( x i , x j ) = max ⁡ l ( ∣ x i ( l ) − x j ( l ) ∣ ) L_p=(x_i,x_j)=\max\limits_l(|x_i^{(l)}-x_j^{(l)}|) Lp​=(xi​,xj​)=lmax​(∣xi(l)​−xj(l)​∣)
Supplementary parameters:
Hyperparameters: parameters that need to be determined before the algorithm runs;
Model parameters: The parameters learned during the operation of the algorithm.
The parameter k in the k-NN algorithm is a typical hyperparameter.
because of different p p p, will affect the selection of the nearest neighbor instance, so p p p can also be used as a hyperparameter for the k-NN algorithm.

Code for distance calculation:

	/**
	 * 
	 *********************
	 * @Title: distance
	 * @Description: TODO(The distance between two instances.)
	 *
	 * @param paraI The index of the first instance.
	 * @param paraJ The index of the second instance.
	 * @return The distance.
	 *********************
	 *
	 */
	public double distance(int paraI, int paraJ) {
		double resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				} else {
					resultDistance += tempDifference;
				} // Of if
			} // Of for i
			break;

		case EUCLIDEAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				resultDistance += tempDifference * tempDifference;
			} // Of for i
			break;
		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
			break;
		}// Of switch
		return resultDistance;
	}// Of distance

After calculating the distance, it is to select the nearest k instances:

	/**
	 * 
	 *********************
	 * @Title: computeNearests
	 * @Description: TODO(Compute the nearest k neighbors.)
	 *
	 * @param paraCurrent current instance.
	 * @return The indices of the nearest instances.
	 *********************
	 *
	 */
	public int[] computeNearests(int paraCurrent) {
		int[] resultNearests;

		// Compute all distances to avoid redundant compute.
		double[] tempDistances = new double[trainingSet.length];
		for (int i = 0; i < trainingSet.length; i++) {
			tempDistances[i] = distance(paraCurrent, trainingSet[i]);
		} // Of for i

		// resultNearests = simpleSelect(tempDistances);
		// System.out.println("The nearest of " + paraCurrent + " are: " +
		// Arrays.toString(resultNearests));

		resultNearests = selectWithHeap(tempDistances);
		System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));

		return resultNearests;
	}// Of computeNearests

Simple selection method:

/**
	 * 
	 *********************
	 * @Title: simpleSelect
	 * @Description: TODO(Select the nearest indices.)
	 *
	 * @param paraDistances The distance.
	 * @return A array of result.
	 *********************
	 *
	 */
	public int[] simpleSelect(double[] paraDistances) {
		int[] resultNearests = new int[numNeighbors];
		boolean[] tempSelected = new boolean[trainingSet.length];
		double tempMinimalDistance;
		int tempMinimalIndex = 0;

		for (int i = 0; i < numNeighbors; i++) {
			tempMinimalDistance = Double.MAX_VALUE;

			for (int j = 0; j < trainingSet.length; j++) {
				if (tempSelected[j])
					continue;

				if (paraDistances[j] < tempMinimalDistance) {
					tempMinimalDistance = paraDistances[j];
					tempMinimalIndex = j;
				} // Of if
			} // Of for j
			resultNearests[i] = trainingSet[tempMinimalIndex];
			tempSelected[tempMinimalIndex] = true;
		} // Of for i

		return resultNearests;
	}// Of simpleSelect

Heap selection method:

	/**
	 * 
	 *********************
	 * @Title: adjustHeap
	 * @Description: TODO(Adjust the heap.)
	 *
	 * @param paraStartIndex The start of the index that need to adjust.
	 * @param paraLength     The length of the adjusted sequence.
	 * @param paraDistances  The array of distance.
	 * @param paraIndexes    The index of distance.
	 *********************
	 *
	 */
	public void adjustHeap(int paraStartIndex, int paraLength, double[] paraDistances, int[] paraIndexes) {
		int tempParentIndex = paraStartIndex;
		double tempDistance = paraDistances[paraStartIndex];
		int tempIndex = paraIndexes[paraStartIndex];

		for (int i = paraStartIndex * 2 + 1; i < paraLength; i = i * 2 + 1) {
			// Select the smaller.
			if (i + 1 < paraLength && paraDistances[i + 1] < paraDistances[i])
				i++;
			if (tempDistance > paraDistances[i]) {
				// Update the index and distance.
				paraIndexes[tempParentIndex] = paraIndexes[i];
				paraDistances[tempParentIndex] = paraDistances[i];
				tempParentIndex = i;
			} else {
				break;
			} // Of if
		} // Of for i

		paraDistances[tempParentIndex] = tempDistance;
		paraIndexes[tempParentIndex] = tempIndex;
	}// Of adjustHeap

	/**
	 * 
	 *********************
	 * @Title: selectWithHeap
	 * @Description: TODO(Select the nearest indices.)
	 *
	 * @param paraDistances The distance.
	 * @return A array of result.
	 *********************
	 *
	 */
	public int[] selectWithHeap(double[] paraDistances) {
		int[] resultNearests = new int[numNeighbors];

		// Initialize the indexes.
		int[] tempIndexes = new int[trainingSet.length];
		for (int i = 0; i < trainingSet.length; i++) {
			tempIndexes[i] = i;
		} // Of for i

		// Build the heap.
		int tempLength = paraDistances.length;
		for (int i = trainingSet.length / 2 - 1; i >= 0; i--) {
			adjustHeap(i, tempLength, paraDistances, tempIndexes);
		}

		for (int i = 0; i < numNeighbors; i++) {
			resultNearests[i] = trainingSet[tempIndexes[0]];
			tempIndexes[0] = tempIndexes[tempLength - i - 1];
			paraDistances[0] = paraDistances[tempLength - i - 1];
			adjustHeap(0, tempLength - i - 1, paraDistances, tempIndexes);
		}
		return resultNearests;
	}// Of selectWithHeap

Time complexity based on heap sort algorithm: O ( n + k log ⁡ n ) O(n+k\log n) O(n+klogn).
Selection time complexity based on selection sort algorithm: O ( k n ) O(kn) O(kn);
Test result graph:

In general, the time overhead of the latter will be greater than the former.
After selecting the most recent k instances, the voting is performed. According to the selected index, the corresponding category is found, and the counter of the category is made ++, and finally the one with the most votes wins:

	/**
	 * 
	 *********************
	 * @Title: simpleVoting
	 * @Description: TODO(Voting using the instances.)
	 *
	 * @param paraNeihbors The indices of the neighbors.
	 * @return The predicted label.
	 *********************
	 *
	 */
	public int simpleVoting(int[] paraNeihbors) {
		int[] tempVotes = new int[dataset.numClasses()];
		for (int i = 0; i < paraNeihbors.length; i++) {
			tempVotes[(int) dataset.instance(paraNeihbors[i]).classValue()]++;
		} // Of for i

		int tempMaximalVoting = 0;
		int tempMaximalVotingIndex = 0;
		for (int i = 0; i < dataset.numClasses(); i++) {
			if (tempVotes[i] > tempMaximalVoting) {
				tempMaximalVoting = tempVotes[i];
				tempMaximalVotingIndex = i;
			} // Of if
		} // Of for i

		return tempMaximalVotingIndex;
	}// Of simpleVoting

This fully embodies "the one who is close to Zhu is red, and the one who is close to ink is black".
Why is it a simple vote? In fact, if you think about it carefully, you can't just look at the number of votes, you can also look at the distance. The closer you are, the higher the possibility of being in the same category. So, if you encounter the same number of votes, you know what to do. Do it.

Finally, it is to predict and obtain the prediction accuracy:

	/**
	 * 
	 *********************
	 * @Title: predict
	 * @Description: TODO(Predict for the whole testing set. The results are stored
	 *               in predictions.) #see predictions.
	 *********************
	 *
	 */
	public void predict() {
		predictions = new int[testingSet.length];
		for (int i = 0; i < predictions.length; i++) {
			predictions[i] = predict(testingSet[i]);
		} // Of for i
	}// Of predict
	
	/**
	 * 
	 *********************
	 * @Title: predict
	 * @Description: TODO(Predict for given instance.)
	 *
	 * @return The prediction.
	 *********************
	 *
	 */
	public int predict(int paraIndex) {
		int[] tempNeighbors = computeNearests(paraIndex);
		int resultPrediction = simpleVoting(tempNeighbors);

		return resultPrediction;
	}// Of predict
	
	/**
	 * 
	 *********************
	 * @Title: getAccuracy
	 * @Description: TODO(Get the accuracy of the classifier.)
	 *
	 * @return The accuracy.
	 *********************
	 *
	 */
	public double getAccuracy() {
		// A double divides an int gets another double.
		double tempCorrect = 0;
		for (int i = 0; i < predictions.length; i++) {
			if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
				tempCorrect++;
			} // Of if
		} // Of for i

		return tempCorrect / testingSet.length;
	}// Of getAccuracy

Attach the main function:

/**
	 *********************
	 * The entrance of the program.
	 * 
	 * @param args Not used now.
	 *********************
	 */
	public static void main(String args[]) {
		KnnClassification tempClassifier = new KnnClassification("E:/Weka-3-8-6/data/iris.arff");
		tempClassifier.splitTrainingTesting(0.8);
		tempClassifier.predict();
		System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());

		// for (int i = 4; i < 9; i++) {
		// tempClassifier.setNumNeighbors(i);
		// System.out.println("------------The test with " + tempClassifier.numNeighbors
		// + " neighors------------");
		// tempClassifier.predict();
		// System.out.println("The accuracy of the classifier is: " +
		// tempClassifier.getAccuracy());
		// } // Of for i

	}// Of main

where the for loop is a simple test for k:

But the data set is too small, it is not meaningful to do so~ However, I still have to remind that if it is found that the accuracy is still increasing with the increase of k, it means that the current test to the boundary value of k is only the most suitable parameter, It is also necessary to expand the range of k to continue the measurement.

Summary: Through this training, I found a lot of shortcomings in myself, such as not being able to draw in Java, not knowing Instance and Attribute, which led me to try normalization and add the processed data as a new attribute to the original data I have not been successful, and I still need to slowly explore the document~

Tags: Java Machine Learning AI

Posted by kkessler on Thu, 05 May 2022 17:31:35 +0300