Network.php 29 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324
  1. <?php
  2. /**
  3. * Artificial Neural Network - Version 2.2
  4. *
  5. * For updates and changes visit the project page at http://ann.thwien.de/
  6. *
  7. *
  8. *
  9. * <b>LICENCE</b>
  10. *
  11. * The BSD 2-Clause License
  12. *
  13. * http://opensource.org/licenses/bsd-license.php
  14. *
  15. * Copyright (c) 2002, Eddy Young
  16. * Copyright (c) 2007 - 2012, Thomas Wien
  17. * All rights reserved.
  18. *
  19. * Redistribution and use in source and binary forms, with or without
  20. * modification, are permitted provided that the following conditions
  21. * are met:
  22. *
  23. * 1. Redistributions of source code must retain the above copyright
  24. * notice, this list of conditions and the following disclaimer.
  25. *
  26. * 2. Redistributions in binary form must reproduce the above copyright
  27. * notice, this list of conditions and the following disclaimer in the
  28. * documentation and/or other materials provided with the distribution.
  29. *
  30. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  31. * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  32. * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
  33. * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
  34. * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  35. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
  36. * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  37. * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  38. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  39. * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
  40. * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  41. * POSSIBILITY OF SUCH DAMAGE.
  42. *
  43. * @author Eddy Young <jeyoung_at_priscimon_dot_com>
  44. * @author Thomas Wien <info_at_thwien_dot_de>
  45. * @version ANN Version 1.0 by Eddy Young
  46. * @version ANN Version 2.2 by Thomas Wien
  47. * @copyright Copyright (c) 2002 by Eddy Young
  48. * @copyright Copyright (c) 2007-2012 by Thomas Wien
  49. * @package ANN
  50. */
  51. namespace ANN;
  52. /**
  53. * @package ANN
  54. * @access public
  55. */
  56. class Network extends Filesystem implements InterfaceLoadable
  57. {
  58. /**#@+
  59. * @ignore
  60. */
  61. /**
  62. * @var Layer
  63. */
  64. protected $objOutputLayer = null;
  65. /**
  66. * @var array
  67. */
  68. protected $arrHiddenLayers = array();
  69. /**
  70. * @var array
  71. */
  72. protected $arrInputs = null;
  73. /**
  74. * @var array
  75. */
  76. protected $arrOutputs = null;
  77. /**
  78. * @var integer
  79. */
  80. protected $intTotalLoops = 0;
  81. /**
  82. * @var integer
  83. */
  84. protected $intTotalTrainings = 0;
  85. /**
  86. * @var integer
  87. */
  88. protected $intTotalActivations = 0;
  89. /**
  90. * @var integer
  91. */
  92. protected $intTotalActivationsRequests = 0;
  93. /**
  94. * @var integer
  95. */
  96. protected $intNumberOfHiddenLayers = null;
  97. /**
  98. * @var integer
  99. */
  100. protected $intNumberOfHiddenLayersDec = null; // decremented value
  101. /**
  102. * @var integer
  103. */
  104. protected $intMaxExecutionTime = 0;
  105. /**
  106. * @var integer
  107. */
  108. protected $intNumberEpoch = 0;
  109. /**
  110. * @var boolean
  111. */
  112. protected $boolLoggingWeights = FALSE;
  113. /**
  114. * @var boolean
  115. */
  116. protected $boolLoggingNetworkErrors = FALSE;
  117. /**
  118. * @var boolean
  119. */
  120. protected $boolTrained = FALSE;
  121. /**
  122. * @var integer
  123. */
  124. protected $intTrainingTime = 0; // Seconds
  125. /**
  126. * @var Logging
  127. */
  128. protected $objLoggingWeights = null;
  129. /**
  130. * @var Logging
  131. */
  132. protected $objLoggingNetworkErrors = null;
  133. /**
  134. * @var boolean
  135. */
  136. protected $boolNetworkActivated = FALSE;
  137. /**
  138. * @var array
  139. */
  140. protected $arrTrainingComplete = array();
  141. /**
  142. * @var integer
  143. */
  144. protected $intNumberOfNeuronsPerLayer = 0;
  145. /**
  146. * @var float
  147. */
  148. protected $floatOutputErrorTolerance = 0.02;
  149. /**
  150. * @var float
  151. */
  152. public $floatMomentum = 0.95;
  153. /**
  154. * @var array
  155. */
  156. private $arrInputsToTrain = array();
  157. /**
  158. * @var integer
  159. */
  160. private $intInputsToTrainIndex = -1;
  161. /**
  162. * @var integer
  163. */
  164. public $intOutputType = self::OUTPUT_LINEAR;
  165. /**
  166. * @var float
  167. */
  168. public $floatLearningRate = 0.7;
  169. /**
  170. * @var boolean
  171. */
  172. public $boolFirstLoopOfTraining = TRUE;
  173. /**
  174. * @var boolean
  175. */
  176. public $boolFirstEpochOfTraining = TRUE;
  177. /**#@-*/
  178. /**
  179. * Linear output type
  180. */
  181. const OUTPUT_LINEAR = 1;
  182. /**
  183. * Binary output type
  184. */
  185. const OUTPUT_BINARY = 2;
  186. /**
  187. * @param integer $intNumberOfHiddenLayers (Default: 1)
  188. * @param integer $intNumberOfNeuronsPerLayer (Default: 6)
  189. * @param integer $intNumberOfOutputs (Default: 1)
  190. * @uses Exception::__construct()
  191. * @uses setMaxExecutionTime()
  192. * @uses createHiddenLayers()
  193. * @uses createOutputLayer()
  194. * @throws Exception
  195. */
  196. public function __construct($intNumberOfHiddenLayers = 1, $intNumberOfNeuronsPerLayer = 6, $intNumberOfOutputs = 1)
  197. {
  198. if(!is_integer($intNumberOfHiddenLayers) || $intNumberOfHiddenLayers < 1)
  199. throw new Exception('Constraints: $intNumberOfHiddenLayers must be a positiv integer >= 1');
  200. if(!is_integer($intNumberOfNeuronsPerLayer) || $intNumberOfNeuronsPerLayer < 2)
  201. throw new Exception('Constraints: $intNumberOfNeuronsPerLayer must be a positiv integer number >= 2');
  202. if(!is_integer($intNumberOfOutputs) || $intNumberOfOutputs < 1)
  203. throw new Exception('Constraints: $intNumberOfOutputs must be a positiv integer number >= 1');
  204. $this->createOutputLayer($intNumberOfOutputs);
  205. $this->createHiddenLayers($intNumberOfHiddenLayers, $intNumberOfNeuronsPerLayer);
  206. $this->intNumberOfHiddenLayers = $intNumberOfHiddenLayers;
  207. $this->intNumberOfHiddenLayersDec = $this->intNumberOfHiddenLayers - 1;
  208. $this->intNumberOfNeuronsPerLayer = $intNumberOfNeuronsPerLayer;
  209. $this->setMaxExecutionTime();
  210. }
  211. /**
  212. * @param array $arrInputs
  213. */
  214. protected function setInputs($arrInputs)
  215. {
  216. if(!is_array($arrInputs))
  217. throw new Exception('Constraints: $arrInputs should be an array');
  218. $this->arrInputs = $arrInputs;
  219. $this->intNumberEpoch = count($arrInputs);
  220. $this->nextIndexInputToTrain = 0;
  221. $this->boolNetworkActivated = FALSE;
  222. }
  223. /**
  224. * @param array $arrOutputs
  225. * @uses Exception::__construct()
  226. * @uses Layer::getNeuronsCount()
  227. * @throws Exception
  228. */
  229. protected function setOutputs($arrOutputs)
  230. {
  231. if(isset($arrOutputs[0]) && is_array($arrOutputs[0]))
  232. if(count($arrOutputs[0]) != $this->objOutputLayer->getNeuronsCount())
  233. throw new Exception('Count of arrOutputs doesn\'t fit to number of arrOutputs on instantiation of \\'. __NAMESPACE__ .'\\Network');
  234. $this->arrOutputs = $arrOutputs;
  235. $this->boolNetworkActivated = FALSE;
  236. }
  237. /**
  238. * Set Values for training or using network
  239. *
  240. * Set Values of inputs and outputs for training or just inputs for using
  241. * already trained network.
  242. *
  243. * <code>
  244. * $objNetwork = new \ANN\Network(2, 4, 1);
  245. *
  246. * $objValues = new \ANN\Values;
  247. *
  248. * $objValues->train()
  249. * ->input(0.12, 0.11, 0.15)
  250. * ->output(0.56);
  251. *
  252. * $objNetwork->setValues($objValues);
  253. * </code>
  254. *
  255. * @param Values $objValues
  256. * @uses Values::getInputsArray()
  257. * @uses Values::getOutputsArray()
  258. * @uses setInputs()
  259. * @uses setOutputs()
  260. * @since 2.0.6
  261. */
  262. public function setValues(Values $objValues)
  263. {
  264. $this->setInputs($objValues->getInputsArray());
  265. $this->setOutputs($objValues->getOutputsArray());
  266. }
  267. /**
  268. * @param array $arrInputs
  269. * @uses Layer::setInputs()
  270. */
  271. protected function setInputsToTrain($arrInputs)
  272. {
  273. $this->arrHiddenLayers[0]->setInputs($arrInputs);
  274. $this->boolNetworkActivated = FALSE;
  275. }
  276. /**
  277. * Get the output values
  278. *
  279. * Get the output values to the related input values set by setValues(). This
  280. * method returns the output values as a two-dimensional array.
  281. *
  282. * @return array two-dimensional array
  283. * @uses activate()
  284. * @uses getCountInputs()
  285. * @uses Layer::getOutputs()
  286. * @uses Layer::getThresholdOutputs()
  287. * @uses setInputsToTrain()
  288. */
  289. public function getOutputs()
  290. {
  291. $arrReturnOutputs = array();
  292. $intCountInputs = $this->getCountInputs();
  293. for ($intIndex = 0; $intIndex < $intCountInputs; $intIndex++)
  294. {
  295. $this->setInputsToTrain($this->arrInputs[$intIndex]);
  296. $this->activate();
  297. switch($this->intOutputType)
  298. {
  299. case self::OUTPUT_LINEAR:
  300. $arrReturnOutputs[] = $this->objOutputLayer->getOutputs();
  301. break;
  302. case self::OUTPUT_BINARY:
  303. $arrReturnOutputs[] = $this->objOutputLayer->getThresholdOutputs();
  304. break;
  305. }
  306. }
  307. return $arrReturnOutputs;
  308. }
  309. /**
  310. * @param integer $intKeyInput
  311. * @return array
  312. * @uses activate()
  313. * @uses Layer::getOutputs()
  314. * @uses Layer::getThresholdOutputs()
  315. * @uses setInputsToTrain()
  316. */
  317. public function getOutputsByInputKey($intKeyInput)
  318. {
  319. $this->setInputsToTrain($this->arrInputs[$intKeyInput]);
  320. $this->activate();
  321. switch($this->intOutputType)
  322. {
  323. case self::OUTPUT_LINEAR:
  324. return $this->objOutputLayer->getOutputs();
  325. case self::OUTPUT_BINARY:
  326. return $this->objOutputLayer->getThresholdOutputs();
  327. }
  328. }
  329. /**
  330. * @param integer $intNumberOfHiddenLayers
  331. * @param integer $intNumberOfNeuronsPerLayer
  332. * @uses Layer::__construct()
  333. */
  334. protected function createHiddenLayers($intNumberOfHiddenLayers, $intNumberOfNeuronsPerLayer)
  335. {
  336. $layerId = $intNumberOfHiddenLayers;
  337. for ($i = 0; $i < $intNumberOfHiddenLayers; $i++)
  338. {
  339. $layerId--;
  340. if($i == 0)
  341. $nextLayer = $this->objOutputLayer;
  342. if($i > 0)
  343. $nextLayer = $this->arrHiddenLayers[$layerId + 1];
  344. $this->arrHiddenLayers[$layerId] = new Layer($this, $intNumberOfNeuronsPerLayer, $nextLayer);
  345. }
  346. ksort($this->arrHiddenLayers);
  347. }
  348. /**
  349. * @param integer $intNumberOfOutputs
  350. * @uses Layer::__construct()
  351. */
  352. protected function createOutputLayer($intNumberOfOutputs)
  353. {
  354. $this->objOutputLayer = new Layer($this, $intNumberOfOutputs);
  355. }
  356. /**
  357. * @uses Layer::setInputs()
  358. * @uses Layer::activate()
  359. * @uses Layer::getOutputs()
  360. */
  361. protected function activate()
  362. {
  363. $this->intTotalActivationsRequests++;
  364. if($this->boolNetworkActivated)
  365. return;
  366. $this->arrHiddenLayers[0]->activate();
  367. $this->boolNetworkActivated = TRUE;
  368. $this->intTotalActivations++;
  369. }
  370. /**
  371. * @return boolean
  372. * @uses Exception::__construct()
  373. * @uses setInputs()
  374. * @uses setOutputs()
  375. * @uses hasTimeLeftForTraining()
  376. * @uses isTrainingComplete()
  377. * @uses isTrainingCompleteByEpoch()
  378. * @uses setInputsToTrain()
  379. * @uses training()
  380. * @uses isEpoch()
  381. * @uses logWeights()
  382. * @uses logNetworkErrors()
  383. * @uses getNextIndexInputsToTrain()
  384. * @uses isTrainingCompleteByInputKey()
  385. * @uses setDynamicLearningRate()
  386. * @uses detectOutputType()
  387. * @throws Exception
  388. */
  389. public function train()
  390. {
  391. if(!$this->arrInputs)
  392. throw new Exception('No arrInputs defined. Use \\'. __NAMESPACE__ .'\\Network::setValues().');
  393. if(!$this->arrOutputs)
  394. throw new Exception('No arrOutputs defined. Use \\'. __NAMESPACE__ .'\\Network::setValues().');
  395. $this->detectOutputType();
  396. if($this->isTrainingComplete())
  397. {
  398. $this->boolTrained = TRUE;
  399. return $this->boolTrained;
  400. }
  401. $intStartTime = date('U');
  402. $this->getNextIndexInputsToTrain(TRUE);
  403. $this->boolFirstLoopOfTraining = TRUE;
  404. $this->boolFirstEpochOfTraining = TRUE;
  405. $intLoop = 0;
  406. while($this->hasTimeLeftForTraining())
  407. {
  408. $intLoop++;
  409. $this->setDynamicLearningRate($intLoop);
  410. $j = $this->getNextIndexInputsToTrain();
  411. $this->setInputsToTrain($this->arrInputs[$j]);
  412. if(!($this->arrTrainingComplete[$j] = $this->isTrainingCompleteByInputKey($j)))
  413. $this->training($this->arrOutputs[$j]);
  414. if($this->isEpoch())
  415. {
  416. if($this->boolLoggingWeights)
  417. $this->logWeights();
  418. if($this->boolLoggingNetworkErrors)
  419. $this->logNetworkErrors();
  420. if($this->isTrainingCompleteByEpoch())
  421. break;
  422. $this->boolFirstEpochOfTraining = FALSE;
  423. }
  424. $this->boolFirstLoopOfTraining = FALSE;
  425. }
  426. $intStopTime = date('U');
  427. $this->intTotalLoops += $intLoop;
  428. $this->intTrainingTime += $intStopTime - $intStartTime;
  429. $this->boolTrained = $this->isTrainingComplete();
  430. return $this->boolTrained;
  431. }
  432. /**
  433. * @return boolean
  434. */
  435. protected function hasTimeLeftForTraining()
  436. {
  437. return ($_SERVER['REQUEST_TIME'] + $this->intMaxExecutionTime > date('U'));
  438. }
  439. /**
  440. * @param boolean $boolReset (Default: FALSE)
  441. * @return integer
  442. */
  443. protected function getNextIndexInputsToTrain($boolReset = FALSE)
  444. {
  445. if($boolReset)
  446. {
  447. $this->arrInputsToTrain = array_keys($this->arrInputs);
  448. $this->intInputsToTrainIndex = -1;
  449. return;
  450. }
  451. $this->intInputsToTrainIndex++;
  452. if(!isset($this->arrInputsToTrain[$this->intInputsToTrainIndex]))
  453. {
  454. shuffle($this->arrInputsToTrain);
  455. $this->intInputsToTrainIndex = 0;
  456. }
  457. return $this->arrInputsToTrain[$this->intInputsToTrainIndex];
  458. }
  459. /**
  460. * @return integer
  461. */
  462. public function getTotalLoops()
  463. {
  464. return $this->intTotalLoops;
  465. }
  466. /**
  467. * @return boolean
  468. */
  469. protected function isEpoch()
  470. {
  471. static $countLoop = 0;
  472. $countLoop++;
  473. if($countLoop >= $this->intNumberEpoch)
  474. {
  475. $countLoop = 0;
  476. return TRUE;
  477. }
  478. return FALSE;
  479. }
  480. /**
  481. * Setting the learning rate
  482. *
  483. * @param float $floatLearningRate (Default: 0.7) (0.1 .. 0.9)
  484. * @uses Exception::__construct()
  485. * @throws Exception
  486. */
  487. protected function setLearningRate($floatLearningRate = 0.7)
  488. {
  489. if(!is_float($floatLearningRate))
  490. throw new Exception('$floatLearningRate should be between 0.1 and 0.9');
  491. if($floatLearningRate <= 0 || $floatLearningRate >= 1)
  492. throw new Exception('$floatLearningRate should be between 0.1 and 0.9');
  493. $this->floatLearningRate = $floatLearningRate;
  494. }
  495. /**
  496. * @return boolean
  497. * @uses getOutputs()
  498. */
  499. protected function isTrainingComplete()
  500. {
  501. $arrOutputs = $this->getOutputs();
  502. switch($this->intOutputType)
  503. {
  504. case self::OUTPUT_LINEAR:
  505. foreach($this->arrOutputs as $intKey1 => $arrOutput)
  506. foreach($arrOutput as $intKey2 => $floatValue)
  507. if(($floatValue > round($arrOutputs[$intKey1][$intKey2] + $this->floatOutputErrorTolerance, 3)) || ($floatValue < round($arrOutputs[$intKey1][$intKey2] - $this->floatOutputErrorTolerance, 3)))
  508. return FALSE;
  509. return TRUE;
  510. case self::OUTPUT_BINARY:
  511. foreach($this->arrOutputs as $intKey1 => $arrOutput)
  512. foreach($arrOutput as $intKey2 => $floatValue)
  513. if($floatValue != $arrOutputs[$intKey1][$intKey2])
  514. return FALSE;
  515. return TRUE;
  516. }
  517. }
  518. /**
  519. * @return boolean
  520. */
  521. protected function isTrainingCompleteByEpoch()
  522. {
  523. foreach($this->arrTrainingComplete as $trainingComplete)
  524. if(!$trainingComplete)
  525. return FALSE;
  526. return TRUE;
  527. }
  528. /**
  529. * @param integer $intKeyInput
  530. * @return boolean
  531. * @uses getOutputsByInputKey()
  532. */
  533. protected function isTrainingCompleteByInputKey($intKeyInput)
  534. {
  535. $arrOutputs = $this->getOutputsByInputKey($intKeyInput);
  536. if(!isset($this->arrOutputs[$intKeyInput]))
  537. return TRUE;
  538. switch($this->intOutputType)
  539. {
  540. case self::OUTPUT_LINEAR:
  541. foreach($this->arrOutputs[$intKeyInput] as $intKey => $floatValue)
  542. if(($floatValue > round($arrOutputs[$intKey] + $this->floatOutputErrorTolerance, 3)) || ($floatValue < round($arrOutputs[$intKey] - $this->floatOutputErrorTolerance, 3)))
  543. return FALSE;
  544. return TRUE;
  545. case self::OUTPUT_BINARY:
  546. foreach($this->arrOutputs[$intKeyInput] as $intKey => $floatValue)
  547. if($floatValue != $arrOutputs[$intKey])
  548. return FALSE;
  549. return TRUE;
  550. }
  551. }
  552. /**
  553. * @return integer
  554. */
  555. protected function getCountInputs()
  556. {
  557. if(isset($this->arrInputs) && is_array($this->arrInputs))
  558. return count($this->arrInputs);
  559. return 0;
  560. }
  561. /**
  562. * @param array $arrOutputs
  563. * @uses activate()
  564. * @uses Layer::calculateHiddenDeltas()
  565. * @uses Layer::adjustWeights()
  566. * @uses Layer::calculateOutputDeltas()
  567. * @uses getNetworkError()
  568. */
  569. protected function training($arrOutputs)
  570. {
  571. $this->activate();
  572. $this->objOutputLayer->calculateOutputDeltas($arrOutputs);
  573. for ($i = $this->intNumberOfHiddenLayersDec; $i >= 0; $i--)
  574. $this->arrHiddenLayers[$i]->calculateHiddenDeltas();
  575. $this->objOutputLayer->adjustWeights();
  576. for ($i = $this->intNumberOfHiddenLayersDec; $i >= 0; $i--)
  577. $this->arrHiddenLayers[$i]->adjustWeights();
  578. $this->intTotalTrainings++;
  579. $this->boolNetworkActivated = FALSE;
  580. }
  581. /**
  582. * @return string Filename
  583. */
  584. protected static function getDefaultFilename()
  585. {
  586. return preg_replace('/\.php$/', '.dat', basename($_SERVER['PHP_SELF']));
  587. }
  588. /**
  589. * @param integer $intType (Default: Network::OUTPUT_LINEAR)
  590. * @uses Exception::__construct()
  591. * @throws Exception
  592. */
  593. protected function setOutputType($intType = self::OUTPUT_LINEAR)
  594. {
  595. settype($intType, 'integer');
  596. switch($intType)
  597. {
  598. case self::OUTPUT_LINEAR:
  599. case self::OUTPUT_BINARY:
  600. $this->intOutputType = $intType;
  601. break;
  602. default:
  603. throw new Exception('$strType must be \\'. __NAMESPACE__ .'\\Network::OUTPUT_LINEAR or \\'. __NAMESPACE__ .'\\Network::OUTPUT_BINARY');
  604. }
  605. }
  606. /**
  607. * @uses getCPULimit()
  608. * @uses getMaxExecutionTime()
  609. * @throws Exception
  610. */
  611. protected function setMaxExecutionTime()
  612. {
  613. $intMaxExecutionTime = $this->getMaxExecutionTime();
  614. $intCPULimit = $this->getCPULimit();
  615. if($intMaxExecutionTime == 0)
  616. {
  617. $intMaxExecutionTime = $intCPULimit;
  618. }
  619. elseif($intCPULimit == 0)
  620. {
  621. $intMaxExecutionTime = $intMaxExecutionTime;
  622. }
  623. else
  624. {
  625. $intMaxExecutionTime = min($intMaxExecutionTime, $intCPULimit);
  626. }
  627. $this->intMaxExecutionTime = $intMaxExecutionTime;
  628. if($this->intMaxExecutionTime == 0 && !isset($_REQUEST['XDEBUG_SESSION_START']))
  629. throw new Exception('max_execution_time is 0');
  630. }
  631. /**
  632. * @uses setMaxExecutionTime()
  633. */
  634. public function __wakeup()
  635. {
  636. $this->setMaxExecutionTime();
  637. $this->boolNetworkActivated = FALSE;
  638. }
  639. /**
  640. * @param string $strFilename (Default: null)
  641. * @return Network
  642. * @uses parent::loadFromFile()
  643. * @uses getDefaultFilename()
  644. */
  645. public static function loadFromFile($strFilename = null)
  646. {
  647. if($strFilename === null)
  648. $strFilename = self::getDefaultFilename();
  649. return parent::loadFromFile($strFilename);
  650. }
  651. /**
  652. * @param string $strFilename (Default: null)
  653. * @uses parent::saveToFile()
  654. * @uses getDefaultFilename()
  655. */
  656. public function saveToFile($strFilename = null)
  657. {
  658. if($strFilename === null)
  659. $strFilename = self::getDefaultFilename();
  660. parent::saveToFile($strFilename);
  661. }
  662. /**
  663. * @return integer
  664. */
  665. public function getNumberInputs()
  666. {
  667. if(isset($this->arrInputs) && is_array($this->arrInputs))
  668. if(isset($this->arrInputs[0]))
  669. return count($this->arrInputs[0]);
  670. return 0;
  671. }
  672. /**
  673. * @return integer
  674. */
  675. public function getNumberHiddenLayers()
  676. {
  677. if(isset($this->arrHiddenLayers) && is_array($this->arrHiddenLayers))
  678. return count($this->arrHiddenLayers);
  679. return 0;
  680. }
  681. /**
  682. * @return integer
  683. */
  684. public function getNumberHiddens()
  685. {
  686. if(isset($this->arrHiddenLayers) && is_array($this->arrHiddenLayers))
  687. if(isset($this->arrHiddenLayers[0]))
  688. return $this->arrHiddenLayers[0]->getNeuronsCount();
  689. return 0;
  690. }
  691. /**
  692. * @return integer
  693. */
  694. public function getNumberOutputs()
  695. {
  696. if(isset($this->arrOutputs[0]) && is_array($this->arrOutputs[0]))
  697. return count($this->arrOutputs[0]);
  698. return 0;
  699. }
  700. /**
  701. * Log weights while training in CSV format
  702. *
  703. * @param string $strFilename
  704. * @uses Logging::__construct()
  705. * @uses Logging::setFilename()
  706. */
  707. public function logWeightsToFile($strFilename)
  708. {
  709. $this->boolLoggingWeights = TRUE;
  710. $this->objLoggingWeights = new Logging;
  711. $this->objLoggingWeights->setFilename($strFilename);
  712. }
  713. /**
  714. * Log network errors while training in CSV format
  715. *
  716. * @param string $strFilename
  717. * @uses Logging::__construct()
  718. * @uses Logging::setFilename()
  719. */
  720. public function logNetworkErrorsToFile($strFilename)
  721. {
  722. $this->boolLoggingNetworkErrors = TRUE;
  723. $this->objLoggingNetworkErrors = new Logging;
  724. $this->objLoggingNetworkErrors->setFilename($strFilename);
  725. }
  726. /**
  727. * @uses Layer::getNeurons()
  728. * @uses Logging::logData()
  729. * @uses Neuron::getWeights()
  730. * @uses getNetworkError()
  731. */
  732. protected function logWeights()
  733. {
  734. $arrData = array();
  735. $arrData['E'] = $this->getNetworkError();
  736. // ****** arrHiddenLayers ****************
  737. foreach($this->arrHiddenLayers as $intKeyLayer => $objHiddenLayer)
  738. {
  739. $arrNeurons = $objHiddenLayer->getNeurons();
  740. foreach($arrNeurons as $intKeyNeuron => $objNeuron)
  741. foreach($objNeuron->getWeights() as $intKeyWeight => $weight)
  742. $arrData["H$intKeyLayer-N$intKeyNeuron-W$intKeyWeight"] = round($weight, 5);
  743. }
  744. // ****** objOutputLayer *****************
  745. $arrNeurons = $this->objOutputLayer->getNeurons();
  746. foreach($arrNeurons as $intKeyNeuron => $objNeuron)
  747. foreach($objNeuron->getWeights() as $intKeyWeight => $weight)
  748. $arrData["O-N$intKeyNeuron-W$intKeyWeight"] = round($weight, 5);
  749. // ************************************
  750. $this->objLoggingWeights->logData($arrData);
  751. }
  752. /**
  753. * @uses getNetworkError()
  754. * @uses Logging::logData()
  755. */
  756. protected function logNetworkErrors()
  757. {
  758. $arrData = array();
  759. $arrData['network error'] = number_format($this->getNetworkError(), 8, ',', '');
  760. $arrData['learning rate'] = $this->floatLearningRate;
  761. $this->objLoggingNetworkErrors->logData($arrData);
  762. }
  763. /**
  764. * @return float
  765. * @uses getOutputs()
  766. */
  767. protected function getNetworkError()
  768. {
  769. $floatError = 0;
  770. $arrNetworkOutputs = $this->getOutputs();
  771. foreach($this->arrOutputs as $intKeyOutputs => $arrDesiredOutputs)
  772. foreach($arrDesiredOutputs as $intKeyOutput => $floatDesiredOutput)
  773. $floatError += pow($arrNetworkOutputs[$intKeyOutputs][$intKeyOutput] - $floatDesiredOutput, 2);
  774. return $floatError / 2;
  775. }
  776. /**
  777. * @param string $strUsername
  778. * @param string $strPassword
  779. * @param string $strHost
  780. * @return Network
  781. * @throws Exception
  782. */
  783. public function trainByHost($strUsername, $strPassword, $strHost)
  784. {
  785. if(!extension_loaded('curl'))
  786. throw new Exception('Curl extension is not installed or active on this system');
  787. $handleCurl = curl_init();
  788. settype($strUsername, 'string');
  789. settype($strPassword, 'string');
  790. settype($strHost, 'string');
  791. curl_setopt($handleCurl, CURLOPT_URL, $strHost);
  792. curl_setopt($handleCurl, CURLOPT_POST, TRUE);
  793. curl_setopt($handleCurl, CURLOPT_POSTFIELDS, "mode=trainbyhost&username=$strUsername&password=$strPassword&network=". serialize($this));
  794. curl_setopt($handleCurl, CURLOPT_RETURNTRANSFER, 1);
  795. $strResult = curl_exec($handleCurl);
  796. curl_close($handleCurl);
  797. $objNetwork = @unserialize($strResult);
  798. if($objNetwork instanceof Network)
  799. return $objNetwork;
  800. }
  801. /**
  802. * @param string $strUsername
  803. * @param string $strPassword
  804. * @param string $strHost
  805. * @throws Exception
  806. */
  807. public function saveToHost($strUsername, $strPassword, $strHost)
  808. {
  809. if(!extension_loaded('curl'))
  810. throw new Exception('Curl extension is not installed or active on this system');
  811. $handleCurl = curl_init();
  812. settype($strUsername, 'string');
  813. settype($strPassword, 'string');
  814. settype($strHost, 'string');
  815. curl_setopt($handleCurl, CURLOPT_URL, $strHost);
  816. curl_setopt($handleCurl, CURLOPT_POST, TRUE);
  817. curl_setopt($handleCurl, CURLOPT_POSTFIELDS, "mode=savetohost&username=$strUsername&password=$strPassword&network=". serialize($this));
  818. curl_exec($handleCurl);
  819. curl_close($handleCurl);
  820. }
  821. /**
  822. * @param string $strUsername
  823. * @param string $strPassword
  824. * @param string $strHost
  825. * @return Network
  826. * @throws Exception
  827. */
  828. public static function loadFromHost($strUsername, $strPassword, $strHost)
  829. {
  830. if(!extension_loaded('curl'))
  831. throw new Exception('Curl extension is not installed or active on this system');
  832. $handleCurl = curl_init();
  833. settype($strUsername, 'string');
  834. settype($strPassword, 'string');
  835. settype($strHost, 'string');
  836. curl_setopt($handleCurl, CURLOPT_URL, $strHost);
  837. curl_setopt($handleCurl, CURLOPT_POST, TRUE);
  838. curl_setopt($handleCurl, CURLOPT_POSTFIELDS, "mode=loadfromhost&username=$strUsername&password=$strPassword");
  839. curl_setopt($handleCurl, CURLOPT_RETURNTRANSFER, 1);
  840. $strResult = curl_exec($handleCurl);
  841. curl_close($handleCurl);
  842. $objNetwork = unserialize(trim($strResult));
  843. if($objNetwork instanceof Network)
  844. return $objNetwork;
  845. }
  846. /**
  847. * @uses setOutputType()
  848. */
  849. protected function detectOutputType()
  850. {
  851. if(empty($this->arrOutputs))
  852. return;
  853. foreach($this->arrOutputs as $arrOutputs)
  854. foreach($arrOutputs as $floatOutput)
  855. if($floatOutput < 1 && $floatOutput > 0)
  856. {
  857. $this->setOutputType(self::OUTPUT_LINEAR);
  858. return;
  859. }
  860. $this->setOutputType(self::OUTPUT_BINARY);
  861. }
  862. /**
  863. * Setting the percentage of output error in comparison to the desired output
  864. *
  865. * @param float $floatOutputErrorTolerance (Default: 0.02)
  866. */
  867. public function setOutputErrorTolerance($floatOutputErrorTolerance = 0.02)
  868. {
  869. if($floatOutputErrorTolerance < 0 || $floatOutputErrorTolerance > 0.1)
  870. throw new Exception('$floatOutputErrorTolerance must be between 0 and 0.1');
  871. $this->floatOutputErrorTolerance = $floatOutputErrorTolerance;
  872. }
  873. /**
  874. * @param float $floatMomentum (Default: 0.95) (0 .. 1)
  875. * @uses Exception::__construct()
  876. * @throws Exception
  877. */
  878. public function setMomentum($floatMomentum = 0.95)
  879. {
  880. if(!is_float($floatMomentum) && !is_integer($floatMomentum))
  881. throw new Exception('$floatLearningRate should be between 0 and 1');
  882. if($floatMomentum <= 0 || $floatMomentum > 1)
  883. throw new Exception('$floatLearningRate should be between 0 and 1');
  884. $this->floatMomentum = $floatMomentum;
  885. }
  886. /**
  887. * @uses \ANN\Controller\ControllerPrintNetwork::__construct()
  888. */
  889. public function printNetwork()
  890. {
  891. $objController = new \ANN\Controller\ControllerPrintNetwork($this);
  892. }
  893. /**
  894. * @param integer $intLevel (Default: 2)
  895. * @uses printNetwork()
  896. */
  897. public function __invoke($intLevel = 2)
  898. {
  899. $this->printNetwork($intLevel);
  900. }
  901. /**
  902. * @uses getPrintNetwork()
  903. * @return string
  904. */
  905. public function __toString()
  906. {
  907. return $this->getPrintNetwork();
  908. }
  909. /**
  910. * Dynamic Learning Rate
  911. *
  912. * Setting learning rate all 1000 loops dynamically
  913. *
  914. * @param integer $intLoop
  915. * @uses setLearningRate()
  916. */
  917. protected function setDynamicLearningRate($intLoop)
  918. {
  919. if($intLoop % 1000)
  920. return;
  921. $floatLearningRate = (mt_rand(5, 7) / 10);
  922. $this->setLearningRate($floatLearningRate);
  923. }
  924. /**
  925. * @return array
  926. * @uses getCPULimit()
  927. * @uses getMaxExecutionTime()
  928. * @uses getNetworkError()
  929. * @uses getNumberInputs()
  930. * @uses getTrainedInputsPercentage()
  931. */
  932. public function getNetworkInfo()
  933. {
  934. $arrReturn = array();
  935. switch($this->intOutputType)
  936. {
  937. case self::OUTPUT_BINARY:
  938. $arrReturn['detected_output_type'] = 'Binary';
  939. break;
  940. case self::OUTPUT_LINEAR:
  941. $arrReturn['detected_output_type'] = 'Linear';
  942. break;
  943. }
  944. $arrReturn['activation_function'] = 'Sigmoid';
  945. $arrReturn['momentum'] = $this->floatMomentum;
  946. $arrReturn['learning_rate'] = 'Dynamic';
  947. $arrReturn['network_error'] = $this->getNetworkError();
  948. $arrReturn['output_error_tolerance'] = $this->floatOutputErrorTolerance;
  949. $arrReturn['total_loops'] = $this->intTotalLoops;
  950. $arrReturn['total_trainings'] = $this->intTotalTrainings;
  951. $arrReturn['total_activations'] = $this->intTotalActivations;
  952. $arrReturn['total_activations_requests'] = $this->intTotalActivationsRequests;
  953. $arrReturn['epoch'] = $this->intNumberEpoch;
  954. $arrReturn['training_time_seconds'] = $this->intTrainingTime;
  955. $arrReturn['training_time_minutes'] = round($this->intTrainingTime / 60, 1);
  956. if($this->intTrainingTime > 0)
  957. {
  958. $arrReturn['loops_per_second'] = round($this->intTotalLoops / $this->intTrainingTime);
  959. }
  960. else
  961. {
  962. $arrReturn['loops_per_second'] = round($this->intTotalLoops / 0.1);
  963. }
  964. $arrReturn['training_finished'] = ($this->boolTrained) ? 'Yes' : 'No';
  965. $arrReturn['max_execution_time'] = $this->getMaxExecutionTime();
  966. $arrReturn['cpu_limit'] = $this->getCPULimit();
  967. $arrReturn['network']['arrHiddenLayers'] = $this->arrHiddenLayers;
  968. $arrReturn['network']['objOutputLayer'] = $this->objOutputLayer;
  969. $arrReturn['network']['intCountInputs'] = $this->getNumberInputs();
  970. $arrReturn['trained_percentage'] = $this->getTrainedInputsPercentage();
  971. $arrReturn['max_execution_time_network'] = $this->intMaxExecutionTime;
  972. $arrReturn['phpversion'] = phpversion();
  973. $arrReturn['phpinterface'] = php_sapi_name();
  974. return $arrReturn;
  975. }
  976. /**
  977. * @return integer Seconds
  978. */
  979. protected function getMaxExecutionTime()
  980. {
  981. return (int)ini_get('max_execution_time');
  982. }
  983. /**
  984. * @return integer Seconds
  985. */
  986. protected function getCPULimit()
  987. {
  988. return (int)shell_exec('ulimit -t');
  989. }
  990. /**
  991. * @return float
  992. * @uses isTrainingCompleteByInputKey()
  993. */
  994. protected function getTrainedInputsPercentage()
  995. {
  996. $boolTrained = 0;
  997. foreach($this->arrInputs as $intKeyInputs => $arrInputs)
  998. {
  999. if($this->isTrainingCompleteByInputKey($intKeyInputs))
  1000. $boolTrained++;
  1001. }
  1002. return round(($boolTrained / @count($this->arrOutputs)) * 100, 1);
  1003. }
  1004. }