lstmbe.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857
  1. // © 2021 and later: Unicode, Inc. and others.
  2. // License & terms of use: http://www.unicode.org/copyright.html
  3. #include <complex>
  4. #include <utility>
  5. #include "unicode/utypes.h"
  6. #if !UCONFIG_NO_BREAK_ITERATION
  7. #include "brkeng.h"
  8. #include "charstr.h"
  9. #include "cmemory.h"
  10. #include "lstmbe.h"
  11. #include "putilimp.h"
  12. #include "uassert.h"
  13. #include "ubrkimpl.h"
  14. #include "uresimp.h"
  15. #include "uvectr32.h"
  16. #include "uvector.h"
  17. #include "unicode/brkiter.h"
  18. #include "unicode/resbund.h"
  19. #include "unicode/ubrk.h"
  20. #include "unicode/uniset.h"
  21. #include "unicode/ustring.h"
  22. #include "unicode/utf.h"
  23. U_NAMESPACE_BEGIN
  24. // Uncomment the following #define to debug.
  25. // #define LSTM_DEBUG 1
  26. // #define LSTM_VECTORIZER_DEBUG 1
  27. /**
  28. * Interface for reading 1D array.
  29. */
  30. class ReadArray1D {
  31. public:
  32. virtual ~ReadArray1D();
  33. virtual int32_t d1() const = 0;
  34. virtual float get(int32_t i) const = 0;
  35. #ifdef LSTM_DEBUG
  36. void print() const {
  37. printf("\n[");
  38. for (int32_t i = 0; i < d1(); i++) {
  39. printf("%0.8e ", get(i));
  40. if (i % 4 == 3) printf("\n");
  41. }
  42. printf("]\n");
  43. }
  44. #endif
  45. };
  46. ReadArray1D::~ReadArray1D()
  47. {
  48. }
  49. /**
  50. * Interface for reading 2D array.
  51. */
  52. class ReadArray2D {
  53. public:
  54. virtual ~ReadArray2D();
  55. virtual int32_t d1() const = 0;
  56. virtual int32_t d2() const = 0;
  57. virtual float get(int32_t i, int32_t j) const = 0;
  58. };
  59. ReadArray2D::~ReadArray2D()
  60. {
  61. }
  62. /**
  63. * A class to index a float array as a 1D Array without owning the pointer or
  64. * copy the data.
  65. */
  66. class ConstArray1D : public ReadArray1D {
  67. public:
  68. ConstArray1D() : data_(nullptr), d1_(0) {}
  69. ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
  70. virtual ~ConstArray1D();
  71. // Init the object, the object does not own the data nor copy.
  72. // It is designed to directly use data from memory mapped resources.
  73. void init(const int32_t* data, int32_t d1) {
  74. U_ASSERT(IEEE_754 == 1);
  75. data_ = reinterpret_cast<const float*>(data);
  76. d1_ = d1;
  77. }
  78. // ReadArray1D methods.
  79. virtual int32_t d1() const override { return d1_; }
  80. virtual float get(int32_t i) const override {
  81. U_ASSERT(i < d1_);
  82. return data_[i];
  83. }
  84. private:
  85. const float* data_;
  86. int32_t d1_;
  87. };
  88. ConstArray1D::~ConstArray1D()
  89. {
  90. }
  91. /**
  92. * A class to index a float array as a 2D Array without owning the pointer or
  93. * copy the data.
  94. */
  95. class ConstArray2D : public ReadArray2D {
  96. public:
  97. ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
  98. ConstArray2D(const float* data, int32_t d1, int32_t d2)
  99. : data_(data), d1_(d1), d2_(d2) {}
  100. virtual ~ConstArray2D();
  101. // Init the object, the object does not own the data nor copy.
  102. // It is designed to directly use data from memory mapped resources.
  103. void init(const int32_t* data, int32_t d1, int32_t d2) {
  104. U_ASSERT(IEEE_754 == 1);
  105. data_ = reinterpret_cast<const float*>(data);
  106. d1_ = d1;
  107. d2_ = d2;
  108. }
  109. // ReadArray2D methods.
  110. inline int32_t d1() const override { return d1_; }
  111. inline int32_t d2() const override { return d2_; }
  112. float get(int32_t i, int32_t j) const override {
  113. U_ASSERT(i < d1_);
  114. U_ASSERT(j < d2_);
  115. return data_[i * d2_ + j];
  116. }
  117. // Expose the ith row as a ConstArray1D
  118. inline ConstArray1D row(int32_t i) const {
  119. U_ASSERT(i < d1_);
  120. return ConstArray1D(data_ + i * d2_, d2_);
  121. }
  122. private:
  123. const float* data_;
  124. int32_t d1_;
  125. int32_t d2_;
  126. };
  127. ConstArray2D::~ConstArray2D()
  128. {
  129. }
  130. /**
  131. * A class to allocate data as a writable 1D array.
  132. * This is the main class implement matrix operation.
  133. */
  134. class Array1D : public ReadArray1D {
  135. public:
  136. Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
  137. Array1D(int32_t d1, UErrorCode &status)
  138. : memory_(uprv_malloc(d1 * sizeof(float))),
  139. data_((float*)memory_), d1_(d1) {
  140. if (U_SUCCESS(status)) {
  141. if (memory_ == nullptr) {
  142. status = U_MEMORY_ALLOCATION_ERROR;
  143. return;
  144. }
  145. clear();
  146. }
  147. }
  148. virtual ~Array1D();
  149. // A special constructor which does not own the memory but writeable
  150. // as a slice of an array.
  151. Array1D(float* data, int32_t d1)
  152. : memory_(nullptr), data_(data), d1_(d1) {}
  153. // ReadArray1D methods.
  154. virtual int32_t d1() const override { return d1_; }
  155. virtual float get(int32_t i) const override {
  156. U_ASSERT(i < d1_);
  157. return data_[i];
  158. }
  159. // Return the index which point to the max data in the array.
  160. inline int32_t maxIndex() const {
  161. int32_t index = 0;
  162. float max = data_[0];
  163. for (int32_t i = 1; i < d1_; i++) {
  164. if (data_[i] > max) {
  165. max = data_[i];
  166. index = i;
  167. }
  168. }
  169. return index;
  170. }
  171. // Slice part of the array to a new one.
  172. inline Array1D slice(int32_t from, int32_t size) const {
  173. U_ASSERT(from >= 0);
  174. U_ASSERT(from < d1_);
  175. U_ASSERT(from + size <= d1_);
  176. return Array1D(data_ + from, size);
  177. }
  178. // Add dot product of a 1D array and a 2D array into this one.
  179. inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
  180. U_ASSERT(a.d1() == b.d1());
  181. U_ASSERT(b.d2() == d1());
  182. for (int32_t i = 0; i < d1(); i++) {
  183. for (int32_t j = 0; j < a.d1(); j++) {
  184. data_[i] += a.get(j) * b.get(j, i);
  185. }
  186. }
  187. return *this;
  188. }
  189. // Hadamard Product the values of another array of the same size into this one.
  190. inline Array1D& hadamardProduct(const ReadArray1D& a) {
  191. U_ASSERT(a.d1() == d1());
  192. for (int32_t i = 0; i < d1(); i++) {
  193. data_[i] *= a.get(i);
  194. }
  195. return *this;
  196. }
  197. // Add the Hadamard Product of two arrays of the same size into this one.
  198. inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
  199. U_ASSERT(a.d1() == d1());
  200. U_ASSERT(b.d1() == d1());
  201. for (int32_t i = 0; i < d1(); i++) {
  202. data_[i] += a.get(i) * b.get(i);
  203. }
  204. return *this;
  205. }
  206. // Add the values of another array of the same size into this one.
  207. inline Array1D& add(const ReadArray1D& a) {
  208. U_ASSERT(a.d1() == d1());
  209. for (int32_t i = 0; i < d1(); i++) {
  210. data_[i] += a.get(i);
  211. }
  212. return *this;
  213. }
  214. // Assign the values of another array of the same size into this one.
  215. inline Array1D& assign(const ReadArray1D& a) {
  216. U_ASSERT(a.d1() == d1());
  217. for (int32_t i = 0; i < d1(); i++) {
  218. data_[i] = a.get(i);
  219. }
  220. return *this;
  221. }
  222. // Apply tanh to all the elements in the array.
  223. inline Array1D& tanh() {
  224. return tanh(*this);
  225. }
  226. // Apply tanh of a and store into this array.
  227. inline Array1D& tanh(const Array1D& a) {
  228. U_ASSERT(a.d1() == d1());
  229. for (int32_t i = 0; i < d1_; i++) {
  230. data_[i] = std::tanh(a.get(i));
  231. }
  232. return *this;
  233. }
  234. // Apply sigmoid to all the elements in the array.
  235. inline Array1D& sigmoid() {
  236. for (int32_t i = 0; i < d1_; i++) {
  237. data_[i] = 1.0f/(1.0f + expf(-data_[i]));
  238. }
  239. return *this;
  240. }
  241. inline Array1D& clear() {
  242. uprv_memset(data_, 0, d1_ * sizeof(float));
  243. return *this;
  244. }
  245. private:
  246. void* memory_;
  247. float* data_;
  248. int32_t d1_;
  249. };
  250. Array1D::~Array1D()
  251. {
  252. uprv_free(memory_);
  253. }
  254. class Array2D : public ReadArray2D {
  255. public:
  256. Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
  257. Array2D(int32_t d1, int32_t d2, UErrorCode &status)
  258. : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
  259. data_((float*)memory_), d1_(d1), d2_(d2) {
  260. if (U_SUCCESS(status)) {
  261. if (memory_ == nullptr) {
  262. status = U_MEMORY_ALLOCATION_ERROR;
  263. return;
  264. }
  265. clear();
  266. }
  267. }
  268. virtual ~Array2D();
  269. // ReadArray2D methods.
  270. virtual int32_t d1() const override { return d1_; }
  271. virtual int32_t d2() const override { return d2_; }
  272. virtual float get(int32_t i, int32_t j) const override {
  273. U_ASSERT(i < d1_);
  274. U_ASSERT(j < d2_);
  275. return data_[i * d2_ + j];
  276. }
  277. inline Array1D row(int32_t i) const {
  278. U_ASSERT(i < d1_);
  279. return Array1D(data_ + i * d2_, d2_);
  280. }
  281. inline Array2D& clear() {
  282. uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
  283. return *this;
  284. }
  285. private:
  286. void* memory_;
  287. float* data_;
  288. int32_t d1_;
  289. int32_t d2_;
  290. };
  291. Array2D::~Array2D()
  292. {
  293. uprv_free(memory_);
  294. }
  295. typedef enum {
  296. BEGIN,
  297. INSIDE,
  298. END,
  299. SINGLE
  300. } LSTMClass;
  301. typedef enum {
  302. UNKNOWN,
  303. CODE_POINTS,
  304. GRAPHEME_CLUSTER,
  305. } EmbeddingType;
  306. struct LSTMData : public UMemory {
  307. LSTMData(UResourceBundle* rb, UErrorCode &status);
  308. ~LSTMData();
  309. UHashtable* fDict;
  310. EmbeddingType fType;
  311. const char16_t* fName;
  312. ConstArray2D fEmbedding;
  313. ConstArray2D fForwardW;
  314. ConstArray2D fForwardU;
  315. ConstArray1D fForwardB;
  316. ConstArray2D fBackwardW;
  317. ConstArray2D fBackwardU;
  318. ConstArray1D fBackwardB;
  319. ConstArray2D fOutputW;
  320. ConstArray1D fOutputB;
  321. private:
  322. UResourceBundle* fBundle;
  323. };
  324. LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
  325. : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
  326. fBundle(rb)
  327. {
  328. if (U_FAILURE(status)) {
  329. return;
  330. }
  331. if (IEEE_754 != 1) {
  332. status = U_UNSUPPORTED_ERROR;
  333. return;
  334. }
  335. LocalUResourceBundlePointer embeddings_res(
  336. ures_getByKey(rb, "embeddings", nullptr, &status));
  337. int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
  338. LocalUResourceBundlePointer hunits_res(
  339. ures_getByKey(rb, "hunits", nullptr, &status));
  340. if (U_FAILURE(status)) return;
  341. int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
  342. const char16_t* type = ures_getStringByKey(rb, "type", nullptr, &status);
  343. if (U_FAILURE(status)) return;
  344. if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
  345. fType = CODE_POINTS;
  346. } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
  347. fType = GRAPHEME_CLUSTER;
  348. }
  349. fName = ures_getStringByKey(rb, "model", nullptr, &status);
  350. LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
  351. if (U_FAILURE(status)) return;
  352. int32_t data_len = 0;
  353. const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
  354. fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
  355. StackUResourceBundle stackTempBundle;
  356. ResourceDataValue value;
  357. ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
  358. ResourceArray stringArray = value.getArray(status);
  359. int32_t num_index = stringArray.getSize();
  360. if (U_FAILURE(status)) { return; }
  361. // put dict into hash
  362. int32_t stringLength;
  363. for (int32_t idx = 0; idx < num_index; idx++) {
  364. stringArray.getValue(idx, value);
  365. const char16_t* str = value.getString(stringLength, status);
  366. uhash_putiAllowZero(fDict, (void*)str, idx, &status);
  367. if (U_FAILURE(status)) return;
  368. #ifdef LSTM_VECTORIZER_DEBUG
  369. printf("Assign [");
  370. while (*str != 0x0000) {
  371. printf("U+%04x ", *str);
  372. str++;
  373. }
  374. printf("] map to %d\n", idx-1);
  375. #endif
  376. }
  377. int32_t mat1_size = (num_index + 1) * embedding_size;
  378. int32_t mat2_size = embedding_size * 4 * hunits;
  379. int32_t mat3_size = hunits * 4 * hunits;
  380. int32_t mat4_size = 4 * hunits;
  381. int32_t mat5_size = mat2_size;
  382. int32_t mat6_size = mat3_size;
  383. int32_t mat7_size = mat4_size;
  384. int32_t mat8_size = 2 * hunits * 4;
  385. #if U_DEBUG
  386. int32_t mat9_size = 4;
  387. U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
  388. mat6_size + mat7_size + mat8_size + mat9_size);
  389. #endif
  390. fEmbedding.init(data, (num_index + 1), embedding_size);
  391. data += mat1_size;
  392. fForwardW.init(data, embedding_size, 4 * hunits);
  393. data += mat2_size;
  394. fForwardU.init(data, hunits, 4 * hunits);
  395. data += mat3_size;
  396. fForwardB.init(data, 4 * hunits);
  397. data += mat4_size;
  398. fBackwardW.init(data, embedding_size, 4 * hunits);
  399. data += mat5_size;
  400. fBackwardU.init(data, hunits, 4 * hunits);
  401. data += mat6_size;
  402. fBackwardB.init(data, 4 * hunits);
  403. data += mat7_size;
  404. fOutputW.init(data, 2 * hunits, 4);
  405. data += mat8_size;
  406. fOutputB.init(data, 4);
  407. }
  408. LSTMData::~LSTMData() {
  409. uhash_close(fDict);
  410. ures_close(fBundle);
  411. }
  412. class Vectorizer : public UMemory {
  413. public:
  414. Vectorizer(UHashtable* dict) : fDict(dict) {}
  415. virtual ~Vectorizer();
  416. virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
  417. UVector32 &offsets, UVector32 &indices,
  418. UErrorCode &status) const = 0;
  419. protected:
  420. int32_t stringToIndex(const char16_t* str) const {
  421. UBool found = false;
  422. int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
  423. if (!found) {
  424. ret = fDict->count;
  425. }
  426. #ifdef LSTM_VECTORIZER_DEBUG
  427. printf("[");
  428. while (*str != 0x0000) {
  429. printf("U+%04x ", *str);
  430. str++;
  431. }
  432. printf("] map to %d\n", ret);
  433. #endif
  434. return ret;
  435. }
  436. private:
  437. UHashtable* fDict;
  438. };
  439. Vectorizer::~Vectorizer()
  440. {
  441. }
  442. class CodePointsVectorizer : public Vectorizer {
  443. public:
  444. CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
  445. virtual ~CodePointsVectorizer();
  446. virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
  447. UVector32 &offsets, UVector32 &indices,
  448. UErrorCode &status) const override;
  449. };
  450. CodePointsVectorizer::~CodePointsVectorizer()
  451. {
  452. }
  453. void CodePointsVectorizer::vectorize(
  454. UText *text, int32_t startPos, int32_t endPos,
  455. UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
  456. {
  457. if (offsets.ensureCapacity(endPos - startPos, status) &&
  458. indices.ensureCapacity(endPos - startPos, status)) {
  459. if (U_FAILURE(status)) return;
  460. utext_setNativeIndex(text, startPos);
  461. int32_t current;
  462. char16_t str[2] = {0, 0};
  463. while (U_SUCCESS(status) &&
  464. (current = (int32_t)utext_getNativeIndex(text)) < endPos) {
  465. // Since the LSTMBreakEngine is currently only accept chars in BMP,
  466. // we can ignore the possibility of hitting supplementary code
  467. // point.
  468. str[0] = (char16_t) utext_next32(text);
  469. U_ASSERT(!U_IS_SURROGATE(str[0]));
  470. offsets.addElement(current, status);
  471. indices.addElement(stringToIndex(str), status);
  472. }
  473. }
  474. }
  475. class GraphemeClusterVectorizer : public Vectorizer {
  476. public:
  477. GraphemeClusterVectorizer(UHashtable* dict)
  478. : Vectorizer(dict)
  479. {
  480. }
  481. virtual ~GraphemeClusterVectorizer();
  482. virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
  483. UVector32 &offsets, UVector32 &indices,
  484. UErrorCode &status) const override;
  485. };
  486. GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
  487. {
  488. }
  489. constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
  490. void GraphemeClusterVectorizer::vectorize(
  491. UText *text, int32_t startPos, int32_t endPos,
  492. UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
  493. {
  494. if (U_FAILURE(status)) return;
  495. if (!offsets.ensureCapacity(endPos - startPos, status) ||
  496. !indices.ensureCapacity(endPos - startPos, status)) {
  497. return;
  498. }
  499. if (U_FAILURE(status)) return;
  500. LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
  501. if (U_FAILURE(status)) return;
  502. graphemeIter->setText(text, status);
  503. if (U_FAILURE(status)) return;
  504. if (startPos != 0) {
  505. graphemeIter->preceding(startPos);
  506. }
  507. int32_t last = startPos;
  508. int32_t current = startPos;
  509. char16_t str[MAX_GRAPHEME_CLSTER_LENGTH];
  510. while ((current = graphemeIter->next()) != BreakIterator::DONE) {
  511. if (current >= endPos) {
  512. break;
  513. }
  514. if (current > startPos) {
  515. utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
  516. if (U_FAILURE(status)) return;
  517. offsets.addElement(last, status);
  518. indices.addElement(stringToIndex(str), status);
  519. if (U_FAILURE(status)) return;
  520. }
  521. last = current;
  522. }
  523. if (U_FAILURE(status) || last >= endPos) {
  524. return;
  525. }
  526. utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
  527. if (U_SUCCESS(status)) {
  528. offsets.addElement(last, status);
  529. indices.addElement(stringToIndex(str), status);
  530. }
  531. }
  532. // Computing LSTM as stated in
  533. // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
  534. // ifco is temp array allocate outside which does not need to be
  535. // input/output value but could avoid unnecessary memory alloc/free if passing
  536. // in.
  537. void compute(
  538. int32_t hunits,
  539. const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
  540. const ReadArray1D& x, Array1D& h, Array1D& c,
  541. Array1D& ifco)
  542. {
  543. // ifco = x * W + h * U + b
  544. ifco.assign(b)
  545. .addDotProduct(x, W)
  546. .addDotProduct(h, U);
  547. ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
  548. ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
  549. ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
  550. ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
  551. c.hadamardProduct(ifco.slice(hunits, hunits))
  552. .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
  553. h.tanh(c)
  554. .hadamardProduct(ifco.slice(3*hunits, hunits));
  555. }
  556. // Minimum word size
  557. static const int32_t MIN_WORD = 2;
  558. // Minimum number of characters for two words
  559. static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
  560. int32_t
  561. LSTMBreakEngine::divideUpDictionaryRange( UText *text,
  562. int32_t startPos,
  563. int32_t endPos,
  564. UVector32 &foundBreaks,
  565. UBool /* isPhraseBreaking */,
  566. UErrorCode& status) const {
  567. if (U_FAILURE(status)) return 0;
  568. int32_t beginFoundBreakSize = foundBreaks.size();
  569. utext_setNativeIndex(text, startPos);
  570. utext_moveIndex32(text, MIN_WORD_SPAN);
  571. if (utext_getNativeIndex(text) >= endPos) {
  572. return 0; // Not enough characters for two words
  573. }
  574. utext_setNativeIndex(text, startPos);
  575. UVector32 offsets(status);
  576. UVector32 indices(status);
  577. if (U_FAILURE(status)) return 0;
  578. fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
  579. if (U_FAILURE(status)) return 0;
  580. int32_t* offsetsBuf = offsets.getBuffer();
  581. int32_t* indicesBuf = indices.getBuffer();
  582. int32_t input_seq_len = indices.size();
  583. int32_t hunits = fData->fForwardU.d1();
  584. // ----- Begin of all the Array memory allocation needed for this function
  585. // Allocate temp array used inside compute()
  586. Array1D ifco(4 * hunits, status);
  587. Array1D c(hunits, status);
  588. Array1D logp(4, status);
  589. // TODO: limit size of hBackward. If input_seq_len is too big, we could
  590. // run out of memory.
  591. // Backward LSTM
  592. Array2D hBackward(input_seq_len, hunits, status);
  593. // Allocate fbRow and slice the internal array in two.
  594. Array1D fbRow(2 * hunits, status);
  595. // ----- End of all the Array memory allocation needed for this function
  596. if (U_FAILURE(status)) return 0;
  597. // To save the needed memory usage, the following is different from the
  598. // Python or ICU4X implementation. We first perform the Backward LSTM
  599. // and then merge the iteration of the forward LSTM and the output layer
  600. // together because we only neetdto remember the h[t-1] for Forward LSTM.
  601. for (int32_t i = input_seq_len - 1; i >= 0; i--) {
  602. Array1D hRow = hBackward.row(i);
  603. if (i != input_seq_len - 1) {
  604. hRow.assign(hBackward.row(i+1));
  605. }
  606. #ifdef LSTM_DEBUG
  607. printf("hRow %d\n", i);
  608. hRow.print();
  609. printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
  610. printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
  611. fData->fEmbedding.row(indicesBuf[i]).print();
  612. #endif // LSTM_DEBUG
  613. compute(hunits,
  614. fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
  615. fData->fEmbedding.row(indicesBuf[i]),
  616. hRow, c, ifco);
  617. }
  618. Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
  619. Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
  620. // The following iteration merge the forward LSTM and the output layer
  621. // together.
  622. c.clear(); // reuse c since it is the same size.
  623. for (int32_t i = 0; i < input_seq_len; i++) {
  624. #ifdef LSTM_DEBUG
  625. printf("forwardRow %d\n", i);
  626. forwardRow.print();
  627. #endif // LSTM_DEBUG
  628. // Forward LSTM
  629. // Calculate the result into forwardRow, which point to the data in the first half
  630. // of fbRow.
  631. compute(hunits,
  632. fData->fForwardW, fData->fForwardU, fData->fForwardB,
  633. fData->fEmbedding.row(indicesBuf[i]),
  634. forwardRow, c, ifco);
  635. // assign the data from hBackward.row(i) to second half of fbRowa.
  636. backwardRow.assign(hBackward.row(i));
  637. logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
  638. #ifdef LSTM_DEBUG
  639. printf("backwardRow %d\n", i);
  640. backwardRow.print();
  641. printf("logp %d\n", i);
  642. logp.print();
  643. #endif // LSTM_DEBUG
  644. // current = argmax(logp)
  645. LSTMClass current = (LSTMClass)logp.maxIndex();
  646. // BIES logic.
  647. if (current == BEGIN || current == SINGLE) {
  648. if (i != 0) {
  649. foundBreaks.addElement(offsetsBuf[i], status);
  650. if (U_FAILURE(status)) return 0;
  651. }
  652. }
  653. }
  654. return foundBreaks.size() - beginFoundBreakSize;
  655. }
  656. Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
  657. if (U_FAILURE(status)) {
  658. return nullptr;
  659. }
  660. switch (data->fType) {
  661. case CODE_POINTS:
  662. return new CodePointsVectorizer(data->fDict);
  663. break;
  664. case GRAPHEME_CLUSTER:
  665. return new GraphemeClusterVectorizer(data->fDict);
  666. break;
  667. default:
  668. break;
  669. }
  670. UPRV_UNREACHABLE_EXIT;
  671. }
  672. LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
  673. : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
  674. {
  675. if (U_FAILURE(status)) {
  676. fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so.
  677. return;
  678. }
  679. setCharacters(set);
  680. }
  681. LSTMBreakEngine::~LSTMBreakEngine() {
  682. delete fData;
  683. delete fVectorizer;
  684. }
  685. const char16_t* LSTMBreakEngine::name() const {
  686. return fData->fName;
  687. }
  688. UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
  689. // open root from brkitr tree.
  690. UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
  691. b = ures_getByKeyWithFallback(b, "lstm", b, &status);
  692. UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
  693. ures_close(b);
  694. return result;
  695. }
  696. U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
  697. {
  698. if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
  699. return nullptr;
  700. }
  701. UnicodeString name = defaultLSTM(script, status);
  702. if (U_FAILURE(status)) return nullptr;
  703. CharString namebuf;
  704. namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
  705. LocalUResourceBundlePointer rb(
  706. ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
  707. if (U_FAILURE(status)) return nullptr;
  708. return CreateLSTMData(rb.orphan(), status);
  709. }
  710. U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
  711. {
  712. return new LSTMData(rb, status);
  713. }
  714. U_CAPI const LanguageBreakEngine* U_EXPORT2
  715. CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
  716. {
  717. UnicodeString unicodeSetString;
  718. switch(script) {
  719. case USCRIPT_THAI:
  720. unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
  721. break;
  722. case USCRIPT_MYANMAR:
  723. unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
  724. break;
  725. default:
  726. delete data;
  727. return nullptr;
  728. }
  729. UnicodeSet unicodeSet;
  730. unicodeSet.applyPattern(unicodeSetString, status);
  731. const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
  732. if (U_FAILURE(status) || engine == nullptr) {
  733. if (engine != nullptr) {
  734. delete engine;
  735. } else {
  736. status = U_MEMORY_ALLOCATION_ERROR;
  737. }
  738. return nullptr;
  739. }
  740. return engine;
  741. }
  742. U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
  743. {
  744. delete data;
  745. }
  746. U_CAPI const char16_t* U_EXPORT2 LSTMDataName(const LSTMData* data)
  747. {
  748. return data->fName;
  749. }
  750. U_NAMESPACE_END
  751. #endif /* #if !UCONFIG_NO_BREAK_ITERATION */