KNN.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. //#include "stdafx.h"
  2. #include<iostream>
  3. #include<map>
  4. #include<vector>
  5. #include<stdio.h>
  6. #include<cmath>
  7. #include<cstdlib>
  8. #include<algorithm>
  9. #include<fstream>
  10. #include "./../CSVparser/CSVparser.h"
  11. #include "KNN.h"
  12. using namespace std;
  13. ifstream fin;
  14. ofstream fout;
  15. string Trim(string& str)
  16. {
  17. str.erase(0,str.find_first_not_of(" \t\r\n"));
  18. str.erase(str.find_last_not_of(" \t\r\n") + 1);
  19. return str;
  20. }
  21. KNN::KNN(int magic_x, int magic_y, int magic_z)
  22. {
  23. int k = 10;
  24. this->k = k;
  25. csv::Parser file = csv::Parser("d:/magic_map/trainning_data_magic.csv");
  26. std::cout << file[0][0] << std::endl; // display : 1997
  27. std::cout << file[0] << std::endl; // display : 1997 | Ford | E350
  28. cout<<file[1]<<endl;
  29. int row_count = file.rowCount();
  30. /* input the dataSet */
  31. for(int i=0;i<rowLen;i++)
  32. {
  33. for(int j=0;j<colLen;j++)
  34. {
  35. if(i < row_count)
  36. {
  37. dataSet[i][j] = atof((file[i][j]).c_str());
  38. }
  39. }
  40. if(i < row_count)
  41. {
  42. labels[i].x = atof((file[i][3]).c_str());
  43. labels[i].y = atof((file[i][4]).c_str());
  44. }
  45. }
  46. cout<<"please input the test data :"<<endl;
  47. /* inuput the test data */
  48. testData[0] = magic_x;
  49. testData[1] = magic_y;
  50. testData[2] = magic_z;
  51. }
  52. /*
  53. * calculate the distance between test data and dataSet[i]
  54. */
  55. double KNN:: get_distance(tData *d1,tData *d2)
  56. {
  57. double sum = 0;
  58. for(int i=0;i<colLen;i++)
  59. {
  60. sum += pow( (d1[i]-d2[i]) , 2 );
  61. }
  62. // cout<<"the sum is = "<<sum<<endl;
  63. return sqrt(sum);
  64. }
  65. /*
  66. * calculate all the distance between test data and each training data
  67. */
  68. void KNN:: get_all_distance()
  69. {
  70. double distance;
  71. int i;
  72. for(i=0;i<rowLen;i++)
  73. {
  74. distance = get_distance(dataSet[i],testData);
  75. //<key,value> => <i,distance>
  76. map_index_dis[i] = distance;
  77. }
  78. //traverse the map to print the index and distance
  79. map<int,double>::const_iterator it = map_index_dis.begin();
  80. while(it!=map_index_dis.end())
  81. {
  82. //cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;
  83. it++;
  84. }
  85. }
  86. /*
  87. * check which label the test data belongs to to classify the test data
  88. */
  89. tLabel KNN:: get_max_freq_label()
  90. {
  91. //transform the map_index_dis to vec_index_dis
  92. vector<PAIR> vec_index_dis( map_index_dis.begin(),map_index_dis.end() );
  93. //sort the vec_index_dis by distance from low to high to get the nearest data
  94. sort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue());
  95. for(int i=0;i<k;i++)
  96. {
  97. cout<<"the index = "<<vec_index_dis[i].first<<" the distance = "<<vec_index_dis[i].second<<" the label x= "<<labels[vec_index_dis[i].first].x<<" the label y= "<<labels[vec_index_dis[i].first].y<<" the coordinate ( "<<dataSet[ vec_index_dis[i].first ][0]<<","<<dataSet[ vec_index_dis[i].first ][1]<<" )"<<endl;
  98. //calculate the count of each label
  99. map_label_freq[ labels[ vec_index_dis[i].first ]]++;
  100. }
  101. map<tLabel,int>::const_iterator map_it = map_label_freq.begin();
  102. tLabel label;
  103. int max_freq = 0;
  104. //find the most frequent label
  105. while( map_it != map_label_freq.end() )
  106. {
  107. if( map_it->second > max_freq )
  108. {
  109. max_freq = map_it->second;
  110. label = map_it->first;
  111. }
  112. map_it++;
  113. }
  114. cout<<"The test data belongs to the x:"<<label.x<<"y:"<<label.y<<" label"<<endl;
  115. return label;
  116. }
  117. #if 0
  118. int main()
  119. {
  120. int k ;
  121. cout<<"please input the k value : "<<endl;
  122. cin>>k;
  123. KNN knn(k);
  124. knn.get_all_distance();
  125. knn.get_max_freq_label();
  126. system("pause");
  127. return 0;
  128. }
  129. #endif