Fit.h 6.1 KB


  1. #pragma
  2. #include <vector>
  3. using namespace std;
  4. class Fit{
  5. public:
  6. std::vector<double> factor; //拟合后的方程系数
  7. double ssr; //回归平方和
  8. double sse; //剩余平方和
  9. double rmse; //均方根误差
  10. std::vector<double> fitedYs;///<存放拟合后的y值,在拟合时可设置为不保存节省内存
  11. public:
  12. Fit():ssr(0),sse(0),rmse(0){factor.resize(2,0);}
  13. public:
  14. template<typename T>
  15. bool linearFit(const std::vector<typename T>& x, const std::vector<typename T>& y,bool isSaveFitYs=false)
  16. {
  17. return linearFit(&x[0],&y[0],getSeriesLength(x,y),isSaveFitYs);
  18. }
  19. template<typename T>
  20. bool linearFit(const T* x, const T* y,size_t length,bool isSaveFitYs=false)
  21. {
  22. factor.resize(2,0);
  23. typename T t1 = 0, t2 = 0, t3 = 0, t4 = 0;
  24. for(size_t i = 0; i < length; ++i)
  25. {
  26. t1 += x[i]*x[i];
  27. t2 += x[i];
  28. t3 += x[i]*y[i];
  29. t4 += y[i];
  30. }
  31. factor[1] = (t3*length - t2*t4) / (t1*length - t2*t2);
  32. factor[0] = (t1*t4 - t2*t3) / (t1*length - t2*t2);
  33. //////////////////////////////////////////////////////////////////////////
  34. //计算误差
  35. calcError(x,y,length,this->ssr,this->sse,this->rmse,isSaveFitYs);
  36. return true;
  37. }
  38. /// \brief 多项式拟合,拟合y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n
  39. /// \param x 观察值的x
  40. /// \param y 观察值的y
  41. /// \param poly_n 期望拟合的阶数,若poly_n=2,则y=a0+a1*x+a2*x^2
  42. /// \param isSaveFitYs 拟合后的数据是否保存,默认是
  43. ///
  44. template<typename T>
  45. void polyfit(const std::vector<typename T>& x
  46. ,const std::vector<typename T>& y
  47. ,int poly_n
  48. ,bool isSaveFitYs=true)
  49. {
  50. polyfit(&x[0],&y[0],getSeriesLength(x,y),poly_n,isSaveFitYs);
  51. }
  52. template<typename T>
  53. void polyfit(const T* x,const T* y,size_t length,int poly_n,bool isSaveFitYs=true)
  54. {
  55. factor.resize(poly_n+1,0);
  56. int i = 0;
  57. size_t j = 0;
  58. std::vector<double> tempx(length,1.0);
  59. std::vector<double> tempy(y,y+length);
  60. std::vector<double> sumxx(poly_n*2+1);
  61. std::vector<double> ata((poly_n+1)*(poly_n+1));
  62. std::vector<double> sumxy(poly_n+1);
  63. for (i = 0;i < 2*poly_n + 1;i++){
  64. for (sumxx[i]=0,j=0;j<length;j++)
  65. {
  66. sumxx[i]+=tempx[j];
  67. tempx[j]*=x[j];
  68. }
  69. }
  70. for (i = 0;i < poly_n + 1;i++){
  71. for (sumxy[i] = 0,j = 0;j < length;j++)
  72. {
  73. sumxy[i]+=tempy[j];
  74. tempy[j]*=x[j];
  75. }
  76. }
  77. for (i = 0;i < poly_n + 1;i++)
  78. for (j = 0;j < poly_n + 1;j++)
  79. ata[i*(poly_n+1)+j]=sumxx[i+j];
  80. gauss_solve(poly_n+1,ata,factor,sumxy);
  81. //计算拟合后的数据并计算误差
  82. fitedYs.reserve(length);
  83. calcError(&x[0],&y[0],length,this->ssr,this->sse,this->rmse,isSaveFitYs);
  84. }
  85. void getFactor(std::vector<double>& factor){factor = this->factor;}
  86. /// \brief 获取拟合方程对应的y值,前提是拟合时设置isSaveFitYs为true
  87. ///
  88. void getFitedYs(std::vector<double>& fitedYs){fitedYs = this->fitedYs;}
  89. /// \brief 根据x获取拟合方程的y值
  90. /// \return 返回x对应的y值
  91. ///
  92. template<typename T>
  93. double getY(const T x) const
  94. {
  95. double ans(0);
  96. for (size_t i = 0;i < factor.size();++i)
  97. {
  98. ans += factor[i]*pow((double)x,(int)i);
  99. }
  100. return ans;
  101. }
  102. double getSlope(){return factor[1];}
  103. ///
  104. /// \brief 获取截距
  105. /// \return 截距值
  106. ///
  107. double getIntercept(){return factor[0];}
  108. ///
  109. /// \brief 剩余平方和
  110. /// \return 剩余平方和
  111. ///
  112. double getSSE(){return sse;}
  113. ///
  114. /// \brief 回归平方和
  115. /// \return 回归平方和
  116. ///
  117. double getSSR(){return ssr;}
  118. ///
  119. /// \brief 均方根误差
  120. /// \return 均方根误差
  121. ///
  122. double getRMSE(){return rmse;}
  123. ///
  124. /// \brief 确定系数,系数是0~1之间的数,是数理上判定拟合优度的一个量
  125. /// \return 确定系数
  126. ///
  127. double getR_square(){return 1-(sse/(ssr+sse));}
  128. template<typename T>
  129. size_t getSeriesLength(const std::vector<typename T>& x
  130. ,const std::vector<typename T>& y)
  131. {
  132. return (x.size() > y.size() ? y.size() : x.size());
  133. }
  134. /// \brief 计算均值
  135. /// \return 均值
  136. ///
  137. template <typename T>
  138. static T Mean(const std::vector<T>& v)
  139. {
  140. return Mean(&v[0],v.size());
  141. }
  142. template <typename T>
  143. static T Mean(const T* v,size_t length)
  144. {
  145. T total(0);
  146. for (size_t i=0;i<length;++i)
  147. {
  148. total += v[i];
  149. }
  150. return (total / length);
  151. }
  152. ///
  153. /// \brief 获取拟合方程系数的个数
  154. /// \return 拟合方程系数的个数
  155. ///
  156. size_t getFactorSize(){return factor.size();}
  157. ///
  158. /// \brief 根据阶次获取拟合方程的系数,
  159. /// 如getFactor(2),就是获取y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n中a2的值
  160. /// \return 拟合方程的系数
  161. ///
  162. double getFactor(size_t i){return factor.at(i);}
  163. private:
  164. template<typename T>
  165. void calcError(const T* x
  166. ,const T* y
  167. ,size_t length
  168. ,double& r_ssr
  169. ,double& r_sse
  170. ,double& r_rmse
  171. ,bool isSaveFitYs=true
  172. )
  173. {
  174. T mean_y = Mean<T>(y,length);
  175. T yi(0);
  176. fitedYs.reserve(length);
  177. for (size_t i = 0; i < length; ++i)
  178. {
  179. yi = getY(x[i]);
  180. r_ssr += ((yi-mean_y)*(yi-mean_y));//计算回归平方和
  181. r_sse += ((yi-y[i])*(yi-y[i]));//残差平方和
  182. if (isSaveFitYs)
  183. {
  184. fitedYs.push_back(double(yi));
  185. }
  186. }
  187. r_rmse = sqrt(r_sse/(double(length)));
  188. }
  189. template<typename T>
  190. void gauss_solve(int n
  191. ,std::vector<typename T>& A
  192. ,std::vector<typename T>& x
  193. ,std::vector<typename T>& b)
  194. {
  195. gauss_solve(n,&A[0],&x[0],&b[0]);
  196. }
  197. template<typename T>
  198. void gauss_solve(int n
  199. ,T* A
  200. ,T* x
  201. ,T* b)
  202. {
  203. int i,j,k,r;
  204. double max;
  205. for (k=0;k<n-1;k++)
  206. {
  207. max=fabs(A[k*n+k]); /*find maxmum*/
  208. r=k;
  209. for (i=k+1;i<n-1;i++){
  210. if (max<fabs(A[i*n+i]))
  211. {
  212. max=fabs(A[i*n+i]);
  213. r=i;
  214. }
  215. }
  216. if (r!=k){
  217. for (i=0;i<n;i++) /*change array:A[k]&A[r] */
  218. {
  219. max=A[k*n+i];
  220. A[k*n+i]=A[r*n+i];
  221. A[r*n+i]=max;
  222. }
  223. }
  224. max=b[k]; /*change array:b[k]&b[r] */
  225. b[k]=b[r];
  226. b[r]=max;
  227. for (i=k+1;i<n;i++)
  228. {
  229. for (j=k+1;j<n;j++)
  230. A[i*n+j]-=A[i*n+k]*A[k*n+j]/A[k*n+k];
  231. b[i]-=A[i*n+k]*b[k]/A[k*n+k];
  232. }
  233. }
  234. for (i=n-1;i>=0;x[i]/=A[i*n+i],i--)
  235. for (j=i+1,x[i]=b[i];j<n;j++)
  236. x[i]-=A[i*n+j]*x[j];
  237. }
  238. };