package org.knuchel.playground import scala.io.Source import scala.util.Random // K-nearest neighbor /* 0.0 < sepalLength < 7.9 0.0 < sepalWidth < 4.4 0.0 < petalLength < 6.9 0.0 < petalWidth < 2.5 */ object KNNScala { def main(args: Array[String]) { val dataset = loadDataset("data/iris.data.csv") val k = 5 val features = (5.7, 2.6, 3.5, 1d) val predictedSpecie = predictSpecie(dataset, k, features) println("Iris with [sepalLength=" + features._1 + ", sepalWidth=" + features._2 + ", petalLength=" + features._3 + ", petalWidth=" + features._4 + "] should be a " + predictedSpecie); // try different values of k and see how prediction errors change // evaluate(dataset) } def predictSpecie(dataset: List[Iris], k: Int, features: (Double, Double, Double, Double)) = { dataset .map(iris => (iris.distance(features), iris.getSpecie)).sorted.take(k) // calculate distance for each sample in dataset, sort by distance and take K nearest .groupBy(_._2).map(elt => (elt._2.length, elt._1)).toList // group by specie, count occurences of species and transform map to list .sortBy(-_._1).head._2 // sort descending by number of specie occurences, get the first one and return the specie name } def loadDataset(csvFile: String) = { val file = Source.fromFile(csvFile) val iter = file.getLines().filter(s => s.length() > 0).map(line => { val cell = line.split(",") new Iris(cell(0).toDouble, cell(1).toDouble, cell(2).toDouble, cell(3).toDouble, cell(4)) }).toList file.close() iter } class Iris(sepalLength: Double, sepalWidth: Double, petalLength: Double, petalWidth: Double, specie: String) { def getSpecie = specie def getFeatures = (sepalLength, sepalWidth, petalLength, petalWidth) def distance(that: (Double, Double, Double, Double)) = Math.sqrt(Math.pow(this.sepalLength - that._1, 2) + Math.pow(this.sepalWidth - that._2, 2) + Math.pow(this.petalLength - that._3, 2) + Math.pow(this.petalWidth - that._4, 2)) // override def toString = "Iris" override def toString = "Iris [specie=" + specie + ", sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth=" + petalWidth + "]" } def evaluate(dataset: List[Iris]) { // split dataset in 2 parts : one part to learn, the other part to test val versicolor = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-versicolor")) val virginica = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-virginica")) val setosa = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-setosa")) val learningData = versicolor.take(versicolor.length / 2) ++ virginica.take(virginica.length / 2) ++ setosa.take(setosa.length / 2) val testData = versicolor.drop(versicolor.length / 2) ++ virginica.drop(virginica.length / 2) ++ setosa.drop(setosa.length / 2) // test errors for different values of k for (k <- (1 to 20)) { println(testData.filter(iris => iris.getSpecie != predictSpecie(learningData, k, iris.getFeatures)).length + " errors on " + testData.length + " tests with k=" + k); } } }