svm.cpp 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019
  1. #include <math.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <ctype.h>
  5. #include <float.h>
  6. #include <string.h>
  7. #include <stdarg.h>
  8. #include "svm.h"
  9. typedef float Qfloat;
  10. typedef signed char schar;
  11. #ifndef min
  12. template <class T> inline T min(T x,T y) { return (x<y)?x:y; }
  13. #endif
  14. #ifndef max
  15. template <class T> inline T max(T x,T y) { return (x>y)?x:y; }
  16. #endif
  17. template <class T> inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
  18. template <class S, class T> inline void clone(T*& dst, S* src, int n)
  19. {
  20. dst = new T[n];
  21. memcpy((void *)dst,(void *)src,sizeof(T)*n);
  22. }
  23. inline double powi(double base, int times)
  24. {
  25. double tmp = base, ret = 1.0;
  26. for(int t=times; t>0; t/=2)
  27. {
  28. if(t%2==1) ret*=tmp;
  29. tmp = tmp * tmp;
  30. }
  31. return ret;
  32. }
  33. #define INF HUGE_VAL
  34. #define TAU 1e-12
  35. #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
  36. #if 1
  37. void info(const char *fmt,...)
  38. {
  39. va_list ap;
  40. va_start(ap,fmt);
  41. vprintf(fmt,ap);
  42. va_end(ap);
  43. }
  44. void info_flush()
  45. {
  46. fflush(stdout);
  47. }
  48. #else
  49. void info(char *fmt,...) {}
  50. void info_flush() {}
  51. #endif
  52. //
  53. // Kernel Cache
  54. //
  55. // l is the number of total data items
  56. // size is the cache size limit in bytes
  57. //
  58. class Cache
  59. {
  60. public:
  61. Cache(int l,long int size);
  62. ~Cache();
  63. // request data [0,len)
  64. // return some position p where [p,len) need to be filled
  65. // (p >= len if nothing needs to be filled)
  66. int get_data(const int index, Qfloat **data, int len);
  67. void swap_index(int i, int j); // future_option
  68. private:
  69. int l;
  70. long int size;
  71. struct head_t
  72. {
  73. head_t *prev, *next; // a cicular list
  74. Qfloat *data;
  75. int len; // data[0,len) is cached in this entry
  76. };
  77. head_t *head;
  78. head_t lru_head;
  79. void lru_delete(head_t *h);
  80. void lru_insert(head_t *h);
  81. };
  82. Cache::Cache(int l_,long int size_):l(l_),size(size_)
  83. {
  84. head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0
  85. size /= sizeof(Qfloat);
  86. size -= l * sizeof(head_t) / sizeof(Qfloat);
  87. size = max(size, 2 * (long int) l); // cache must be large enough for two columns
  88. lru_head.next = lru_head.prev = &lru_head;
  89. }
  90. Cache::~Cache()
  91. {
  92. for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
  93. free(h->data);
  94. free(head);
  95. }
  96. void Cache::lru_delete(head_t *h)
  97. {
  98. // delete from current location
  99. h->prev->next = h->next;
  100. h->next->prev = h->prev;
  101. }
  102. void Cache::lru_insert(head_t *h)
  103. {
  104. // insert to last position
  105. h->next = &lru_head;
  106. h->prev = lru_head.prev;
  107. h->prev->next = h;
  108. h->next->prev = h;
  109. }
  110. int Cache::get_data(const int index, Qfloat **data, int len)
  111. {
  112. head_t *h = &head[index];
  113. if(h->len) lru_delete(h);
  114. int more = len - h->len;
  115. if(more > 0)
  116. {
  117. // free old space
  118. while(size < more)
  119. {
  120. head_t *old = lru_head.next;
  121. lru_delete(old);
  122. free(old->data);
  123. size += old->len;
  124. old->data = 0;
  125. old->len = 0;
  126. }
  127. // allocate new space
  128. h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
  129. size -= more;
  130. swap(h->len,len);
  131. }
  132. lru_insert(h);
  133. *data = h->data;
  134. return len;
  135. }
  136. void Cache::swap_index(int i, int j)
  137. {
  138. if(i==j) return;
  139. if(head[i].len) lru_delete(&head[i]);
  140. if(head[j].len) lru_delete(&head[j]);
  141. swap(head[i].data,head[j].data);
  142. swap(head[i].len,head[j].len);
  143. if(head[i].len) lru_insert(&head[i]);
  144. if(head[j].len) lru_insert(&head[j]);
  145. if(i>j) swap(i,j);
  146. for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
  147. {
  148. if(h->len > i)
  149. {
  150. if(h->len > j)
  151. swap(h->data[i],h->data[j]);
  152. else
  153. {
  154. // give up
  155. lru_delete(h);
  156. free(h->data);
  157. size += h->len;
  158. h->data = 0;
  159. h->len = 0;
  160. }
  161. }
  162. }
  163. }
  164. //
  165. // Kernel evaluation
  166. //
  167. // the static method k_function is for doing single kernel evaluation
  168. // the constructor of Kernel prepares to calculate the l*l kernel matrix
  169. // the member function get_Q is for getting one column from the Q Matrix
  170. //
  171. class QMatrix {
  172. public:
  173. virtual Qfloat *get_Q(int column, int len) const = 0;
  174. virtual Qfloat *get_QD() const = 0;
  175. virtual void swap_index(int i, int j) const = 0;
  176. virtual ~QMatrix() {}
  177. };
  178. class Kernel: public QMatrix {
  179. public:
  180. Kernel(int l, svm_node * const * x, const svm_parameter& param);
  181. virtual ~Kernel();
  182. static double k_function(const svm_node *x, const svm_node *y,
  183. const svm_parameter& param);
  184. virtual Qfloat *get_Q(int column, int len) const = 0;
  185. virtual Qfloat *get_QD() const = 0;
  186. virtual void swap_index(int i, int j) const // no so const...
  187. {
  188. swap(x[i],x[j]);
  189. if(x_square) swap(x_square[i],x_square[j]);
  190. }
  191. protected:
  192. double (Kernel::*kernel_function)(int i, int j) const;
  193. private:
  194. const svm_node **x;
  195. double *x_square;
  196. // svm_parameter
  197. const int kernel_type;
  198. const int degree;
  199. const double gamma;
  200. const double coef0;
  201. static double dot(const svm_node *px, const svm_node *py);
  202. double kernel_linear(int i, int j) const
  203. {
  204. return dot(x[i],x[j]);
  205. }
  206. double kernel_poly(int i, int j) const
  207. {
  208. return powi(gamma*dot(x[i],x[j])+coef0,degree);
  209. }
  210. double kernel_rbf(int i, int j) const
  211. {
  212. return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
  213. }
  214. double kernel_sigmoid(int i, int j) const
  215. {
  216. return tanh(gamma*dot(x[i],x[j])+coef0);
  217. }
  218. double kernel_precomputed(int i, int j) const
  219. {
  220. return x[i][(int)(x[j][0].value)].value;
  221. }
  222. };
  223. Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
  224. :kernel_type(param.kernel_type), degree(param.degree),
  225. gamma(param.gamma), coef0(param.coef0)
  226. {
  227. switch(kernel_type)
  228. {
  229. case LINEAR:
  230. kernel_function = &Kernel::kernel_linear;
  231. break;
  232. case POLY:
  233. kernel_function = &Kernel::kernel_poly;
  234. break;
  235. case RBF:
  236. kernel_function = &Kernel::kernel_rbf;
  237. break;
  238. case SIGMOID:
  239. kernel_function = &Kernel::kernel_sigmoid;
  240. break;
  241. case PRECOMPUTED:
  242. kernel_function = &Kernel::kernel_precomputed;
  243. break;
  244. }
  245. clone(x,x_,l);
  246. if(kernel_type == RBF)
  247. {
  248. x_square = new double[l];
  249. for(int i=0;i<l;i++)
  250. x_square[i] = dot(x[i],x[i]);
  251. }
  252. else
  253. x_square = 0;
  254. }
  255. Kernel::~Kernel()
  256. {
  257. delete[] x;
  258. delete[] x_square;
  259. }
  260. double Kernel::dot(const svm_node *px, const svm_node *py)
  261. {
  262. double sum = 0;
  263. while(px->index != -1 && py->index != -1)
  264. {
  265. if(px->index == py->index)
  266. {
  267. sum += px->value * py->value;
  268. ++px;
  269. ++py;
  270. }
  271. else
  272. {
  273. if(px->index > py->index)
  274. ++py;
  275. else
  276. ++px;
  277. }
  278. }
  279. return sum;
  280. }
  281. double Kernel::k_function(const svm_node *x, const svm_node *y,
  282. const svm_parameter& param)
  283. {
  284. switch(param.kernel_type)
  285. {
  286. case LINEAR:
  287. return dot(x,y);
  288. case POLY:
  289. return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
  290. case RBF:
  291. {
  292. double sum = 0;
  293. while(x->index != -1 && y->index !=-1)
  294. {
  295. if(x->index == y->index)
  296. {
  297. double d = x->value - y->value;
  298. sum += d*d;
  299. ++x;
  300. ++y;
  301. }
  302. else
  303. {
  304. if(x->index > y->index)
  305. {
  306. sum += y->value * y->value;
  307. ++y;
  308. }
  309. else
  310. {
  311. sum += x->value * x->value;
  312. ++x;
  313. }
  314. }
  315. }
  316. while(x->index != -1)
  317. {
  318. sum += x->value * x->value;
  319. ++x;
  320. }
  321. while(y->index != -1)
  322. {
  323. sum += y->value * y->value;
  324. ++y;
  325. }
  326. return exp(-param.gamma*sum);
  327. }
  328. case SIGMOID:
  329. return tanh(param.gamma*dot(x,y)+param.coef0);
  330. case PRECOMPUTED: //x: test (validation), y: SV
  331. return x[(int)(y->value)].value;
  332. default:
  333. return 0; // Unreachable
  334. }
  335. }
  336. // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
  337. // Solves:
  338. //
  339. // min 0.5(\alpha^T Q \alpha) + p^T \alpha
  340. //
  341. // y^T \alpha = \delta
  342. // y_i = +1 or -1
  343. // 0 <= alpha_i <= Cp for y_i = 1
  344. // 0 <= alpha_i <= Cn for y_i = -1
  345. //
  346. // Given:
  347. //
  348. // Q, p, y, Cp, Cn, and an initial feasible point \alpha
  349. // l is the size of vectors and matrices
  350. // eps is the stopping tolerance
  351. //
  352. // solution will be put in \alpha, objective value will be put in obj
  353. //
  354. class Solver {
  355. public:
  356. Solver() {};
  357. virtual ~Solver() {};
  358. struct SolutionInfo {
  359. double obj;
  360. double rho;
  361. double upper_bound_p;
  362. double upper_bound_n;
  363. double r; // for Solver_NU
  364. };
  365. void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
  366. double *alpha_, double Cp, double Cn, double eps,
  367. SolutionInfo* si, int shrinking);
  368. protected:
  369. int active_size;
  370. schar *y;
  371. double *G; // gradient of objective function
  372. enum { LOWER_BOUND, UPPER_BOUND, FREE };
  373. char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
  374. double *alpha;
  375. const QMatrix *Q;
  376. const Qfloat *QD;
  377. double eps;
  378. double Cp,Cn;
  379. double *p;
  380. int *active_set;
  381. double *G_bar; // gradient, if we treat free variables as 0
  382. int l;
  383. bool unshrinked; // XXX
  384. double get_C(int i)
  385. {
  386. return (y[i] > 0)? Cp : Cn;
  387. }
  388. void update_alpha_status(int i)
  389. {
  390. if(alpha[i] >= get_C(i))
  391. alpha_status[i] = UPPER_BOUND;
  392. else if(alpha[i] <= 0)
  393. alpha_status[i] = LOWER_BOUND;
  394. else alpha_status[i] = FREE;
  395. }
  396. bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
  397. bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
  398. bool is_free(int i) { return alpha_status[i] == FREE; }
  399. void swap_index(int i, int j);
  400. void reconstruct_gradient();
  401. virtual int select_working_set(int &i, int &j);
  402. virtual double calculate_rho();
  403. virtual void do_shrinking();
  404. private:
  405. bool be_shrunken(int i, double Gmax1, double Gmax2);
  406. };
  407. void Solver::swap_index(int i, int j)
  408. {
  409. Q->swap_index(i,j);
  410. swap(y[i],y[j]);
  411. swap(G[i],G[j]);
  412. swap(alpha_status[i],alpha_status[j]);
  413. swap(alpha[i],alpha[j]);
  414. swap(p[i],p[j]);
  415. swap(active_set[i],active_set[j]);
  416. swap(G_bar[i],G_bar[j]);
  417. }
  418. void Solver::reconstruct_gradient()
  419. {
  420. // reconstruct inactive elements of G from G_bar and free variables
  421. if(active_size == l) return;
  422. int i;
  423. for(i=active_size;i<l;i++)
  424. G[i] = G_bar[i] + p[i];
  425. for(i=0;i<active_size;i++)
  426. if(is_free(i))
  427. {
  428. const Qfloat *Q_i = Q->get_Q(i,l);
  429. double alpha_i = alpha[i];
  430. for(int j=active_size;j<l;j++)
  431. G[j] += alpha_i * Q_i[j];
  432. }
  433. }
  434. void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
  435. double *alpha_, double Cp, double Cn, double eps,
  436. SolutionInfo* si, int shrinking)
  437. {
  438. this->l = l;
  439. this->Q = &Q;
  440. QD=Q.get_QD();
  441. clone(p, p_,l);
  442. clone(y, y_,l);
  443. clone(alpha,alpha_,l);
  444. this->Cp = Cp;
  445. this->Cn = Cn;
  446. this->eps = eps;
  447. unshrinked = false;
  448. // initialize alpha_status
  449. {
  450. alpha_status = new char[l];
  451. for(int i=0;i<l;i++)
  452. update_alpha_status(i);
  453. }
  454. // initialize active set (for shrinking)
  455. {
  456. active_set = new int[l];
  457. for(int i=0;i<l;i++)
  458. active_set[i] = i;
  459. active_size = l;
  460. }
  461. // initialize gradient
  462. {
  463. G = new double[l];
  464. G_bar = new double[l];
  465. int i;
  466. for(i=0;i<l;i++)
  467. {
  468. G[i] = p[i];
  469. G_bar[i] = 0;
  470. }
  471. for(i=0;i<l;i++)
  472. if(!is_lower_bound(i))
  473. {
  474. const Qfloat *Q_i = Q.get_Q(i,l);
  475. double alpha_i = alpha[i];
  476. int j;
  477. for(j=0;j<l;j++)
  478. G[j] += alpha_i*Q_i[j];
  479. if(is_upper_bound(i))
  480. for(j=0;j<l;j++)
  481. G_bar[j] += get_C(i) * Q_i[j];
  482. }
  483. }
  484. // optimization step
  485. int iter = 0;
  486. int counter = min(l,1000)+1;
  487. while(1)
  488. {
  489. // show progress and do shrinking
  490. if(--counter == 0)
  491. {
  492. counter = min(l,1000);
  493. if(shrinking) do_shrinking();
  494. info("."); info_flush();
  495. }
  496. int i,j;
  497. if(select_working_set(i,j)!=0)
  498. {
  499. // reconstruct the whole gradient
  500. reconstruct_gradient();
  501. // reset active set size and check
  502. active_size = l;
  503. info("*"); info_flush();
  504. if(select_working_set(i,j)!=0)
  505. break;
  506. else
  507. counter = 1; // do shrinking next iteration
  508. }
  509. ++iter;
  510. // update alpha[i] and alpha[j], handle bounds carefully
  511. const Qfloat *Q_i = Q.get_Q(i,active_size);
  512. const Qfloat *Q_j = Q.get_Q(j,active_size);
  513. double C_i = get_C(i);
  514. double C_j = get_C(j);
  515. double old_alpha_i = alpha[i];
  516. double old_alpha_j = alpha[j];
  517. if(y[i]!=y[j])
  518. {
  519. double quad_coef = Q_i[i]+Q_j[j]+2*Q_i[j];
  520. if (quad_coef <= 0)
  521. quad_coef = TAU;
  522. double delta = (-G[i]-G[j])/quad_coef;
  523. double diff = alpha[i] - alpha[j];
  524. alpha[i] += delta;
  525. alpha[j] += delta;
  526. if(diff > 0)
  527. {
  528. if(alpha[j] < 0)
  529. {
  530. alpha[j] = 0;
  531. alpha[i] = diff;
  532. }
  533. }
  534. else
  535. {
  536. if(alpha[i] < 0)
  537. {
  538. alpha[i] = 0;
  539. alpha[j] = -diff;
  540. }
  541. }
  542. if(diff > C_i - C_j)
  543. {
  544. if(alpha[i] > C_i)
  545. {
  546. alpha[i] = C_i;
  547. alpha[j] = C_i - diff;
  548. }
  549. }
  550. else
  551. {
  552. if(alpha[j] > C_j)
  553. {
  554. alpha[j] = C_j;
  555. alpha[i] = C_j + diff;
  556. }
  557. }
  558. }
  559. else
  560. {
  561. double quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j];
  562. if (quad_coef <= 0)
  563. quad_coef = TAU;
  564. double delta = (G[i]-G[j])/quad_coef;
  565. double sum = alpha[i] + alpha[j];
  566. alpha[i] -= delta;
  567. alpha[j] += delta;
  568. if(sum > C_i)
  569. {
  570. if(alpha[i] > C_i)
  571. {
  572. alpha[i] = C_i;
  573. alpha[j] = sum - C_i;
  574. }
  575. }
  576. else
  577. {
  578. if(alpha[j] < 0)
  579. {
  580. alpha[j] = 0;
  581. alpha[i] = sum;
  582. }
  583. }
  584. if(sum > C_j)
  585. {
  586. if(alpha[j] > C_j)
  587. {
  588. alpha[j] = C_j;
  589. alpha[i] = sum - C_j;
  590. }
  591. }
  592. else
  593. {
  594. if(alpha[i] < 0)
  595. {
  596. alpha[i] = 0;
  597. alpha[j] = sum;
  598. }
  599. }
  600. }
  601. // update G
  602. double delta_alpha_i = alpha[i] - old_alpha_i;
  603. double delta_alpha_j = alpha[j] - old_alpha_j;
  604. for(int k=0;k<active_size;k++)
  605. {
  606. G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
  607. }
  608. // update alpha_status and G_bar
  609. {
  610. bool ui = is_upper_bound(i);
  611. bool uj = is_upper_bound(j);
  612. update_alpha_status(i);
  613. update_alpha_status(j);
  614. int k;
  615. if(ui != is_upper_bound(i))
  616. {
  617. Q_i = Q.get_Q(i,l);
  618. if(ui)
  619. for(k=0;k<l;k++)
  620. G_bar[k] -= C_i * Q_i[k];
  621. else
  622. for(k=0;k<l;k++)
  623. G_bar[k] += C_i * Q_i[k];
  624. }
  625. if(uj != is_upper_bound(j))
  626. {
  627. Q_j = Q.get_Q(j,l);
  628. if(uj)
  629. for(k=0;k<l;k++)
  630. G_bar[k] -= C_j * Q_j[k];
  631. else
  632. for(k=0;k<l;k++)
  633. G_bar[k] += C_j * Q_j[k];
  634. }
  635. }
  636. }
  637. // calculate rho
  638. si->rho = calculate_rho();
  639. // calculate objective value
  640. {
  641. double v = 0;
  642. int i;
  643. for(i=0;i<l;i++)
  644. v += alpha[i] * (G[i] + p[i]);
  645. si->obj = v/2;
  646. }
  647. // put back the solution
  648. {
  649. for(int i=0;i<l;i++)
  650. alpha_[active_set[i]] = alpha[i];
  651. }
  652. // juggle everything back
  653. /*{
  654. for(int i=0;i<l;i++)
  655. while(active_set[i] != i)
  656. swap_index(i,active_set[i]);
  657. // or Q.swap_index(i,active_set[i]);
  658. }*/
  659. si->upper_bound_p = Cp;
  660. si->upper_bound_n = Cn;
  661. info("\noptimization finished, #iter = %d\n",iter);
  662. delete[] p;
  663. delete[] y;
  664. delete[] alpha;
  665. delete[] alpha_status;
  666. delete[] active_set;
  667. delete[] G;
  668. delete[] G_bar;
  669. }
  670. // return 1 if already optimal, return 0 otherwise
  671. int Solver::select_working_set(int &out_i, int &out_j)
  672. {
  673. // return i,j such that
  674. // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
  675. // j: minimizes the decrease of obj value
  676. // (if quadratic coefficeint <= 0, replace it with tau)
  677. // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
  678. double Gmax = -INF;
  679. double Gmax2 = -INF;
  680. int Gmax_idx = -1;
  681. int Gmin_idx = -1;
  682. double obj_diff_min = INF;
  683. for(int t=0;t<active_size;t++)
  684. if(y[t]==+1)
  685. {
  686. if(!is_upper_bound(t))
  687. if(-G[t] >= Gmax)
  688. {
  689. Gmax = -G[t];
  690. Gmax_idx = t;
  691. }
  692. }
  693. else
  694. {
  695. if(!is_lower_bound(t))
  696. if(G[t] >= Gmax)
  697. {
  698. Gmax = G[t];
  699. Gmax_idx = t;
  700. }
  701. }
  702. int i = Gmax_idx;
  703. const Qfloat *Q_i = NULL;
  704. if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
  705. Q_i = Q->get_Q(i,active_size);
  706. for(int j=0;j<active_size;j++)
  707. {
  708. if(y[j]==+1)
  709. {
  710. if (!is_lower_bound(j))
  711. {
  712. double grad_diff=Gmax+G[j];
  713. if (G[j] >= Gmax2)
  714. Gmax2 = G[j];
  715. if (grad_diff > 0)
  716. {
  717. double obj_diff;
  718. double quad_coef=Q_i[i]+QD[j]-2*y[i]*Q_i[j];
  719. if (quad_coef > 0)
  720. obj_diff = -(grad_diff*grad_diff)/quad_coef;
  721. else
  722. obj_diff = -(grad_diff*grad_diff)/TAU;
  723. if (obj_diff <= obj_diff_min)
  724. {
  725. Gmin_idx=j;
  726. obj_diff_min = obj_diff;
  727. }
  728. }
  729. }
  730. }
  731. else
  732. {
  733. if (!is_upper_bound(j))
  734. {
  735. double grad_diff= Gmax-G[j];
  736. if (-G[j] >= Gmax2)
  737. Gmax2 = -G[j];
  738. if (grad_diff > 0)
  739. {
  740. double obj_diff;
  741. double quad_coef=Q_i[i]+QD[j]+2*y[i]*Q_i[j];
  742. if (quad_coef > 0)
  743. obj_diff = -(grad_diff*grad_diff)/quad_coef;
  744. else
  745. obj_diff = -(grad_diff*grad_diff)/TAU;
  746. if (obj_diff <= obj_diff_min)
  747. {
  748. Gmin_idx=j;
  749. obj_diff_min = obj_diff;
  750. }
  751. }
  752. }
  753. }
  754. }
  755. if(Gmax+Gmax2 < eps)
  756. return 1;
  757. out_i = Gmax_idx;
  758. out_j = Gmin_idx;
  759. return 0;
  760. }
  761. bool Solver::be_shrunken(int i, double Gmax1, double Gmax2)
  762. {
  763. if(is_upper_bound(i))
  764. {
  765. if(y[i]==+1)
  766. return(-G[i] > Gmax1);
  767. else
  768. return(-G[i] > Gmax2);
  769. }
  770. else if(is_lower_bound(i))
  771. {
  772. if(y[i]==+1)
  773. return(G[i] > Gmax2);
  774. else
  775. return(G[i] > Gmax1);
  776. }
  777. else
  778. return(false);
  779. }
  780. void Solver::do_shrinking()
  781. {
  782. int i;
  783. double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
  784. double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
  785. // find maximal violating pair first
  786. for(i=0;i<active_size;i++)
  787. {
  788. if(y[i]==+1)
  789. {
  790. if(!is_upper_bound(i))
  791. {
  792. if(-G[i] >= Gmax1)
  793. Gmax1 = -G[i];
  794. }
  795. if(!is_lower_bound(i))
  796. {
  797. if(G[i] >= Gmax2)
  798. Gmax2 = G[i];
  799. }
  800. }
  801. else
  802. {
  803. if(!is_upper_bound(i))
  804. {
  805. if(-G[i] >= Gmax2)
  806. Gmax2 = -G[i];
  807. }
  808. if(!is_lower_bound(i))
  809. {
  810. if(G[i] >= Gmax1)
  811. Gmax1 = G[i];
  812. }
  813. }
  814. }
  815. // shrink
  816. for(i=0;i<active_size;i++)
  817. if (be_shrunken(i, Gmax1, Gmax2))
  818. {
  819. active_size--;
  820. while (active_size > i)
  821. {
  822. if (!be_shrunken(active_size, Gmax1, Gmax2))
  823. {
  824. swap_index(i,active_size);
  825. break;
  826. }
  827. active_size--;
  828. }
  829. }
  830. // unshrink, check all variables again before final iterations
  831. if(unshrinked || Gmax1 + Gmax2 > eps*10) return;
  832. unshrinked = true;
  833. reconstruct_gradient();
  834. for(i=l-1;i>=active_size;i--)
  835. if (!be_shrunken(i, Gmax1, Gmax2))
  836. {
  837. while (active_size < i)
  838. {
  839. if (be_shrunken(active_size, Gmax1, Gmax2))
  840. {
  841. swap_index(i,active_size);
  842. break;
  843. }
  844. active_size++;
  845. }
  846. active_size++;
  847. }
  848. }
  849. double Solver::calculate_rho()
  850. {
  851. double r;
  852. int nr_free = 0;
  853. double ub = INF, lb = -INF, sum_free = 0;
  854. for(int i=0;i<active_size;i++)
  855. {
  856. double yG = y[i]*G[i];
  857. if(is_upper_bound(i))
  858. {
  859. if(y[i]==-1)
  860. ub = min(ub,yG);
  861. else
  862. lb = max(lb,yG);
  863. }
  864. else if(is_lower_bound(i))
  865. {
  866. if(y[i]==+1)
  867. ub = min(ub,yG);
  868. else
  869. lb = max(lb,yG);
  870. }
  871. else
  872. {
  873. ++nr_free;
  874. sum_free += yG;
  875. }
  876. }
  877. if(nr_free>0)
  878. r = sum_free/nr_free;
  879. else
  880. r = (ub+lb)/2;
  881. return r;
  882. }
  883. //
  884. // Solver for nu-svm classification and regression
  885. //
  886. // additional constraint: e^T \alpha = constant
  887. //
  888. class Solver_NU : public Solver
  889. {
  890. public:
  891. Solver_NU() {}
  892. void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
  893. double *alpha, double Cp, double Cn, double eps,
  894. SolutionInfo* si, int shrinking)
  895. {
  896. this->si = si;
  897. Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
  898. }
  899. private:
  900. SolutionInfo *si;
  901. int select_working_set(int &i, int &j);
  902. double calculate_rho();
  903. bool be_shrunken(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
  904. void do_shrinking();
  905. };
  906. // return 1 if already optimal, return 0 otherwise
  907. int Solver_NU::select_working_set(int &out_i, int &out_j)
  908. {
  909. // return i,j such that y_i = y_j and
  910. // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
  911. // j: minimizes the decrease of obj value
  912. // (if quadratic coefficeint <= 0, replace it with tau)
  913. // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
  914. double Gmaxp = -INF;
  915. double Gmaxp2 = -INF;
  916. int Gmaxp_idx = -1;
  917. double Gmaxn = -INF;
  918. double Gmaxn2 = -INF;
  919. int Gmaxn_idx = -1;
  920. int Gmin_idx = -1;
  921. double obj_diff_min = INF;
  922. for(int t=0;t<active_size;t++)
  923. if(y[t]==+1)
  924. {
  925. if(!is_upper_bound(t))
  926. if(-G[t] >= Gmaxp)
  927. {
  928. Gmaxp = -G[t];
  929. Gmaxp_idx = t;
  930. }
  931. }
  932. else
  933. {
  934. if(!is_lower_bound(t))
  935. if(G[t] >= Gmaxn)
  936. {
  937. Gmaxn = G[t];
  938. Gmaxn_idx = t;
  939. }
  940. }
  941. int ip = Gmaxp_idx;
  942. int in = Gmaxn_idx;
  943. const Qfloat *Q_ip = NULL;
  944. const Qfloat *Q_in = NULL;
  945. if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
  946. Q_ip = Q->get_Q(ip,active_size);
  947. if(in != -1)
  948. Q_in = Q->get_Q(in,active_size);
  949. for(int j=0;j<active_size;j++)
  950. {
  951. if(y[j]==+1)
  952. {
  953. if (!is_lower_bound(j))
  954. {
  955. double grad_diff=Gmaxp+G[j];
  956. if (G[j] >= Gmaxp2)
  957. Gmaxp2 = G[j];
  958. if (grad_diff > 0)
  959. {
  960. double obj_diff;
  961. double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j];
  962. if (quad_coef > 0)
  963. obj_diff = -(grad_diff*grad_diff)/quad_coef;
  964. else
  965. obj_diff = -(grad_diff*grad_diff)/TAU;
  966. if (obj_diff <= obj_diff_min)
  967. {
  968. Gmin_idx=j;
  969. obj_diff_min = obj_diff;
  970. }
  971. }
  972. }
  973. }
  974. else
  975. {
  976. if (!is_upper_bound(j))
  977. {
  978. double grad_diff=Gmaxn-G[j];
  979. if (-G[j] >= Gmaxn2)
  980. Gmaxn2 = -G[j];
  981. if (grad_diff > 0)
  982. {
  983. double obj_diff;
  984. double quad_coef = Q_in[in]+QD[j]-2*Q_in[j];
  985. if (quad_coef > 0)
  986. obj_diff = -(grad_diff*grad_diff)/quad_coef;
  987. else
  988. obj_diff = -(grad_diff*grad_diff)/TAU;
  989. if (obj_diff <= obj_diff_min)
  990. {
  991. Gmin_idx=j;
  992. obj_diff_min = obj_diff;
  993. }
  994. }
  995. }
  996. }
  997. }
  998. if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
  999. return 1;
  1000. if (y[Gmin_idx] == +1)
  1001. out_i = Gmaxp_idx;
  1002. else
  1003. out_i = Gmaxn_idx;
  1004. out_j = Gmin_idx;
  1005. return 0;
  1006. }
  1007. bool Solver_NU::be_shrunken(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
  1008. {
  1009. if(is_upper_bound(i))
  1010. {
  1011. if(y[i]==+1)
  1012. return(-G[i] > Gmax1);
  1013. else
  1014. return(-G[i] > Gmax4);
  1015. }
  1016. else if(is_lower_bound(i))
  1017. {
  1018. if(y[i]==+1)
  1019. return(G[i] > Gmax2);
  1020. else
  1021. return(G[i] > Gmax3);
  1022. }
  1023. else
  1024. return(false);
  1025. }
  1026. void Solver_NU::do_shrinking()
  1027. {
  1028. double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
  1029. double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
  1030. double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
  1031. double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
  1032. // find maximal violating pair first
  1033. int i;
  1034. for(i=0;i<active_size;i++)
  1035. {
  1036. if(!is_upper_bound(i))
  1037. {
  1038. if(y[i]==+1)
  1039. {
  1040. if(-G[i] > Gmax1) Gmax1 = -G[i];
  1041. }
  1042. else if(-G[i] > Gmax4) Gmax4 = -G[i];
  1043. }
  1044. if(!is_lower_bound(i))
  1045. {
  1046. if(y[i]==+1)
  1047. {
  1048. if(G[i] > Gmax2) Gmax2 = G[i];
  1049. }
  1050. else if(G[i] > Gmax3) Gmax3 = G[i];
  1051. }
  1052. }
  1053. // shrinking
  1054. for(i=0;i<active_size;i++)
  1055. if (be_shrunken(i, Gmax1, Gmax2, Gmax3, Gmax4))
  1056. {
  1057. active_size--;
  1058. while (active_size > i)
  1059. {
  1060. if (!be_shrunken(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
  1061. {
  1062. swap_index(i,active_size);
  1063. break;
  1064. }
  1065. active_size--;
  1066. }
  1067. }
  1068. // unshrink, check all variables again before final iterations
  1069. if(unshrinked || max(Gmax1+Gmax2,Gmax3+Gmax4) > eps*10) return;
  1070. unshrinked = true;
  1071. reconstruct_gradient();
  1072. for(i=l-1;i>=active_size;i--)
  1073. if (!be_shrunken(i, Gmax1, Gmax2, Gmax3, Gmax4))
  1074. {
  1075. while (active_size < i)
  1076. {
  1077. if (be_shrunken(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
  1078. {
  1079. swap_index(i,active_size);
  1080. break;
  1081. }
  1082. active_size++;
  1083. }
  1084. active_size++;
  1085. }
  1086. }
  1087. double Solver_NU::calculate_rho()
  1088. {
  1089. int nr_free1 = 0,nr_free2 = 0;
  1090. double ub1 = INF, ub2 = INF;
  1091. double lb1 = -INF, lb2 = -INF;
  1092. double sum_free1 = 0, sum_free2 = 0;
  1093. for(int i=0;i<active_size;i++)
  1094. {
  1095. if(y[i]==+1)
  1096. {
  1097. if(is_upper_bound(i))
  1098. lb1 = max(lb1,G[i]);
  1099. else if(is_lower_bound(i))
  1100. ub1 = min(ub1,G[i]);
  1101. else
  1102. {
  1103. ++nr_free1;
  1104. sum_free1 += G[i];
  1105. }
  1106. }
  1107. else
  1108. {
  1109. if(is_upper_bound(i))
  1110. lb2 = max(lb2,G[i]);
  1111. else if(is_lower_bound(i))
  1112. ub2 = min(ub2,G[i]);
  1113. else
  1114. {
  1115. ++nr_free2;
  1116. sum_free2 += G[i];
  1117. }
  1118. }
  1119. }
  1120. double r1,r2;
  1121. if(nr_free1 > 0)
  1122. r1 = sum_free1/nr_free1;
  1123. else
  1124. r1 = (ub1+lb1)/2;
  1125. if(nr_free2 > 0)
  1126. r2 = sum_free2/nr_free2;
  1127. else
  1128. r2 = (ub2+lb2)/2;
  1129. si->r = (r1+r2)/2;
  1130. return (r1-r2)/2;
  1131. }
  1132. //
  1133. // Q matrices for various formulations
  1134. //
  1135. class SVC_Q: public Kernel
  1136. {
  1137. public:
  1138. SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
  1139. :Kernel(prob.l, prob.x, param)
  1140. {
  1141. clone(y,y_,prob.l);
  1142. cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
  1143. QD = new Qfloat[prob.l];
  1144. for(int i=0;i<prob.l;i++)
  1145. QD[i]= (Qfloat)(this->*kernel_function)(i,i);
  1146. }
  1147. Qfloat *get_Q(int i, int len) const
  1148. {
  1149. Qfloat *data;
  1150. int start;
  1151. if((start = cache->get_data(i,&data,len)) < len)
  1152. {
  1153. for(int j=start;j<len;j++)
  1154. data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
  1155. }
  1156. return data;
  1157. }
  1158. Qfloat *get_QD() const
  1159. {
  1160. return QD;
  1161. }
  1162. void swap_index(int i, int j) const
  1163. {
  1164. cache->swap_index(i,j);
  1165. Kernel::swap_index(i,j);
  1166. swap(y[i],y[j]);
  1167. swap(QD[i],QD[j]);
  1168. }
  1169. ~SVC_Q()
  1170. {
  1171. delete[] y;
  1172. delete cache;
  1173. delete[] QD;
  1174. }
  1175. private:
  1176. schar *y;
  1177. Cache *cache;
  1178. Qfloat *QD;
  1179. };
  1180. class ONE_CLASS_Q: public Kernel
  1181. {
  1182. public:
  1183. ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
  1184. :Kernel(prob.l, prob.x, param)
  1185. {
  1186. cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
  1187. QD = new Qfloat[prob.l];
  1188. for(int i=0;i<prob.l;i++)
  1189. QD[i]= (Qfloat)(this->*kernel_function)(i,i);
  1190. }
  1191. Qfloat *get_Q(int i, int len) const
  1192. {
  1193. Qfloat *data;
  1194. int start;
  1195. if((start = cache->get_data(i,&data,len)) < len)
  1196. {
  1197. for(int j=start;j<len;j++)
  1198. data[j] = (Qfloat)(this->*kernel_function)(i,j);
  1199. }
  1200. return data;
  1201. }
  1202. Qfloat *get_QD() const
  1203. {
  1204. return QD;
  1205. }
  1206. void swap_index(int i, int j) const
  1207. {
  1208. cache->swap_index(i,j);
  1209. Kernel::swap_index(i,j);
  1210. swap(QD[i],QD[j]);
  1211. }
  1212. ~ONE_CLASS_Q()
  1213. {
  1214. delete cache;
  1215. delete[] QD;
  1216. }
  1217. private:
  1218. Cache *cache;
  1219. Qfloat *QD;
  1220. };
  1221. class SVR_Q: public Kernel
  1222. {
  1223. public:
  1224. SVR_Q(const svm_problem& prob, const svm_parameter& param)
  1225. :Kernel(prob.l, prob.x, param)
  1226. {
  1227. l = prob.l;
  1228. cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
  1229. QD = new Qfloat[2*l];
  1230. sign = new schar[2*l];
  1231. index = new int[2*l];
  1232. for(int k=0;k<l;k++)
  1233. {
  1234. sign[k] = 1;
  1235. sign[k+l] = -1;
  1236. index[k] = k;
  1237. index[k+l] = k;
  1238. QD[k]= (Qfloat)(this->*kernel_function)(k,k);
  1239. QD[k+l]=QD[k];
  1240. }
  1241. buffer[0] = new Qfloat[2*l];
  1242. buffer[1] = new Qfloat[2*l];
  1243. next_buffer = 0;
  1244. }
  1245. void swap_index(int i, int j) const
  1246. {
  1247. swap(sign[i],sign[j]);
  1248. swap(index[i],index[j]);
  1249. swap(QD[i],QD[j]);
  1250. }
  1251. Qfloat *get_Q(int i, int len) const
  1252. {
  1253. Qfloat *data;
  1254. int real_i = index[i];
  1255. if(cache->get_data(real_i,&data,l) < l)
  1256. {
  1257. for(int j=0;j<l;j++)
  1258. data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
  1259. }
  1260. // reorder and copy
  1261. Qfloat *buf = buffer[next_buffer];
  1262. next_buffer = 1 - next_buffer;
  1263. schar si = sign[i];
  1264. for(int j=0;j<len;j++)
  1265. buf[j] = si * sign[j] * data[index[j]];
  1266. return buf;
  1267. }
  1268. Qfloat *get_QD() const
  1269. {
  1270. return QD;
  1271. }
  1272. ~SVR_Q()
  1273. {
  1274. delete cache;
  1275. delete[] sign;
  1276. delete[] index;
  1277. delete[] buffer[0];
  1278. delete[] buffer[1];
  1279. delete[] QD;
  1280. }
  1281. private:
  1282. int l;
  1283. Cache *cache;
  1284. schar *sign;
  1285. int *index;
  1286. mutable int next_buffer;
  1287. Qfloat *buffer[2];
  1288. Qfloat *QD;
  1289. };
  1290. //
  1291. // construct and solve various formulations
  1292. //
  1293. static void solve_c_svc(
  1294. const svm_problem *prob, const svm_parameter* param,
  1295. double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
  1296. {
  1297. int l = prob->l;
  1298. double *minus_ones = new double[l];
  1299. schar *y = new schar[l];
  1300. int i;
  1301. for(i=0;i<l;i++)
  1302. {
  1303. alpha[i] = 0;
  1304. minus_ones[i] = -1;
  1305. if(prob->y[i] > 0) y[i] = +1; else y[i]=-1;
  1306. }
  1307. Solver s;
  1308. s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
  1309. alpha, Cp, Cn, param->eps, si, param->shrinking);
  1310. double sum_alpha=0;
  1311. for(i=0;i<l;i++)
  1312. sum_alpha += alpha[i];
  1313. if (Cp==Cn)
  1314. info("nu = %f\n", sum_alpha/(Cp*prob->l));
  1315. for(i=0;i<l;i++)
  1316. alpha[i] *= y[i];
  1317. delete[] minus_ones;
  1318. delete[] y;
  1319. }
  1320. static void solve_nu_svc(
  1321. const svm_problem *prob, const svm_parameter *param,
  1322. double *alpha, Solver::SolutionInfo* si)
  1323. {
  1324. int i;
  1325. int l = prob->l;
  1326. double nu = param->nu;
  1327. schar *y = new schar[l];
  1328. for(i=0;i<l;i++)
  1329. if(prob->y[i]>0)
  1330. y[i] = +1;
  1331. else
  1332. y[i] = -1;
  1333. double sum_pos = nu*l/2;
  1334. double sum_neg = nu*l/2;
  1335. for(i=0;i<l;i++)
  1336. if(y[i] == +1)
  1337. {
  1338. alpha[i] = min(1.0,sum_pos);
  1339. sum_pos -= alpha[i];
  1340. }
  1341. else
  1342. {
  1343. alpha[i] = min(1.0,sum_neg);
  1344. sum_neg -= alpha[i];
  1345. }
  1346. double *zeros = new double[l];
  1347. for(i=0;i<l;i++)
  1348. zeros[i] = 0;
  1349. Solver_NU s;
  1350. s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
  1351. alpha, 1.0, 1.0, param->eps, si, param->shrinking);
  1352. double r = si->r;
  1353. info("C = %f\n",1/r);
  1354. for(i=0;i<l;i++)
  1355. alpha[i] *= y[i]/r;
  1356. si->rho /= r;
  1357. si->obj /= (r*r);
  1358. si->upper_bound_p = 1/r;
  1359. si->upper_bound_n = 1/r;
  1360. delete[] y;
  1361. delete[] zeros;
  1362. }
  1363. static void solve_one_class(
  1364. const svm_problem *prob, const svm_parameter *param,
  1365. double *alpha, Solver::SolutionInfo* si)
  1366. {
  1367. int l = prob->l;
  1368. double *zeros = new double[l];
  1369. schar *ones = new schar[l];
  1370. int i;
  1371. int n = (int)(param->nu*prob->l); // # of alpha's at upper bound
  1372. for(i=0;i<n;i++)
  1373. alpha[i] = 1;
  1374. if(n<prob->l)
  1375. alpha[n] = param->nu * prob->l - n;
  1376. for(i=n+1;i<l;i++)
  1377. alpha[i] = 0;
  1378. for(i=0;i<l;i++)
  1379. {
  1380. zeros[i] = 0;
  1381. ones[i] = 1;
  1382. }
  1383. Solver s;
  1384. s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
  1385. alpha, 1.0, 1.0, param->eps, si, param->shrinking);
  1386. delete[] zeros;
  1387. delete[] ones;
  1388. }
  1389. static void solve_epsilon_svr(
  1390. const svm_problem *prob, const svm_parameter *param,
  1391. double *alpha, Solver::SolutionInfo* si)
  1392. {
  1393. int l = prob->l;
  1394. double *alpha2 = new double[2*l];
  1395. double *linear_term = new double[2*l];
  1396. schar *y = new schar[2*l];
  1397. int i;
  1398. for(i=0;i<l;i++)
  1399. {
  1400. alpha2[i] = 0;
  1401. linear_term[i] = param->p - prob->y[i];
  1402. y[i] = 1;
  1403. alpha2[i+l] = 0;
  1404. linear_term[i+l] = param->p + prob->y[i];
  1405. y[i+l] = -1;
  1406. }
  1407. Solver s;
  1408. s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
  1409. alpha2, param->C, param->C, param->eps, si, param->shrinking);
  1410. double sum_alpha = 0;
  1411. for(i=0;i<l;i++)
  1412. {
  1413. alpha[i] = alpha2[i] - alpha2[i+l];
  1414. sum_alpha += fabs(alpha[i]);
  1415. }
  1416. info("nu = %f\n",sum_alpha/(param->C*l));
  1417. delete[] alpha2;
  1418. delete[] linear_term;
  1419. delete[] y;
  1420. }
  1421. static void solve_nu_svr(
  1422. const svm_problem *prob, const svm_parameter *param,
  1423. double *alpha, Solver::SolutionInfo* si)
  1424. {
  1425. int l = prob->l;
  1426. double C = param->C;
  1427. double *alpha2 = new double[2*l];
  1428. double *linear_term = new double[2*l];
  1429. schar *y = new schar[2*l];
  1430. int i;
  1431. double sum = C * param->nu * l / 2;
  1432. for(i=0;i<l;i++)
  1433. {
  1434. alpha2[i] = alpha2[i+l] = min(sum,C);
  1435. sum -= alpha2[i];
  1436. linear_term[i] = - prob->y[i];
  1437. y[i] = 1;
  1438. linear_term[i+l] = prob->y[i];
  1439. y[i+l] = -1;
  1440. }
  1441. Solver_NU s;
  1442. s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
  1443. alpha2, C, C, param->eps, si, param->shrinking);
  1444. info("epsilon = %f\n",-si->r);
  1445. for(i=0;i<l;i++)
  1446. alpha[i] = alpha2[i] - alpha2[i+l];
  1447. delete[] alpha2;
  1448. delete[] linear_term;
  1449. delete[] y;
  1450. }
  1451. //
  1452. // decision_function
  1453. //
  1454. struct decision_function
  1455. {
  1456. double *alpha;
  1457. double rho;
  1458. };
  1459. decision_function svm_train_one(
  1460. const svm_problem *prob, const svm_parameter *param,
  1461. double Cp, double Cn)
  1462. {
  1463. double *alpha = Malloc(double,prob->l);
  1464. Solver::SolutionInfo si;
  1465. switch(param->svm_type)
  1466. {
  1467. case C_SVC:
  1468. solve_c_svc(prob,param,alpha,&si,Cp,Cn);
  1469. break;
  1470. case NU_SVC:
  1471. solve_nu_svc(prob,param,alpha,&si);
  1472. break;
  1473. case ONE_CLASS:
  1474. solve_one_class(prob,param,alpha,&si);
  1475. break;
  1476. case EPSILON_SVR:
  1477. solve_epsilon_svr(prob,param,alpha,&si);
  1478. break;
  1479. case NU_SVR:
  1480. solve_nu_svr(prob,param,alpha,&si);
  1481. break;
  1482. }
  1483. info("obj = %f, rho = %f\n",si.obj,si.rho);
  1484. // output SVs
  1485. int nSV = 0;
  1486. int nBSV = 0;
  1487. for(int i=0;i<prob->l;i++)
  1488. {
  1489. if(fabs(alpha[i]) > 0)
  1490. {
  1491. ++nSV;
  1492. if(prob->y[i] > 0)
  1493. {
  1494. if(fabs(alpha[i]) >= si.upper_bound_p)
  1495. ++nBSV;
  1496. }
  1497. else
  1498. {
  1499. if(fabs(alpha[i]) >= si.upper_bound_n)
  1500. ++nBSV;
  1501. }
  1502. }
  1503. }
  1504. info("nSV = %d, nBSV = %d\n",nSV,nBSV);
  1505. decision_function f;
  1506. f.alpha = alpha;
  1507. f.rho = si.rho;
  1508. return f;
  1509. }
  1510. // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
  1511. void sigmoid_train(
  1512. int l, const double *dec_values, const double *labels,
  1513. double& A, double& B)
  1514. {
  1515. double prior1=0, prior0 = 0;
  1516. int i;
  1517. for (i=0;i<l;i++)
  1518. if (labels[i] > 0) prior1+=1;
  1519. else prior0+=1;
  1520. int max_iter=100; // Maximal number of iterations
  1521. double min_step=1e-10; // Minimal step taken in line search
  1522. double sigma=1e-12; // For numerically strict PD of Hessian
  1523. double eps=1e-5;
  1524. double hiTarget=(prior1+1.0)/(prior1+2.0);
  1525. double loTarget=1/(prior0+2.0);
  1526. double *t=Malloc(double,l);
  1527. double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
  1528. double newA,newB,newf,d1,d2;
  1529. int iter;
  1530. // Initial Point and Initial Fun Value
  1531. A=0.0; B=log((prior0+1.0)/(prior1+1.0));
  1532. double fval = 0.0;
  1533. for (i=0;i<l;i++)
  1534. {
  1535. if (labels[i]>0) t[i]=hiTarget;
  1536. else t[i]=loTarget;
  1537. fApB = dec_values[i]*A+B;
  1538. if (fApB>=0)
  1539. fval += t[i]*fApB + log(1+exp(-fApB));
  1540. else
  1541. fval += (t[i] - 1)*fApB +log(1+exp(fApB));
  1542. }
  1543. for (iter=0;iter<max_iter;iter++)
  1544. {
  1545. // Update Gradient and Hessian (use H' = H + sigma I)
  1546. h11=sigma; // numerically ensures strict PD
  1547. h22=sigma;
  1548. h21=0.0;g1=0.0;g2=0.0;
  1549. for (i=0;i<l;i++)
  1550. {
  1551. fApB = dec_values[i]*A+B;
  1552. if (fApB >= 0)
  1553. {
  1554. p=exp(-fApB)/(1.0+exp(-fApB));
  1555. q=1.0/(1.0+exp(-fApB));
  1556. }
  1557. else
  1558. {
  1559. p=1.0/(1.0+exp(fApB));
  1560. q=exp(fApB)/(1.0+exp(fApB));
  1561. }
  1562. d2=p*q;
  1563. h11+=dec_values[i]*dec_values[i]*d2;
  1564. h22+=d2;
  1565. h21+=dec_values[i]*d2;
  1566. d1=t[i]-p;
  1567. g1+=dec_values[i]*d1;
  1568. g2+=d1;
  1569. }
  1570. // Stopping Criteria
  1571. if (fabs(g1)<eps && fabs(g2)<eps)
  1572. break;
  1573. // Finding Newton direction: -inv(H') * g
  1574. det=h11*h22-h21*h21;
  1575. dA=-(h22*g1 - h21 * g2) / det;
  1576. dB=-(-h21*g1+ h11 * g2) / det;
  1577. gd=g1*dA+g2*dB;
  1578. stepsize = 1; // Line Search
  1579. while (stepsize >= min_step)
  1580. {
  1581. newA = A + stepsize * dA;
  1582. newB = B + stepsize * dB;
  1583. // New function value
  1584. newf = 0.0;
  1585. for (i=0;i<l;i++)
  1586. {
  1587. fApB = dec_values[i]*newA+newB;
  1588. if (fApB >= 0)
  1589. newf += t[i]*fApB + log(1+exp(-fApB));
  1590. else
  1591. newf += (t[i] - 1)*fApB +log(1+exp(fApB));
  1592. }
  1593. // Check sufficient decrease
  1594. if (newf<fval+0.0001*stepsize*gd)
  1595. {
  1596. A=newA;B=newB;fval=newf;
  1597. break;
  1598. }
  1599. else
  1600. stepsize = stepsize / 2.0;
  1601. }
  1602. if (stepsize < min_step)
  1603. {
  1604. info("Line search fails in two-class probability estimates\n");
  1605. break;
  1606. }
  1607. }
  1608. if (iter>=max_iter)
  1609. info("Reaching maximal iterations in two-class probability estimates\n");
  1610. free(t);
  1611. }
  1612. double sigmoid_predict(double decision_value, double A, double B)
  1613. {
  1614. double fApB = decision_value*A+B;
  1615. if (fApB >= 0)
  1616. return exp(-fApB)/(1.0+exp(-fApB));
  1617. else
  1618. return 1.0/(1+exp(fApB)) ;
  1619. }
  1620. // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
  1621. void multiclass_probability(int k, double **r, double *p)
  1622. {
  1623. int t,j;
  1624. int iter = 0, max_iter=max(100,k);
  1625. double **Q=Malloc(double *,k);
  1626. double *Qp=Malloc(double,k);
  1627. double pQp, eps=0.005/k;
  1628. for (t=0;t<k;t++)
  1629. {
  1630. p[t]=1.0/k; // Valid if k = 1
  1631. Q[t]=Malloc(double,k);
  1632. Q[t][t]=0;
  1633. for (j=0;j<t;j++)
  1634. {
  1635. Q[t][t]+=r[j][t]*r[j][t];
  1636. Q[t][j]=Q[j][t];
  1637. }
  1638. for (j=t+1;j<k;j++)
  1639. {
  1640. Q[t][t]+=r[j][t]*r[j][t];
  1641. Q[t][j]=-r[j][t]*r[t][j];
  1642. }
  1643. }
  1644. for (iter=0;iter<max_iter;iter++)
  1645. {
  1646. // stopping condition, recalculate QP,pQP for numerical accuracy
  1647. pQp=0;
  1648. for (t=0;t<k;t++)
  1649. {
  1650. Qp[t]=0;
  1651. for (j=0;j<k;j++)
  1652. Qp[t]+=Q[t][j]*p[j];
  1653. pQp+=p[t]*Qp[t];
  1654. }
  1655. double max_error=0;
  1656. for (t=0;t<k;t++)
  1657. {
  1658. double error=fabs(Qp[t]-pQp);
  1659. if (error>max_error)
  1660. max_error=error;
  1661. }
  1662. if (max_error<eps) break;
  1663. for (t=0;t<k;t++)
  1664. {
  1665. double diff=(-Qp[t]+pQp)/Q[t][t];
  1666. p[t]+=diff;
  1667. pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
  1668. for (j=0;j<k;j++)
  1669. {
  1670. Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
  1671. p[j]/=(1+diff);
  1672. }
  1673. }
  1674. }
  1675. if (iter>=max_iter)
  1676. info("Exceeds max_iter in multiclass_prob\n");
  1677. for(t=0;t<k;t++) free(Q[t]);
  1678. free(Q);
  1679. free(Qp);
  1680. }
  1681. // Cross-validation decision values for probability estimates
  1682. void svm_binary_svc_probability(
  1683. const svm_problem *prob, const svm_parameter *param,
  1684. double Cp, double Cn, double& probA, double& probB)
  1685. {
  1686. int i;
  1687. int nr_fold = 5;
  1688. int *perm = Malloc(int,prob->l);
  1689. double *dec_values = Malloc(double,prob->l);
  1690. // random shuffle
  1691. for(i=0;i<prob->l;i++) perm[i]=i;
  1692. for(i=0;i<prob->l;i++)
  1693. {
  1694. int j = i+rand()%(prob->l-i);
  1695. swap(perm[i],perm[j]);
  1696. }
  1697. for(i=0;i<nr_fold;i++)
  1698. {
  1699. int begin = i*prob->l/nr_fold;
  1700. int end = (i+1)*prob->l/nr_fold;
  1701. int j,k;
  1702. struct svm_problem subprob;
  1703. subprob.l = prob->l-(end-begin);
  1704. subprob.x = Malloc(struct svm_node*,subprob.l);
  1705. subprob.y = Malloc(double,subprob.l);
  1706. k=0;
  1707. for(j=0;j<begin;j++)
  1708. {
  1709. subprob.x[k] = prob->x[perm[j]];
  1710. subprob.y[k] = prob->y[perm[j]];
  1711. ++k;
  1712. }
  1713. for(j=end;j<prob->l;j++)
  1714. {
  1715. subprob.x[k] = prob->x[perm[j]];
  1716. subprob.y[k] = prob->y[perm[j]];
  1717. ++k;
  1718. }
  1719. int p_count=0,n_count=0;
  1720. for(j=0;j<k;j++)
  1721. if(subprob.y[j]>0)
  1722. p_count++;
  1723. else
  1724. n_count++;
  1725. if(p_count==0 && n_count==0)
  1726. for(j=begin;j<end;j++)
  1727. dec_values[perm[j]] = 0;
  1728. else if(p_count > 0 && n_count == 0)
  1729. for(j=begin;j<end;j++)
  1730. dec_values[perm[j]] = 1;
  1731. else if(p_count == 0 && n_count > 0)
  1732. for(j=begin;j<end;j++)
  1733. dec_values[perm[j]] = -1;
  1734. else
  1735. {
  1736. svm_parameter subparam = *param;
  1737. subparam.probability=0;
  1738. subparam.C=1.0;
  1739. subparam.nr_weight=2;
  1740. subparam.weight_label = Malloc(int,2);
  1741. subparam.weight = Malloc(double,2);
  1742. subparam.weight_label[0]=+1;
  1743. subparam.weight_label[1]=-1;
  1744. subparam.weight[0]=Cp;
  1745. subparam.weight[1]=Cn;
  1746. struct svm_model *submodel = svm_train(&subprob,&subparam);
  1747. for(j=begin;j<end;j++)
  1748. {
  1749. svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]]));
  1750. // ensure +1 -1 order; reason not using CV subroutine
  1751. dec_values[perm[j]] *= submodel->label[0];
  1752. }
  1753. svm_destroy_model(submodel);
  1754. svm_destroy_param(&subparam);
  1755. }
  1756. free(subprob.x);
  1757. free(subprob.y);
  1758. }
  1759. sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
  1760. free(dec_values);
  1761. free(perm);
  1762. }
  1763. // Return parameter of a Laplace distribution
  1764. double svm_svr_probability(
  1765. const svm_problem *prob, const svm_parameter *param)
  1766. {
  1767. int i;
  1768. int nr_fold = 5;
  1769. double *ymv = Malloc(double,prob->l);
  1770. double mae = 0;
  1771. svm_parameter newparam = *param;
  1772. newparam.probability = 0;
  1773. svm_cross_validation(prob,&newparam,nr_fold,ymv);
  1774. for(i=0;i<prob->l;i++)
  1775. {
  1776. ymv[i]=prob->y[i]-ymv[i];
  1777. mae += fabs(ymv[i]);
  1778. }
  1779. mae /= prob->l;
  1780. double std=sqrt(2*mae*mae);
  1781. int count=0;
  1782. mae=0;
  1783. for(i=0;i<prob->l;i++)
  1784. if (fabs(ymv[i]) > 5*std)
  1785. count=count+1;
  1786. else
  1787. mae+=fabs(ymv[i]);
  1788. mae /= (prob->l-count);
  1789. info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
  1790. free(ymv);
  1791. return mae;
  1792. }
  1793. // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
  1794. // perm, length l, must be allocated before calling this subroutine
  1795. void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
  1796. {
  1797. int l = prob->l;
  1798. int max_nr_class = 16;
  1799. int nr_class = 0;
  1800. int *label = Malloc(int,max_nr_class);
  1801. int *count = Malloc(int,max_nr_class);
  1802. int *data_label = Malloc(int,l);
  1803. int i;
  1804. for(i=0;i<l;i++)
  1805. {
  1806. int this_label = (int)prob->y[i];
  1807. int j;
  1808. for(j=0;j<nr_class;j++)
  1809. {
  1810. if(this_label == label[j])
  1811. {
  1812. ++count[j];
  1813. break;
  1814. }
  1815. }
  1816. data_label[i] = j;
  1817. if(j == nr_class)
  1818. {
  1819. if(nr_class == max_nr_class)
  1820. {
  1821. max_nr_class *= 2;
  1822. label = (int *)realloc(label,max_nr_class*sizeof(int));
  1823. count = (int *)realloc(count,max_nr_class*sizeof(int));
  1824. }
  1825. label[nr_class] = this_label;
  1826. count[nr_class] = 1;
  1827. ++nr_class;
  1828. }
  1829. }
  1830. int *start = Malloc(int,nr_class);
  1831. start[0] = 0;
  1832. for(i=1;i<nr_class;i++)
  1833. start[i] = start[i-1]+count[i-1];
  1834. for(i=0;i<l;i++)
  1835. {
  1836. perm[start[data_label[i]]] = i;
  1837. ++start[data_label[i]];
  1838. }
  1839. start[0] = 0;
  1840. for(i=1;i<nr_class;i++)
  1841. start[i] = start[i-1]+count[i-1];
  1842. *nr_class_ret = nr_class;
  1843. *label_ret = label;
  1844. *start_ret = start;
  1845. *count_ret = count;
  1846. free(data_label);
  1847. }
  1848. //
  1849. // Interface functions
  1850. //
  1851. svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
  1852. {
  1853. svm_model *model = Malloc(svm_model,1);
  1854. model->param = *param;
  1855. model->free_sv = 0; // XXX
  1856. if(param->svm_type == ONE_CLASS ||
  1857. param->svm_type == EPSILON_SVR ||
  1858. param->svm_type == NU_SVR)
  1859. {
  1860. // regression or one-class-svm
  1861. model->nr_class = 2;
  1862. model->label = NULL;
  1863. model->nSV = NULL;
  1864. model->probA = NULL; model->probB = NULL;
  1865. model->sv_coef = Malloc(double *,1);
  1866. if(param->probability &&
  1867. (param->svm_type == EPSILON_SVR ||
  1868. param->svm_type == NU_SVR))
  1869. {
  1870. model->probA = Malloc(double,1);
  1871. model->probA[0] = svm_svr_probability(prob,param);
  1872. }
  1873. decision_function f = svm_train_one(prob,param,0,0);
  1874. model->rho = Malloc(double,1);
  1875. model->rho[0] = f.rho;
  1876. int nSV = 0;
  1877. int i;
  1878. for(i=0;i<prob->l;i++)
  1879. if(fabs(f.alpha[i]) > 0) ++nSV;
  1880. model->l = nSV;
  1881. model->SV = Malloc(svm_node *,nSV);
  1882. model->sv_coef[0] = Malloc(double,nSV);
  1883. int j = 0;
  1884. for(i=0;i<prob->l;i++)
  1885. if(fabs(f.alpha[i]) > 0)
  1886. {
  1887. model->SV[j] = prob->x[i];
  1888. model->sv_coef[0][j] = f.alpha[i];
  1889. ++j;
  1890. }
  1891. free(f.alpha);
  1892. }
  1893. else
  1894. {
  1895. // classification
  1896. int l = prob->l;
  1897. int nr_class;
  1898. int *label = NULL;
  1899. int *start = NULL;
  1900. int *count = NULL;
  1901. int *perm = Malloc(int,l);
  1902. // group training data of the same class
  1903. svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
  1904. svm_node **x = Malloc(svm_node *,l);
  1905. int i;
  1906. for(i=0;i<l;i++)
  1907. x[i] = prob->x[perm[i]];
  1908. // calculate weighted C
  1909. double *weighted_C = Malloc(double, nr_class);
  1910. for(i=0;i<nr_class;i++)
  1911. weighted_C[i] = param->C;
  1912. for(i=0;i<param->nr_weight;i++)
  1913. {
  1914. int j;
  1915. for(j=0;j<nr_class;j++)
  1916. if(param->weight_label[i] == label[j])
  1917. break;
  1918. if(j == nr_class)
  1919. fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
  1920. else
  1921. weighted_C[j] *= param->weight[i];
  1922. }
  1923. // train k*(k-1)/2 models
  1924. bool *nonzero = Malloc(bool,l);
  1925. for(i=0;i<l;i++)
  1926. nonzero[i] = false;
  1927. decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
  1928. double *probA=NULL,*probB=NULL;
  1929. if (param->probability)
  1930. {
  1931. probA=Malloc(double,nr_class*(nr_class-1)/2);
  1932. probB=Malloc(double,nr_class*(nr_class-1)/2);
  1933. }
  1934. int p = 0;
  1935. for(i=0;i<nr_class;i++)
  1936. for(int j=i+1;j<nr_class;j++)
  1937. {
  1938. svm_problem sub_prob;
  1939. int si = start[i], sj = start[j];
  1940. int ci = count[i], cj = count[j];
  1941. sub_prob.l = ci+cj;
  1942. sub_prob.x = Malloc(svm_node *,sub_prob.l);
  1943. sub_prob.y = Malloc(double,sub_prob.l);
  1944. int k;
  1945. for(k=0;k<ci;k++)
  1946. {
  1947. sub_prob.x[k] = x[si+k];
  1948. sub_prob.y[k] = +1;
  1949. }
  1950. for(k=0;k<cj;k++)
  1951. {
  1952. sub_prob.x[ci+k] = x[sj+k];
  1953. sub_prob.y[ci+k] = -1;
  1954. }
  1955. if(param->probability)
  1956. svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
  1957. f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
  1958. for(k=0;k<ci;k++)
  1959. if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
  1960. nonzero[si+k] = true;
  1961. for(k=0;k<cj;k++)
  1962. if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
  1963. nonzero[sj+k] = true;
  1964. free(sub_prob.x);
  1965. free(sub_prob.y);
  1966. ++p;
  1967. }
  1968. // build output
  1969. model->nr_class = nr_class;
  1970. model->label = Malloc(int,nr_class);
  1971. for(i=0;i<nr_class;i++)
  1972. model->label[i] = label[i];
  1973. model->rho = Malloc(double,nr_class*(nr_class-1)/2);
  1974. for(i=0;i<nr_class*(nr_class-1)/2;i++)
  1975. model->rho[i] = f[i].rho;
  1976. if(param->probability)
  1977. {
  1978. model->probA = Malloc(double,nr_class*(nr_class-1)/2);
  1979. model->probB = Malloc(double,nr_class*(nr_class-1)/2);
  1980. for(i=0;i<nr_class*(nr_class-1)/2;i++)
  1981. {
  1982. model->probA[i] = probA[i];
  1983. model->probB[i] = probB[i];
  1984. }
  1985. }
  1986. else
  1987. {
  1988. model->probA=NULL;
  1989. model->probB=NULL;
  1990. }
  1991. int total_sv = 0;
  1992. int *nz_count = Malloc(int,nr_class);
  1993. model->nSV = Malloc(int,nr_class);
  1994. for(i=0;i<nr_class;i++)
  1995. {
  1996. int nSV = 0;
  1997. for(int j=0;j<count[i];j++)
  1998. if(nonzero[start[i]+j])
  1999. {
  2000. ++nSV;
  2001. ++total_sv;
  2002. }
  2003. model->nSV[i] = nSV;
  2004. nz_count[i] = nSV;
  2005. }
  2006. info("Total nSV = %d\n",total_sv);
  2007. model->l = total_sv;
  2008. model->SV = Malloc(svm_node *,total_sv);
  2009. p = 0;
  2010. for(i=0;i<l;i++)
  2011. if(nonzero[i]) model->SV[p++] = x[i];
  2012. int *nz_start = Malloc(int,nr_class);
  2013. nz_start[0] = 0;
  2014. for(i=1;i<nr_class;i++)
  2015. nz_start[i] = nz_start[i-1]+nz_count[i-1];
  2016. model->sv_coef = Malloc(double *,nr_class-1);
  2017. for(i=0;i<nr_class-1;i++)
  2018. model->sv_coef[i] = Malloc(double,total_sv);
  2019. p = 0;
  2020. for(i=0;i<nr_class;i++)
  2021. for(int j=i+1;j<nr_class;j++)
  2022. {
  2023. // classifier (i,j): coefficients with
  2024. // i are in sv_coef[j-1][nz_start[i]...],
  2025. // j are in sv_coef[i][nz_start[j]...]
  2026. int si = start[i];
  2027. int sj = start[j];
  2028. int ci = count[i];
  2029. int cj = count[j];
  2030. int q = nz_start[i];
  2031. int k;
  2032. for(k=0;k<ci;k++)
  2033. if(nonzero[si+k])
  2034. model->sv_coef[j-1][q++] = f[p].alpha[k];
  2035. q = nz_start[j];
  2036. for(k=0;k<cj;k++)
  2037. if(nonzero[sj+k])
  2038. model->sv_coef[i][q++] = f[p].alpha[ci+k];
  2039. ++p;
  2040. }
  2041. free(label);
  2042. free(probA);
  2043. free(probB);
  2044. free(count);
  2045. free(perm);
  2046. free(start);
  2047. free(x);
  2048. free(weighted_C);
  2049. free(nonzero);
  2050. for(i=0;i<nr_class*(nr_class-1)/2;i++)
  2051. free(f[i].alpha);
  2052. free(f);
  2053. free(nz_count);
  2054. free(nz_start);
  2055. }
  2056. return model;
  2057. }
  2058. // Stratified cross validation
  2059. void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
  2060. {
  2061. int i;
  2062. int *fold_start = Malloc(int,nr_fold+1);
  2063. int l = prob->l;
  2064. int *perm = Malloc(int,l);
  2065. int nr_class;
  2066. // stratified cv may not give leave-one-out rate
  2067. // Each class to l folds -> some folds may have zero elements
  2068. if((param->svm_type == C_SVC ||
  2069. param->svm_type == NU_SVC) && nr_fold < l)
  2070. {
  2071. int *start = NULL;
  2072. int *label = NULL;
  2073. int *count = NULL;
  2074. svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
  2075. // random shuffle and then data grouped by fold using the array perm
  2076. int *fold_count = Malloc(int,nr_fold);
  2077. int c;
  2078. int *index = Malloc(int,l);
  2079. for(i=0;i<l;i++)
  2080. index[i]=perm[i];
  2081. for (c=0; c<nr_class; c++)
  2082. for(i=0;i<count[c];i++)
  2083. {
  2084. int j = i+rand()%(count[c]-i);
  2085. swap(index[start[c]+j],index[start[c]+i]);
  2086. }
  2087. for(i=0;i<nr_fold;i++)
  2088. {
  2089. fold_count[i] = 0;
  2090. for (c=0; c<nr_class;c++)
  2091. fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
  2092. }
  2093. fold_start[0]=0;
  2094. for (i=1;i<=nr_fold;i++)
  2095. fold_start[i] = fold_start[i-1]+fold_count[i-1];
  2096. for (c=0; c<nr_class;c++)
  2097. for(i=0;i<nr_fold;i++)
  2098. {
  2099. int begin = start[c]+i*count[c]/nr_fold;
  2100. int end = start[c]+(i+1)*count[c]/nr_fold;
  2101. for(int j=begin;j<end;j++)
  2102. {
  2103. perm[fold_start[i]] = index[j];
  2104. fold_start[i]++;
  2105. }
  2106. }
  2107. fold_start[0]=0;
  2108. for (i=1;i<=nr_fold;i++)
  2109. fold_start[i] = fold_start[i-1]+fold_count[i-1];
  2110. free(start);
  2111. free(label);
  2112. free(count);
  2113. free(index);
  2114. free(fold_count);
  2115. }
  2116. else
  2117. {
  2118. for(i=0;i<l;i++) perm[i]=i;
  2119. for(i=0;i<l;i++)
  2120. {
  2121. int j = i+rand()%(l-i);
  2122. swap(perm[i],perm[j]);
  2123. }
  2124. for(i=0;i<=nr_fold;i++)
  2125. fold_start[i]=i*l/nr_fold;
  2126. }
  2127. for(i=0;i<nr_fold;i++)
  2128. {
  2129. int begin = fold_start[i];
  2130. int end = fold_start[i+1];
  2131. int j,k;
  2132. struct svm_problem subprob;
  2133. subprob.l = l-(end-begin);
  2134. subprob.x = Malloc(struct svm_node*,subprob.l);
  2135. subprob.y = Malloc(double,subprob.l);
  2136. k=0;
  2137. for(j=0;j<begin;j++)
  2138. {
  2139. subprob.x[k] = prob->x[perm[j]];
  2140. subprob.y[k] = prob->y[perm[j]];
  2141. ++k;
  2142. }
  2143. for(j=end;j<l;j++)
  2144. {
  2145. subprob.x[k] = prob->x[perm[j]];
  2146. subprob.y[k] = prob->y[perm[j]];
  2147. ++k;
  2148. }
  2149. struct svm_model *submodel = svm_train(&subprob,param);
  2150. if(param->probability &&
  2151. (param->svm_type == C_SVC || param->svm_type == NU_SVC))
  2152. {
  2153. double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
  2154. for(j=begin;j<end;j++)
  2155. target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
  2156. free(prob_estimates);
  2157. }
  2158. else
  2159. for(j=begin;j<end;j++)
  2160. target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
  2161. svm_destroy_model(submodel);
  2162. free(subprob.x);
  2163. free(subprob.y);
  2164. }
  2165. free(fold_start);
  2166. free(perm);
  2167. }
  2168. int svm_get_svm_type(const svm_model *model)
  2169. {
  2170. return model->param.svm_type;
  2171. }
  2172. int svm_get_nr_class(const svm_model *model)
  2173. {
  2174. return model->nr_class;
  2175. }
  2176. void svm_get_labels(const svm_model *model, int* label)
  2177. {
  2178. if (model->label != NULL)
  2179. for(int i=0;i<model->nr_class;i++)
  2180. label[i] = model->label[i];
  2181. }
  2182. double svm_get_svr_probability(const svm_model *model)
  2183. {
  2184. if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
  2185. model->probA!=NULL)
  2186. return model->probA[0];
  2187. else
  2188. {
  2189. info("Model doesn't contain information for SVR probability inference\n");
  2190. return 0;
  2191. }
  2192. }
  2193. void svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
  2194. {
  2195. if(model->param.svm_type == ONE_CLASS ||
  2196. model->param.svm_type == EPSILON_SVR ||
  2197. model->param.svm_type == NU_SVR)
  2198. {
  2199. double *sv_coef = model->sv_coef[0];
  2200. double sum = 0;
  2201. for(int i=0;i<model->l;i++)
  2202. sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
  2203. sum -= model->rho[0];
  2204. *dec_values = sum;
  2205. }
  2206. else
  2207. {
  2208. int i;
  2209. int nr_class = model->nr_class;
  2210. int l = model->l;
  2211. double *kvalue = Malloc(double,l);
  2212. for(i=0;i<l;i++)
  2213. kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
  2214. int *start = Malloc(int,nr_class);
  2215. start[0] = 0;
  2216. for(i=1;i<nr_class;i++)
  2217. start[i] = start[i-1]+model->nSV[i-1];
  2218. int p=0;
  2219. for(i=0;i<nr_class;i++)
  2220. for(int j=i+1;j<nr_class;j++)
  2221. {
  2222. double sum = 0;
  2223. int si = start[i];
  2224. int sj = start[j];
  2225. int ci = model->nSV[i];
  2226. int cj = model->nSV[j];
  2227. int k;
  2228. double *coef1 = model->sv_coef[j-1];
  2229. double *coef2 = model->sv_coef[i];
  2230. for(k=0;k<ci;k++)
  2231. sum += coef1[si+k] * kvalue[si+k];
  2232. for(k=0;k<cj;k++)
  2233. sum += coef2[sj+k] * kvalue[sj+k];
  2234. sum -= model->rho[p];
  2235. dec_values[p] = sum;
  2236. p++;
  2237. }
  2238. free(kvalue);
  2239. free(start);
  2240. }
  2241. }
  2242. double svm_predict(const svm_model *model, const svm_node *x)
  2243. {
  2244. if(model->param.svm_type == ONE_CLASS ||
  2245. model->param.svm_type == EPSILON_SVR ||
  2246. model->param.svm_type == NU_SVR)
  2247. {
  2248. double res;
  2249. svm_predict_values(model, x, &res);
  2250. if(model->param.svm_type == ONE_CLASS)
  2251. return (res>0)?1:-1;
  2252. else
  2253. return res;
  2254. }
  2255. else
  2256. {
  2257. int i;
  2258. int nr_class = model->nr_class;
  2259. double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
  2260. svm_predict_values(model, x, dec_values);
  2261. int *vote = Malloc(int,nr_class);
  2262. for(i=0;i<nr_class;i++)
  2263. vote[i] = 0;
  2264. int pos=0;
  2265. for(i=0;i<nr_class;i++)
  2266. for(int j=i+1;j<nr_class;j++)
  2267. {
  2268. if(dec_values[pos++] > 0)
  2269. ++vote[i];
  2270. else
  2271. ++vote[j];
  2272. }
  2273. int vote_max_idx = 0;
  2274. for(i=1;i<nr_class;i++)
  2275. if(vote[i] > vote[vote_max_idx])
  2276. vote_max_idx = i;
  2277. free(vote);
  2278. free(dec_values);
  2279. return model->label[vote_max_idx];
  2280. }
  2281. }
  2282. double svm_predict_probability(
  2283. const svm_model *model, const svm_node *x, double *prob_estimates)
  2284. {
  2285. if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
  2286. model->probA!=NULL && model->probB!=NULL)
  2287. {
  2288. int i;
  2289. int nr_class = model->nr_class;
  2290. double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
  2291. svm_predict_values(model, x, dec_values);
  2292. double min_prob=1e-7;
  2293. double **pairwise_prob=Malloc(double *,nr_class);
  2294. for(i=0;i<nr_class;i++)
  2295. pairwise_prob[i]=Malloc(double,nr_class);
  2296. int k=0;
  2297. for(i=0;i<nr_class;i++)
  2298. for(int j=i+1;j<nr_class;j++)
  2299. {
  2300. pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
  2301. pairwise_prob[j][i]=1-pairwise_prob[i][j];
  2302. k++;
  2303. }
  2304. multiclass_probability(nr_class,pairwise_prob,prob_estimates);
  2305. int prob_max_idx = 0;
  2306. for(i=1;i<nr_class;i++)
  2307. if(prob_estimates[i] > prob_estimates[prob_max_idx])
  2308. prob_max_idx = i;
  2309. for(i=0;i<nr_class;i++)
  2310. free(pairwise_prob[i]);
  2311. free(dec_values);
  2312. free(pairwise_prob);
  2313. return model->label[prob_max_idx];
  2314. }
  2315. else
  2316. return svm_predict(model, x);
  2317. }
  2318. const char *svm_type_table[] =
  2319. {
  2320. "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
  2321. };
  2322. const char *kernel_type_table[]=
  2323. {
  2324. "linear","polynomial","rbf","sigmoid","precomputed",NULL
  2325. };
  2326. int svm_save_model(const char *model_file_name, const svm_model *model)
  2327. {
  2328. FILE *fp = fopen(model_file_name,"w");
  2329. if(fp==NULL) return -1;
  2330. const svm_parameter& param = model->param;
  2331. fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
  2332. fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
  2333. if(param.kernel_type == POLY)
  2334. fprintf(fp,"degree %d\n", param.degree);
  2335. if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
  2336. fprintf(fp,"gamma %g\n", param.gamma);
  2337. if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
  2338. fprintf(fp,"coef0 %g\n", param.coef0);
  2339. int nr_class = model->nr_class;
  2340. int l = model->l;
  2341. fprintf(fp, "nr_class %d\n", nr_class);
  2342. fprintf(fp, "total_sv %d\n",l);
  2343. {
  2344. fprintf(fp, "rho");
  2345. for(int i=0;i<nr_class*(nr_class-1)/2;i++)
  2346. fprintf(fp," %g",model->rho[i]);
  2347. fprintf(fp, "\n");
  2348. }
  2349. if(model->label)
  2350. {
  2351. fprintf(fp, "label");
  2352. for(int i=0;i<nr_class;i++)
  2353. fprintf(fp," %d",model->label[i]);
  2354. fprintf(fp, "\n");
  2355. }
  2356. if(model->probA) // regression has probA only
  2357. {
  2358. fprintf(fp, "probA");
  2359. for(int i=0;i<nr_class*(nr_class-1)/2;i++)
  2360. fprintf(fp," %g",model->probA[i]);
  2361. fprintf(fp, "\n");
  2362. }
  2363. if(model->probB)
  2364. {
  2365. fprintf(fp, "probB");
  2366. for(int i=0;i<nr_class*(nr_class-1)/2;i++)
  2367. fprintf(fp," %g",model->probB[i]);
  2368. fprintf(fp, "\n");
  2369. }
  2370. if(model->nSV)
  2371. {
  2372. fprintf(fp, "nr_sv");
  2373. for(int i=0;i<nr_class;i++)
  2374. fprintf(fp," %d",model->nSV[i]);
  2375. fprintf(fp, "\n");
  2376. }
  2377. fprintf(fp, "SV\n");
  2378. const double * const *sv_coef = model->sv_coef;
  2379. const svm_node * const *SV = model->SV;
  2380. for(int i=0;i<l;i++)
  2381. {
  2382. for(int j=0;j<nr_class-1;j++)
  2383. fprintf(fp, "%.16g ",sv_coef[j][i]);
  2384. const svm_node *p = SV[i];
  2385. if(param.kernel_type == PRECOMPUTED)
  2386. fprintf(fp,"0:%d ",(int)(p->value));
  2387. else
  2388. while(p->index != -1)
  2389. {
  2390. fprintf(fp,"%d:%.8g ",p->index,p->value);
  2391. p++;
  2392. }
  2393. fprintf(fp, "\n");
  2394. }
  2395. if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
  2396. else return 0;
  2397. }
  2398. svm_model *svm_load_model(const char *model_file_name)
  2399. {
  2400. FILE *fp = fopen(model_file_name,"r");
  2401. if(fp==NULL) return NULL;
  2402. // read parameters
  2403. svm_model *model = Malloc(svm_model,1);
  2404. svm_parameter& param = model->param;
  2405. model->rho = NULL;
  2406. model->probA = NULL;
  2407. model->probB = NULL;
  2408. model->label = NULL;
  2409. model->nSV = NULL;
  2410. char cmd[81];
  2411. while(1)
  2412. {
  2413. fscanf(fp,"%80s",cmd);
  2414. if(strcmp(cmd,"svm_type")==0)
  2415. {
  2416. fscanf(fp,"%80s",cmd);
  2417. int i;
  2418. for(i=0;svm_type_table[i];i++)
  2419. {
  2420. if(strcmp(svm_type_table[i],cmd)==0)
  2421. {
  2422. param.svm_type=i;
  2423. break;
  2424. }
  2425. }
  2426. if(svm_type_table[i] == NULL)
  2427. {
  2428. fprintf(stderr,"unknown svm type.\n");
  2429. free(model->rho);
  2430. free(model->label);
  2431. free(model->nSV);
  2432. free(model);
  2433. return NULL;
  2434. }
  2435. }
  2436. else if(strcmp(cmd,"kernel_type")==0)
  2437. {
  2438. fscanf(fp,"%80s",cmd);
  2439. int i;
  2440. for(i=0;kernel_type_table[i];i++)
  2441. {
  2442. if(strcmp(kernel_type_table[i],cmd)==0)
  2443. {
  2444. param.kernel_type=i;
  2445. break;
  2446. }
  2447. }
  2448. if(kernel_type_table[i] == NULL)
  2449. {
  2450. fprintf(stderr,"unknown kernel function.\n");
  2451. free(model->rho);
  2452. free(model->label);
  2453. free(model->nSV);
  2454. free(model);
  2455. return NULL;
  2456. }
  2457. }
  2458. else if(strcmp(cmd,"degree")==0)
  2459. fscanf(fp,"%d",&param.degree);
  2460. else if(strcmp(cmd,"gamma")==0)
  2461. fscanf(fp,"%lf",&param.gamma);
  2462. else if(strcmp(cmd,"coef0")==0)
  2463. fscanf(fp,"%lf",&param.coef0);
  2464. else if(strcmp(cmd,"nr_class")==0)
  2465. fscanf(fp,"%d",&model->nr_class);
  2466. else if(strcmp(cmd,"total_sv")==0)
  2467. fscanf(fp,"%d",&model->l);
  2468. else if(strcmp(cmd,"rho")==0)
  2469. {
  2470. int n = model->nr_class * (model->nr_class-1)/2;
  2471. model->rho = Malloc(double,n);
  2472. for(int i=0;i<n;i++)
  2473. fscanf(fp,"%lf",&model->rho[i]);
  2474. }
  2475. else if(strcmp(cmd,"label")==0)
  2476. {
  2477. int n = model->nr_class;
  2478. model->label = Malloc(int,n);
  2479. for(int i=0;i<n;i++)
  2480. fscanf(fp,"%d",&model->label[i]);
  2481. }
  2482. else if(strcmp(cmd,"probA")==0)
  2483. {
  2484. int n = model->nr_class * (model->nr_class-1)/2;
  2485. model->probA = Malloc(double,n);
  2486. for(int i=0;i<n;i++)
  2487. fscanf(fp,"%lf",&model->probA[i]);
  2488. }
  2489. else if(strcmp(cmd,"probB")==0)
  2490. {
  2491. int n = model->nr_class * (model->nr_class-1)/2;
  2492. model->probB = Malloc(double,n);
  2493. for(int i=0;i<n;i++)
  2494. fscanf(fp,"%lf",&model->probB[i]);
  2495. }
  2496. else if(strcmp(cmd,"nr_sv")==0)
  2497. {
  2498. int n = model->nr_class;
  2499. model->nSV = Malloc(int,n);
  2500. for(int i=0;i<n;i++)
  2501. fscanf(fp,"%d",&model->nSV[i]);
  2502. }
  2503. else if(strcmp(cmd,"SV")==0)
  2504. {
  2505. while(1)
  2506. {
  2507. int c = getc(fp);
  2508. if(c==EOF || c=='\n') break;
  2509. }
  2510. break;
  2511. }
  2512. else
  2513. {
  2514. fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
  2515. free(model->rho);
  2516. free(model->label);
  2517. free(model->nSV);
  2518. free(model);
  2519. return NULL;
  2520. }
  2521. }
  2522. // read sv_coef and SV
  2523. int elements = 0;
  2524. long pos = ftell(fp);
  2525. while(1)
  2526. {
  2527. int c = fgetc(fp);
  2528. switch(c)
  2529. {
  2530. case '\n':
  2531. // count the '-1' element
  2532. case ':':
  2533. ++elements;
  2534. break;
  2535. case EOF:
  2536. goto out;
  2537. default:
  2538. ;
  2539. }
  2540. }
  2541. out:
  2542. fseek(fp,pos,SEEK_SET);
  2543. int m = model->nr_class - 1;
  2544. int l = model->l;
  2545. model->sv_coef = Malloc(double *,m);
  2546. int i;
  2547. for(i=0;i<m;i++)
  2548. model->sv_coef[i] = Malloc(double,l);
  2549. model->SV = Malloc(svm_node*,l);
  2550. svm_node *x_space=NULL;
  2551. if(l>0) x_space = Malloc(svm_node,elements);
  2552. int j=0;
  2553. for(i=0;i<l;i++)
  2554. {
  2555. model->SV[i] = &x_space[j];
  2556. for(int k=0;k<m;k++)
  2557. fscanf(fp,"%lf",&model->sv_coef[k][i]);
  2558. while(1)
  2559. {
  2560. int c;
  2561. do {
  2562. c = getc(fp);
  2563. if(c=='\n') goto out2;
  2564. } while(isspace(c));
  2565. ungetc(c,fp);
  2566. fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value));
  2567. ++j;
  2568. }
  2569. out2:
  2570. x_space[j++].index = -1;
  2571. }
  2572. if (ferror(fp) != 0 || fclose(fp) != 0) return NULL;
  2573. model->free_sv = 1; // XXX
  2574. return model;
  2575. }
  2576. void svm_destroy_model(svm_model* model)
  2577. {
  2578. if(model->free_sv && model->l > 0)
  2579. free((void *)(model->SV[0]));
  2580. for(int i=0;i<model->nr_class-1;i++)
  2581. free(model->sv_coef[i]);
  2582. free(model->SV);
  2583. free(model->sv_coef);
  2584. free(model->rho);
  2585. free(model->label);
  2586. free(model->probA);
  2587. free(model->probB);
  2588. free(model->nSV);
  2589. free(model);
  2590. }
  2591. void svm_destroy_param(svm_parameter* param)
  2592. {
  2593. free(param->weight_label);
  2594. free(param->weight);
  2595. }
  2596. const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
  2597. {
  2598. // svm_type
  2599. int svm_type = param->svm_type;
  2600. if(svm_type != C_SVC &&
  2601. svm_type != NU_SVC &&
  2602. svm_type != ONE_CLASS &&
  2603. svm_type != EPSILON_SVR &&
  2604. svm_type != NU_SVR)
  2605. return "unknown svm type";
  2606. // kernel_type, degree
  2607. int kernel_type = param->kernel_type;
  2608. if(kernel_type != LINEAR &&
  2609. kernel_type != POLY &&
  2610. kernel_type != RBF &&
  2611. kernel_type != SIGMOID &&
  2612. kernel_type != PRECOMPUTED)
  2613. return "unknown kernel type";
  2614. if(param->degree < 0)
  2615. return "degree of polynomial kernel < 0";
  2616. // cache_size,eps,C,nu,p,shrinking
  2617. if(param->cache_size <= 0)
  2618. return "cache_size <= 0";
  2619. if(param->eps <= 0)
  2620. return "eps <= 0";
  2621. if(svm_type == C_SVC ||
  2622. svm_type == EPSILON_SVR ||
  2623. svm_type == NU_SVR)
  2624. if(param->C <= 0)
  2625. return "C <= 0";
  2626. if(svm_type == NU_SVC ||
  2627. svm_type == ONE_CLASS ||
  2628. svm_type == NU_SVR)
  2629. if(param->nu <= 0 || param->nu > 1)
  2630. return "nu <= 0 or nu > 1";
  2631. if(svm_type == EPSILON_SVR)
  2632. if(param->p < 0)
  2633. return "p < 0";
  2634. if(param->shrinking != 0 &&
  2635. param->shrinking != 1)
  2636. return "shrinking != 0 and shrinking != 1";
  2637. if(param->probability != 0 &&
  2638. param->probability != 1)
  2639. return "probability != 0 and probability != 1";
  2640. if(param->probability == 1 &&
  2641. svm_type == ONE_CLASS)
  2642. return "one-class SVM probability output not supported yet";
  2643. // check whether nu-svc is feasible
  2644. if(svm_type == NU_SVC)
  2645. {
  2646. int l = prob->l;
  2647. int max_nr_class = 16;
  2648. int nr_class = 0;
  2649. int *label = Malloc(int,max_nr_class);
  2650. int *count = Malloc(int,max_nr_class);
  2651. int i;
  2652. for(i=0;i<l;i++)
  2653. {
  2654. int this_label = (int)prob->y[i];
  2655. int j;
  2656. for(j=0;j<nr_class;j++)
  2657. if(this_label == label[j])
  2658. {
  2659. ++count[j];
  2660. break;
  2661. }
  2662. if(j == nr_class)
  2663. {
  2664. if(nr_class == max_nr_class)
  2665. {
  2666. max_nr_class *= 2;
  2667. label = (int *)realloc(label,max_nr_class*sizeof(int));
  2668. count = (int *)realloc(count,max_nr_class*sizeof(int));
  2669. }
  2670. label[nr_class] = this_label;
  2671. count[nr_class] = 1;
  2672. ++nr_class;
  2673. }
  2674. }
  2675. for(i=0;i<nr_class;i++)
  2676. {
  2677. int n1 = count[i];
  2678. for(int j=i+1;j<nr_class;j++)
  2679. {
  2680. int n2 = count[j];
  2681. if(param->nu*(n1+n2)/2 > min(n1,n2))
  2682. {
  2683. free(label);
  2684. free(count);
  2685. return "specified nu is infeasible";
  2686. }
  2687. }
  2688. }
  2689. free(label);
  2690. free(count);
  2691. }
  2692. return NULL;
  2693. }
  2694. int svm_check_probability_model(const svm_model *model)
  2695. {
  2696. return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
  2697. model->probA!=NULL && model->probB!=NULL) ||
  2698. ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
  2699. model->probA!=NULL);
  2700. }