博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
K最近邻(KNN)算法原理和java实现
阅读量:2394 次
发布时间:2019-05-10

本文共 4352 字,大约阅读时间需要 14 分钟。

原理部分:

请参考:

 

 

代码实现:

 

KNN结点类,用来存储最近邻的k个元组相关的信息

/** * KNN结点类,用来存储最近邻的k个元组相关的信息 */public class KNNNode {	private int index; 			// 元组标号	private double distance; 	// 与测试元组的距离	private String c; 			// 所属类别	public KNNNode(int index, double distance, String c) {		super();		this.index = index;		this.distance = distance;		this.c = c;	}			public int getIndex() {		return index;	}	public void setIndex(int index) {		this.index = index;	}	public double getDistance() {		return distance;	}	public void setDistance(double distance) {		this.distance = distance;	}	public String getC() {		return c;	}	public void setC(String c) {		this.c = c;	}}

 

 

KNN算法主体类

/** * KNN算法主体类 */public class KNN {	/**	 * 设置优先级队列的比较函数,距离越大,优先级越高	 */	private Comparator
comparator = new Comparator
() { public int compare(KNNNode o1, KNNNode o2) { if (o1.getDistance() >= o2.getDistance()) { return 1; } else { return 0; } } }; /** * 获取K个不同的随机数 * @param k 随机数的个数 * @param max 随机数最大的范围 * @return 生成的随机数数组 */ public List
getRandKNum(int k, int max) { List
rand = new ArrayList
(k); for (int i = 0; i < k; i++) { int temp = (int) (Math.random() * max); if (!rand.contains(temp)) { rand.add(temp); } else { i--; } } return rand; } /** * 计算测试元组与训练元组之前的距离 * @param d1 测试元组 * @param d2 训练元组 * @return 距离值 */ public double calDistance(List
d1, List
d2) { double distance = 0.00; for (int i = 0; i < d1.size(); i++) { distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); } return distance; } /** * 执行KNN算法,获取测试元组的类别 * @param datas 训练数据集 * @param testData 测试元组 * @param k 设定的K值 * @return 测试元组的类别 */ public String knn(List
> datas, List
testData, int k) { PriorityQueue
pq = new PriorityQueue
(k, comparator); List
randNum = getRandKNum(k, datas.size()); for (int i = 0; i < k; i++) { int index = randNum.get(i); List
currData = datas.get(index); String c = currData.get(currData.size() - 1).toString(); KNNNode node = new KNNNode(index, calDistance(testData, currData), c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List
t = datas.get(i); double distance = calDistance(testData, t); KNNNode top = pq.peek(); if (top.getDistance() > distance) { pq.remove(); pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString())); } } return getMostClass(pq); } /** * 获取所得到的k个最近邻元组的多数类 * @param pq 存储k个最近近邻元组的优先级队列 * @return 多数类的名称 */ private String getMostClass(PriorityQueue
pq) { Map
classCount = new HashMap
(); for (int i = 0; i < pq.size(); i++) { KNNNode node = pq.remove(); String c = node.getC(); if (classCount.containsKey(c)) { classCount.put(c, classCount.get(c) + 1); } else { classCount.put(c, 1); } } int maxIndex = -1; int maxCount = 0; Object[] classes = classCount.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCount.get(classes[i]) > maxCount) { maxIndex = i; maxCount = classCount.get(classes[i]); } } return classes[maxIndex].toString(); }}

 

KNN算法测试类

/** * KNN算法测试类 */public class TestKNN {		/**	 * 从数据文件中读取数据	 * @param datas 存储数据的集合对象	 * @param path 数据文件的路径	 */	public void read(List
> datas, String path){ try { BufferedReader br = new BufferedReader(new FileReader(new File(path))); String data = br.readLine(); List
l = null; while (data != null) { String t[] = data.split(" "); l = new ArrayList
(); for (int i = 0; i < t.length; i++) { l.add(Double.parseDouble(t[i])); } datas.add(l); data = br.readLine(); } } catch (Exception e) { e.printStackTrace(); } } /** * 程序执行入口 * @param args */ public static void main(String[] args) { TestKNN t = new TestKNN(); String datafile = new File("").getAbsolutePath() + File.separator + "datafile"; String testfile = new File("").getAbsolutePath() + File.separator + "testfile"; try { List
> datas = new ArrayList
>(); List
> testDatas = new ArrayList
>(); t.read(datas, datafile); t.read(testDatas, testfile); KNN knn = new KNN(); for (int i = 0; i < testDatas.size(); i++) { List
test = testDatas.get(i); System.out.print("测试元组: "); for (int j = 0; j < test.size(); j++) { System.out.print(test.get(j) + " "); } System.out.print("类别为: "); System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3))))); } } catch (Exception e) { e.printStackTrace(); } }}

 

转载地址:http://upwob.baihongyu.com/

你可能感兴趣的文章
linux在下软件安装-jdk和tomcat安装
查看>>
java框架基础 静态代理和动态代理
查看>>
jQuery ajax开发基于json
查看>>
oracle数据库
查看>>
oracle中间的数据类型
查看>>
论文划分
查看>>
vscode利用cmake调试
查看>>
zcash挖矿
查看>>
zcash挖矿指南
查看>>
区块链术语解释
查看>>
./configure,make,make install的作用
查看>>
学术论文录用结果通知(Notification)
查看>>
Theorem等数学化的论述
查看>>
PKI和X509证书
查看>>
使用HttpClient爬取国内疫情数据
查看>>
引用传递和值传递有什么区别
查看>>
C++从入门到放肆!
查看>>
C++是什么?怎么学?学完了能得到什么?
查看>>
初学C语言没有项目练手怎么行,这17个小项目收下不谢
查看>>
学好C语言,你只需要这几句口诀!
查看>>