Building the KNN algorithm With JavaScript
k-Nearest Neighbor (KNN)
The KNN is a simple, fast, and straightforward classification algorithm. It is very useful for categorized numerical datasets, where the data is naturally clustered. It will feel similar in some ways to the k-means clustering algorithm; with the major distinction being that k-means is an unsupervised algorithm while KNN is a supervised learning algorithm.
If you wish to perform a KNN analysis manually, here’s how it should go: first, plot all your training data on a graph and label each point with its category or label. When you wish to classify a new, unknown point, put it on the graph and find the k closest points to it (the nearest neighbors).
The number k should be an odd number in order to avoid ties; three is a good starting point, but some applications will need more and some can get away with one. Report whatever the majority of the k nearest neighbors is classified, as that will be the result of the algorithm.
Finding the k nearest neighbors to a test point is straightforward, but you can use some optimizations if your training data is very large. Typically, when evaluating a new point, you would calculate the Euclidean distance (the typical, high school geometry distance measure) between your test point and every other training point, and sort them by distance. This algorithm is quite fast because the training data is generally not more than 10,000 points or so.
If you have many training examples (in the order of millions) or you really need the algorithm to be lightning-fast, there are two optimizations you can make. The first is to skip the square root operation in the distance measure and use the squared distance instead. While modern CPUs are very fast, the square root operation is still much slower than multiplication and addition, so you can save a few milliseconds by avoiding the square root.
The second optimization is to only consider points within some bounding rectangle of distance to your test point; for instance, only consider points within +/- 5 units in each dimension from the test point’s location. If your training data is dense, this optimization will not affect the results but will speed up the algorithm because it will avoid calculating distances for many points.
The following is the KNN algorithm as a high-level description:
- Record all training data and their labels
- Given a new point to evaluate, generate a list of its distances to all training points
- Sort the list of distances in the order of closest to farthest
- Throw out all but the knearest distances
- Determine which label represents the majority of your knearest neighbors; this is the result of the algorithm
A more efficient version avoids maintaining a large list of distances that need to be sorted by limiting the list of distances to k items. Now get started with your implementation of the KNN algorithm.
Building the KNN algorithm
Since the KNN algorithm is quite simple, you can build your own implementation:
- Create a new folder and name it Ch5-knn.
- Add the following jsonfile to the folder. Here, you have added a dependency for the jimp library, which is an image processing library:
{ "name": "Ch5-knn", "version": "1.0.0", "description": "ML in JS Example for Chapter 5 - k-nearest-neighbor", "main": "src/index.js", "author": "Burak Kanber", "license": "MIT", "scripts": { "build-web": "browserify src/index.js -o dist/index.js -t [ babelify --presets [ env ] ]", "build-cli": "browserify src/index.js --node -o dist/index.js -t [ babelify --presets [ env ] ]", "start": "yarn build-cli && node dist/index.js" }, "dependencies": { "babel-core": "^6.26.0", "babel-plugin-transform-object-rest-spread": "^6.26.0", "babel-preset-env": "^1.6.1", "babelify": "^8.0.0", "browserify": "^15.1.0", "jimp": "^0.2.28" } }
- Run the yarn installcommand to download and install all the dependencies and then create subfolders called src, dist, and files.
- Inside the srcfolder, create an js file and a knn.js file.
You will also need a data.js file. For these examples, a larger dataset has been used which is difficult to be printed here, so you should take a minute to download the Ch5-knn/src/data.js file from GitHub. You can also find the complete code for this article at https://github.com/PacktPublishing/Hands-On-Machine-Learning-with-JavaScript/tree/master/Chapter05/Ch5-knn.
- Start with the jsfile. You’ll need a distance-measuring function. Add the following to the beginning of knn.js:
/** * Calculate the distance between two points. * Points must be given as arrays or objects with equivalent keys. * @param {Array.<number>} a * @param {Array.<number>} b * @return {number} */ const distance = (a, b) => Math.sqrt( a.map((aPoint, i) => b[i] - aPoint) .reduce((sumOfSquares, diff) => sumOfSquares + (diff*diff), 0) );
If you really need a performance optimization for your KNN implementation, this is where you might omit the Math.sqrt operation and return just the squared distance. However, since this is such a fast algorithm by nature, you need to do this only if you’re working on an extreme problem with a lot of data or with very strict speed requirements.
- Next, add the stub of your KNN class. Add the following to js, beneath the distance function:
class KNN { constructor(k = 1, data, labels) { this.k = k; this.data = data; this.labels = labels; } } export default KNN;
The constructor accepts three arguments: the k or the number of neighbors to consider when classifying your new point, the training data split up into the data points alone, and a corresponding array of their labels.
- Next, you need to add an internal method that considers a test point and calculates a sorted list of distances from the test point to the training points. You can call this a distance map. Add the following to the body of the KNN class:
generateDistanceMap(point) { const map = []; let maxDistanceInMap; for (let index = 0, len = this.data.length; index < len; index++) { const otherPoint = this.data[index]; const otherPointLabel = this.labels[index]; const thisDistance = distance(point, otherPoint); /** * Keep at most k items in the map. * Much more efficient for large sets, because this * avoids storing and then sorting a million-item map. * This adds many more sort operations, but hopefully k is small. */ if (!maxDistanceInMap || thisDistance < maxDistanceInMap) { // Only add an item if it's closer than the farthest of the candidates map.push({ index, distance: thisDistance, label: otherPointLabel }); // Sort the map so the closest is first map.sort((a, b) => a.distance < b.distance ? -1 : 1); // If the map became too long, drop the farthest item if (map.length > this.k) { map.pop(); } // Update this value for the next comparison maxDistanceInMap = map[map.length - 1].distance; } } return map; }
This method could be easier to read, but the simpler version is not efficient for very large training sets. What you’re doing here is maintaining a list of points that might be the KNNs and storing them in map.
By maintaining a variable called maxDistanceInMap, you can loop over every training point and make a simple comparison to see whether the point should be added to your candidates’ list. If the point you’re iterating over is closer than the farthest of your candidates, you can add the point to the list, re-sort the list, remove the farthest point to keep the list small, and then update mapDistanceInMap.
If that sounds like a lot of work, a simpler version might loop overall points, add each one with its distance measurement to the map, sort the map, and then return the first k items. The downside of this implementation is that for a dataset of a million points, you’d need to build a distance map of a million points and then sort that giant list in memory.
In your version, you only ever hold k items as candidates, so you never need to store a separate million-point map. Your version does require a call to Array.sort whenever an item is added to the map. This is inefficient in its own way, as the sort function is called for each addition to the map. Fortunately, the sort operation is only for k items, where k might be something like 3 or 5.
The computational complexity of the sorting algorithm is most likely O(n log n) (for a quicksort or mergesort implementation), so it only takes about 30 data points for the sophisticated version to be more efficient than the simple version when k = 3, and for k = 5, this happens at around 3,000 data points. However, both versions are so fast that for a dataset smaller than 3,000 points, you won’t notice the difference.
- Finally, tie the algorithm together with the predict The predictmethod must accept a test point, and at the very least, return the determined label for the point. You can also add some additional output to the method and report the labels of the k nearest neighbors as well as the number of votes each label contributed. Add the following to the body of the KNN class:
predict(point) { const map = this.generateDistanceMap(point); const votes = map.slice(0, this.k); const voteCounts = votes // Reduces into an object like {label: voteCount} .reduce((obj, vote) => Object.assign({}, obj, {[vote.label]: (obj[vote.label] || 0) + 1}), {}) ; const sortedVotes = Object.keys(voteCounts) .map(label => ({label, count: voteCounts[label]})) .sort((a, b) => a.count > b.count ? -1 : 1) ; return { label: sortedVotes[0].label, voteCounts, votes }; }
This method requires a little bit of datatype juggling in JavaScript but is simple in concept. First, generate your distance map using the method you just implemented. Then, remove all data except for the k nearest points and store that in a votes variable. If you’re using 3 as k, then votes will be an array of length three.
Now that you have your k nearest neighbors, you need to figure out which label represents the majority of the neighbors. You can do this by reducing your votes array into an object called voteCounts. To get a picture of what you want voteCounts to look like, imagine that you’re looking for the three nearest neighbors and the possible categories are Male or Female. The voteCounts variable might look like this: {“Female”: 2, “Male”: 1}.
The job is still not done, however—after reducing your votes into a vote-count object, you still need to sort that and determine the majority label. You can do this by mapping the vote counts object back into an array and then sorting the array based on vote counts.
There are other ways to approach this problem of tallying votes; any method you can think of will work, as long as you can return the majority vote at the end of the day. That’s all you need to do in the knn.js file. The algorithm is complete, requiring fewer than 70 lines of code.
Now set up your index.js file and get ready to run some examples. Remember that you need to download the data.js file first. You can do this by downloading the file from https://github.com/bkanber/MLinJSBook. Now add the following to the top of index.js:
import KNN from'./knn.js'; import {weight_height} from'./data.js';
You can now try out the algorithm using a simple example.
Example – Height, weight, and gender
KNN, like k-means, can work on high-dimensional data—but, like k-means, you can only graph example data in a two-dimensional plane, so keep your example simple. The first question you’ll tackle is: can you predict a person’s biological sex given only their height and weight?
The data for this example has been downloaded from a national longitudinal survey on people’s perception of their weight. Included in the data are the respondents’ height, weight, and gender. This is what the data looks like when graphed:
Just by looking at the preceding charted data, you can get a sense as to why KNN is so effective at evaluating clustered data. It’s true that there’s no neat boundary between male and female, but if you were to evaluate a new data point of a 200 pound, 72 inches-tall person, it’s clear that all the training data around that point is male and it’s likely your new point is male, too.
Conversely, a new respondent at 125 pounds and a height of 62 inches is well into the female area of the graph, though there are a couple of males with those characteristics as well. The middle of the graph, around 145 pounds and 65 inches tall, is the most ambiguous, with an even split of male and female training points. Expect the algorithm to be uncertain about the new points in that area. As there is no clear dividing line in this dataset, you would need more features or more dimensions to get a better resolution of the boundaries.
In any case, try out a few examples. Pick five points that you may expect to be definitely male, definitely female, probably male, probably female, and indeterminable. Add the following code to index.js, beneath the two import lines:
console.log("Testing height and weight with k=5"); console.log("=========================="); constsolver1 = new KNN(5, weight_height.data, weight_height.labels); console.log("Testing a 'definitely male' point:"); console.log(solver1.predict([200, 75])); console.log("\nTesting a 'probably male' point:"); console.log(solver1.predict([170, 70])); console.log("\nTesting a 'totally uncertain' point:"); console.log(solver1.predict([140, 64])); console.log("\nTesting a 'probably female' point:"); console.log(solver1.predict([130, 63])); console.log("\nTesting a 'definitely female' point:"); console.log(solver1.predict([120, 60]));
Run yarn start from the command line and you should see the following output. Since the KNN is not stochastic that is it does not use any random conditions in its evaluation, you should see exactly the same output with the possible exception of the ordering of votes and their indexes, if two votes have the same distance.
If you get an error when you run yarn start, make sure that your data.js file has been correctly downloaded and installed.
Here’s the output from the preceding code:
Testing heightand weight withk=5 ====================================================================== Testing a'definitely male' point: { label: 'Male', voteCounts: { Male: 5 }, votes: [ { index: 372, distance: 0, label: 'Male' }, { index: 256, distance: 1, label: 'Male' }, { index: 291, distance: 1, label: 'Male' }, { index: 236, distance: 2.8284271247461903, label: 'Male' }, { index: 310, distance: 3, label: 'Male' } ] } Testing a'probably male' point: { label: 'Male', voteCounts: { Male: 5 }, votes: [ { index: 463, distance: 0, label: 'Male' }, { index: 311, distance: 0, label: 'Male' }, { index: 247, distance: 1, label: 'Male' }, { index: 437, distance: 1, label: 'Male' }, { index: 435, distance: 1, label: 'Male' } ] } Testing a'totally uncertain' point: { label: 'Male', voteCounts: { Male: 3, Female: 2 }, votes: [ { index: 329, distance: 0, label: 'Male' }, { index: 465, distance: 0, label: 'Male' }, { index: 386, distance: 0, label: 'Male' }, { index: 126, distance: 0, label: 'Female' }, { index: 174, distance: 1, label: 'Female' } ] } Testing a'probably female' point: { label: 'Female', voteCounts: { Female: 4, Male: 1 }, votes: [ { index: 186, distance: 0, label: 'Female' }, { index: 90, distance: 0, label: 'Female' }, { index: 330, distance: 0, label: 'Male' }, { index: 51, distance: 1, label: 'Female' }, { index: 96, distance: 1, label: 'Female' } ] } Testing a'definitely female' point: { label: 'Female', voteCounts: { Female: 5 }, votes: [ { index: 200, distance: 0, label: 'Female' }, { index: 150, distance: 0, label: 'Female' }, { index: 198, distance: 1, label: 'Female' }, { index: 147, distance: 1, label: 'Female' }, { index: 157, distance: 1, label: 'Female' } ] }
The algorithm has determined genders just as you would have done, visually, by looking at the chart. Feel free to play with this example more and experiment with different values of k to see how results might differ for any given test point.
If you found this article interesting, you can explore Burak Kanber’s Hands-on Machine Learning with JavaScript to gain hands-on knowledge on evaluating and implementing the right model, along with choosing from different JS libraries, such as NaturalNode, brain, harthur, and classifier to design smarter applications. This book is a definitive guide to creating an intelligent web application with the best of machine learning and JavaScript.