Fit.h 6.9 KB


  1. #ifndef FIT_H
  2. #define FIT_H
  3. #include <vector>
  4. using namespace std;
  5. class Fit{
  6. public:
  7. std::vector<double> factor; //拟合后的方程系数
  8. double ssr; //回归平方和
  9. double sse; //剩余平方和
  10. double rmse; //均方根误差
  11. std::vector<double> fitedYs;///<存放拟合后的y值,在拟合时可设置为不保存节省内存
  12. double r; //可信度判断
  13. public:
  14. Fit():ssr(0),sse(0),rmse(0){factor.resize(2,0);}
  15. public:
  16. template<typename T>
  17. bool linearFit(const std::vector<typename T>& x, const std::vector<typename T>& y,bool isSaveFitYs=false)
  18. {
  19. return linearFit(&x[0],&y[0],getSeriesLength(x,y),isSaveFitYs);
  20. }
  21. template<typename T>
  22. bool linearFit(const T* x, const T* y,size_t length,bool isSaveFitYs=false)
  23. {
  24. factor.resize(2,0);
  25. typename T t1 = 0, t2 = 0, t3 = 0, t4 = 0;
  26. for(size_t i = 0; i < length; ++i)
  27. {
  28. t1 += x[i]*x[i];
  29. t2 += x[i];
  30. t3 += x[i]*y[i];
  31. t4 += y[i];
  32. }
  33. factor[1] = (t3*length - t2*t4) / (t1*length - t2*t2);
  34. factor[0] = (t1*t4 - t2*t3) / (t1*length - t2*t2);
  35. //////////////////////////////////////////////////////////////////////////
  36. //计算误差
  37. calcError(x,y,length,this->ssr,this->sse,this->rmse,isSaveFitYs);
  38. calcReliability(x,y,length,this->r);
  39. return true;
  40. }
  41. /// \brief 多项式拟合,拟合y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n
  42. /// \param x 观察值的x
  43. /// \param y 观察值的y
  44. /// \param poly_n 期望拟合的阶数,若poly_n=2,则y=a0+a1*x+a2*x^2
  45. /// \param isSaveFitYs 拟合后的数据是否保存,默认是
  46. ///
  47. template<typename T>
  48. void polyfit(const std::vector<typename T>& x
  49. ,const std::vector<typename T>& y
  50. ,int poly_n
  51. ,bool isSaveFitYs=true)
  52. {
  53. polyfit(&x[0],&y[0],getSeriesLength(x,y),poly_n,isSaveFitYs);
  54. }
  55. template<typename T>
  56. void polyfit(const T* x,const T* y,size_t length,int poly_n,bool isSaveFitYs=true)
  57. {
  58. factor.resize(poly_n+1,0);
  59. int i = 0;
  60. int j = 0;
  61. std::vector<double> tempx(length,1.0);
  62. std::vector<double> tempy(y,y+length);
  63. std::vector<double> sumxx(poly_n*2+1);
  64. std::vector<double> ata((poly_n+1)*(poly_n+1));
  65. std::vector<double> sumxy(poly_n+1);
  66. for (i = 0; i <= 2*poly_n; i++){
  67. for (sumxx[i]=0,j=0;j<length;j++)
  68. {
  69. sumxx[i]+=tempx[j];
  70. tempx[j]*=x[j];
  71. }
  72. }
  73. for (i = 0;i <= poly_n;i++){
  74. for (sumxy[i] = 0,j = 0;j < length;j++)
  75. {
  76. sumxy[i]+=tempy[j];
  77. tempy[j]*=x[j];
  78. }
  79. }
  80. for (i = 0; i <= poly_n; i++)
  81. for (j = 0;j <= poly_n; j++)
  82. ata[i*(poly_n+1)+j]=sumxx[i+j];
  83. gauss_solve(poly_n+1,ata,factor,sumxy);
  84. //计算拟合后的数据并计算误差
  85. fitedYs.reserve(length);
  86. calcError(&x[0],&y[0],length,this->ssr,this->sse,this->rmse,isSaveFitYs);
  87. }
  88. void getFactor(std::vector<double>& factor){factor = this->factor;}
  89. /// \brief 获取拟合方程对应的y值,前提是拟合时设置isSaveFitYs为true
  90. ///
  91. void getFitedYs(std::vector<double>& fitedYs){fitedYs = this->fitedYs;}
  92. /// \brief 根据x获取拟合方程的y值
  93. /// \return 返回x对应的y值
  94. ///
  95. template<typename T>
  96. double getY(const T x) const
  97. {
  98. double ans(0);
  99. for (size_t i = 0;i < factor.size();++i)
  100. {
  101. ans += factor[i]*pow((double)x,(int)i);
  102. }
  103. return ans;
  104. }
  105. //获取斜率
  106. double getSlope(){return factor[1];}
  107. ///
  108. /// \brief 获取截距
  109. /// \return 截距值
  110. ///
  111. double getIntercept(){return factor[0];}
  112. ///
  113. /// \brief 剩余平方和
  114. /// \return 剩余平方和
  115. ///
  116. double getSSE(){return sse;}
  117. ///
  118. /// \brief 回归平方和
  119. /// \return 回归平方和
  120. ///
  121. double getSSR(){return ssr;}
  122. ///
  123. /// \brief 均方根误差
  124. /// \return 均方根误差
  125. ///
  126. double getRMSE(){return rmse;}
  127. ///
  128. /// \brief 确定系数,系数是0~1之间的数,是数理上判定拟合优度的一个量
  129. /// \return 确定系数
  130. ///
  131. double getR_square(){return 1-(sse/(ssr+sse));}
  132. template<typename T>
  133. size_t getSeriesLength(const std::vector<typename T>& x
  134. ,const std::vector<typename T>& y)
  135. {
  136. return (x.size() > y.size() ? y.size() : x.size());
  137. }
  138. /// \brief 计算均值
  139. /// \return 均值
  140. ///
  141. template <typename T>
  142. static T Mean(const std::vector<T>& v)
  143. {
  144. return Mean(&v[0],v.size());
  145. }
  146. template <typename T>
  147. static T Mean(const T* v,size_t length)
  148. {
  149. T total(0);
  150. for (size_t i=0;i<length;++i)
  151. {
  152. total += v[i];
  153. }
  154. return (total / length);
  155. }
  156. ///
  157. /// \brief 获取拟合方程系数的个数
  158. /// \return 拟合方程系数的个数
  159. ///
  160. size_t getFactorSize(){return factor.size();}
  161. ///
  162. /// \brief 根据阶次获取拟合方程的系数,
  163. /// 如getFactor(2),就是获取y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n中a2的值
  164. /// \return 拟合方程的系数
  165. ///
  166. double getFactor(size_t i){return factor.at(i);}
  167. double getR(){ return r;}
  168. private:
  169. template<typename T>
  170. void calcError(const T* x
  171. ,const T* y
  172. ,size_t length
  173. ,double& r_ssr
  174. ,double& r_sse
  175. ,double& r_rmse
  176. ,bool isSaveFitYs=true
  177. )
  178. {
  179. T mean_y = Mean<T>(y,length);
  180. T yi(0);
  181. fitedYs.reserve(length);
  182. for (size_t i = 0; i < length; ++i)
  183. {
  184. yi = getY(x[i]);
  185. r_ssr += ((yi-mean_y)*(yi-mean_y));//计算回归平方和
  186. r_sse += ((yi-y[i])*(yi-y[i]));//残差平方和
  187. if (isSaveFitYs)
  188. {
  189. fitedYs.push_back(double(yi));
  190. }
  191. }
  192. r_rmse = sqrt(r_sse/(double(length)));
  193. }
  194. template<typename T>
  195. void gauss_solve(int n
  196. ,std::vector<typename T>& A
  197. ,std::vector<typename T>& x
  198. ,std::vector<typename T>& b)
  199. {
  200. gauss_solve(n,&A[0],&x[0],&b[0]);
  201. }
  202. template<typename T>
  203. void gauss_solve(int n
  204. ,T* A
  205. ,T* x
  206. ,T* b)
  207. {
  208. int i,j,k,r;
  209. double max;
  210. for (k=0;k<n-1;k++)
  211. {
  212. max=fabs(A[k*n+k]); /*find maxmum*/
  213. r=k;
  214. for (i=k+1;i<n-1;i++){
  215. if (max<fabs(A[i*n+i]))
  216. {
  217. max=fabs(A[i*n+i]);
  218. r=i;
  219. }
  220. }
  221. if (r!=k){
  222. for (i=0;i<n;i++) /*change array:A[k]&A[r] */
  223. {
  224. max=A[k*n+i];
  225. A[k*n+i]=A[r*n+i];
  226. A[r*n+i]=max;
  227. }
  228. }
  229. max=b[k]; /*change array:b[k]&b[r] */
  230. b[k]=b[r];
  231. b[r]=max;
  232. for (i=k+1;i<n;i++)
  233. {
  234. for (j=k+1;j<n;j++)
  235. A[i*n+j]-=A[i*n+k]*A[k*n+j]/A[k*n+k];
  236. b[i]-=A[i*n+k]*b[k]/A[k*n+k];
  237. }
  238. }
  239. for (i=n-1;i>=0;x[i]/=A[i*n+i],i--)
  240. for (j=i+1,x[i]=b[i];j<n;j++)
  241. x[i]-=A[i*n+j]*x[j];
  242. }
  243. template<typename T>
  244. void calcReliability(const T* x
  245. ,const T* y
  246. ,size_t length
  247. ,double& r
  248. )
  249. {
  250. int flag = 0;
  251. for (size_t i = 0;i < length;i++)
  252. {
  253. if (y[i] != y[0])
  254. {
  255. flag = 1;
  256. break;
  257. }
  258. }
  259. if (flag == 0)
  260. {
  261. r = 1;
  262. return;
  263. }
  264. T A(0),B(0);
  265. for (size_t i = 0;i < length;i++)
  266. {
  267. A += x[i];
  268. B += y[i];
  269. }
  270. T Amean(0),Bmean(0);
  271. Amean = A/length;
  272. Bmean = B/length;
  273. T Asum(0),Bsum(0);
  274. T E(0),F(0);
  275. for (size_t i = 0;i < length;i++)
  276. {
  277. Asum += (x[i] - Amean)*(x[i] - Amean);
  278. Bsum += (y[i] - Bmean)*(y[i] - Bmean);
  279. E += (x[i] - Amean)*(y[i] - Bmean);
  280. }
  281. F = sqrt(Asum)*sqrt(Bsum);
  282. r = pow((double)E*0.1/((double)F*0.1),2);
  283. }
  284. };
  285. #endif