parametersearch.cpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. // Copyright 2008 Rarefied Technologies, Inc.
  2. // Distributed under the GPL v2 please see
  3. // LICENSE file for more information.
  4. #include "parametersearch.h"
  5. #include "parameterresult.h"
  6. #include "../thirdparty/libsvm/svm.h"
  7. #include <cmath>
  8. #include <iostream>
  9. #include <fstream>
  10. #define REFINED_RANGE 1.33
  11. CParameterSearch::CParameterSearch(svm_problem* pProb, svm_parameter* pSvmParam, string strFilename)
  12. {
  13. if(!pProb || !pSvmParam)
  14. return;
  15. m_pOA = NULL;
  16. m_pofs = NULL;
  17. m_strFilename = strFilename;
  18. ResetSerialization();
  19. m_pProblem = pProb;
  20. m_pSvmParam = pSvmParam;
  21. m_rangeParameters.fParam1Min = -15;
  22. //m_rangeParameters.fParam1Max = -1;
  23. m_rangeParameters.fParam1Max = 3;
  24. m_rangeParameters.fParam1Step = 4;
  25. m_rangeParameters.bParam1UseLog = true;
  26. m_rangeParameters.fParam1RefinementFactor = 2;
  27. m_rangeParameters.fParam2Min = -13;
  28. //m_rangeParameters.fParam2Max = -1;
  29. m_rangeParameters.fParam2Max = -5;
  30. m_rangeParameters.fParam2Step = 4;
  31. m_rangeParameters.bParam2UseLog = true;
  32. m_rangeParameters.fParam2RefinementFactor = 2;
  33. ParameterResult* pResult = new ParameterResult;
  34. const RangeParameters tempRP = m_rangeParameters;
  35. *m_pOA << tempRP;
  36. m_pofs->flush();
  37. SearchRange( pResult, m_rangeParameters);
  38. // semi-infinite, breakable loop
  39. while( pResult = GetNextResult())
  40. {
  41. RangeParameters Params;
  42. GetRefinedParameters(pResult->nLevel, pResult->fParam1, pResult->fParam2, Params);
  43. if(!SearchRange( pResult, Params))
  44. break;
  45. }
  46. }
  47. CParameterSearch::~CParameterSearch()
  48. {
  49. }
  50. ParameterResult* CParameterSearch::GetNextResult()
  51. {
  52. ParameterResult* pResult = NULL;
  53. ResultsSet::iterator it = m_searchResults.begin();
  54. while(it != m_searchResults.end())
  55. {
  56. pResult = *it;
  57. if(!pResult->bRefined)
  58. return pResult;
  59. ++it;
  60. }
  61. return NULL;
  62. }
  63. bool CParameterSearch::SearchRange(ParameterResult* pResult, RangeParameters& Params)
  64. {
  65. if(!m_pProblem || !m_pSvmParam || !pResult)
  66. return false;
  67. float fParam1 = 0;
  68. float fParam2 = 0;
  69. double* target = new double[m_pProblem->l];
  70. /* for(int i=0; i<m_pProblem->l; i++)
  71. {
  72. target[i] = 0;
  73. }*/
  74. for(fParam1=Params.fParam1Min; fParam1<=Params.fParam1Max; fParam1+=Params.fParam1Step)
  75. {
  76. if(Params.bParam1UseLog)
  77. m_pSvmParam->p = ::pow(2,fParam1);
  78. else
  79. m_pSvmParam->p = fParam1;
  80. for(fParam2=Params.fParam2Min; fParam2<=Params.fParam2Max; fParam2+=Params.fParam2Step)
  81. {
  82. if(Params.bParam2UseLog)
  83. m_pSvmParam->C = ::pow(2,fParam2);
  84. else
  85. m_pSvmParam->C = fParam2;
  86. int nFolds = 2;
  87. svm_cross_validation(m_pProblem, m_pSvmParam, nFolds, target);
  88. float fError = 0;
  89. float fWrong = 0;
  90. for(int i=0; i<m_pProblem->l; i++)
  91. {
  92. fError += abs(m_pProblem->y[i] - target[i]);
  93. if( m_pProblem->y[i] >= 0.5 && target[i] < 0.5)
  94. fWrong++;
  95. else if( m_pProblem->y[i] < 0.5 && target[i] >= 0.5)
  96. fWrong++;
  97. }
  98. fError = (float)fError/m_pProblem->l;
  99. fWrong = (float) fWrong/m_pProblem->l;
  100. float fStdDev = 0;
  101. for(int i=0; i<m_pProblem->l; i++)
  102. {
  103. fStdDev += pow(fError - abs(m_pProblem->y[i] - target[i]), 2) ;
  104. }
  105. fStdDev = pow(fStdDev, (float)0.5) / m_pProblem->l;
  106. std::cout << "\n****************" << std::endl;
  107. std::cout << "C = 2^" << fParam2 << ", epsilon = 2^" << fParam1 << std::endl;
  108. std::cout << "Avg Error: " << fError << " Std Dev: " << fStdDev << std::endl;
  109. std::cout << "Percent wrong: " << fWrong << std::endl;
  110. ParameterResult* pNewResult = new ParameterResult;
  111. pNewResult->fError = fError;
  112. pNewResult->fStdDev = fStdDev;
  113. pNewResult->fWrong = fWrong;
  114. pNewResult->fParam1 = fParam1;
  115. pNewResult->fParam2 = fParam2;
  116. pNewResult->nLevel = pResult->nLevel + 1;
  117. m_searchResults.insert(pNewResult);
  118. const ParameterResult* pConstResult = pNewResult;
  119. *m_pOA << *pConstResult;
  120. m_pofs->flush();
  121. //SaveTextResults();
  122. }
  123. }
  124. pResult->bRefined = true;
  125. SerializeData();
  126. return true;
  127. }
  128. // nLevel is desired level, not current level
  129. bool CParameterSearch::GetRefinedParameters(int nLevel, float fParam1, float fParam2, RangeParameters& paramsOut)
  130. {
  131. paramsOut.bParam1UseLog = m_rangeParameters.bParam1UseLog;
  132. float fParam1StepPrev = m_rangeParameters.fParam1Step / pow(m_rangeParameters.fParam1RefinementFactor, nLevel-1);
  133. paramsOut.fParam1Min = max( m_rangeParameters.fParam1Min, fParam1 - (float)REFINED_RANGE * fParam1StepPrev );
  134. paramsOut.fParam1Max = min( m_rangeParameters.fParam1Max, fParam1 + (float)REFINED_RANGE * fParam1StepPrev );
  135. paramsOut.fParam1Step = m_rangeParameters.fParam1Step / pow(m_rangeParameters.fParam1RefinementFactor, nLevel);
  136. paramsOut.bParam2UseLog = m_rangeParameters.bParam2UseLog;
  137. float fParam2StepPrev = m_rangeParameters.fParam2Step / pow(m_rangeParameters.fParam2RefinementFactor, nLevel-1);
  138. paramsOut.fParam2Min = max( m_rangeParameters.fParam2Min, fParam2 - (float)REFINED_RANGE * fParam2StepPrev );
  139. paramsOut.fParam2Max = min( m_rangeParameters.fParam2Max, fParam2 + (float)REFINED_RANGE * fParam2StepPrev );
  140. paramsOut.fParam2Step = m_rangeParameters.fParam2Step / pow(m_rangeParameters.fParam2RefinementFactor, nLevel);
  141. return true;
  142. }
  143. void CParameterSearch::ResetSerialization()
  144. {
  145. delete m_pOA;
  146. delete m_pofs;
  147. m_pofs = new std::ofstream(m_strFilename.c_str());
  148. m_pOA = new boost::archive::text_oarchive(*m_pofs);
  149. }
  150. void CParameterSearch::SerializeData()
  151. {
  152. ResetSerialization();
  153. const CParameterSearch* pSearch = this;
  154. *m_pOA << *pSearch;
  155. m_pofs->flush();
  156. }
  157. void CParameterSearch::SaveTextResults()
  158. {
  159. std::ofstream ofs("SearchResults.txt");
  160. ofs << *this;
  161. }
  162. std::ostream & operator<<(std::ostream &os, const CParameterSearch &ps)
  163. {
  164. os << ps.m_rangeParameters << std::endl << std::endl;
  165. ResultsSet::iterator it = ps.m_searchResults.begin();
  166. while(it != ps.m_searchResults.end())
  167. {
  168. const ParameterResult* pResult = *it;
  169. os << *pResult << '\n';
  170. ++it;
  171. }
  172. return os;
  173. }