123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857 |
- // © 2021 and later: Unicode, Inc. and others.
- // License & terms of use: http://www.unicode.org/copyright.html
- #include <complex>
- #include <utility>
- #include "unicode/utypes.h"
- #if !UCONFIG_NO_BREAK_ITERATION
- #include "brkeng.h"
- #include "charstr.h"
- #include "cmemory.h"
- #include "lstmbe.h"
- #include "putilimp.h"
- #include "uassert.h"
- #include "ubrkimpl.h"
- #include "uresimp.h"
- #include "uvectr32.h"
- #include "uvector.h"
- #include "unicode/brkiter.h"
- #include "unicode/resbund.h"
- #include "unicode/ubrk.h"
- #include "unicode/uniset.h"
- #include "unicode/ustring.h"
- #include "unicode/utf.h"
- U_NAMESPACE_BEGIN
- // Uncomment the following #define to debug.
- // #define LSTM_DEBUG 1
- // #define LSTM_VECTORIZER_DEBUG 1
- /**
- * Interface for reading 1D array.
- */
- class ReadArray1D {
- public:
- virtual ~ReadArray1D();
- virtual int32_t d1() const = 0;
- virtual float get(int32_t i) const = 0;
- #ifdef LSTM_DEBUG
- void print() const {
- printf("\n[");
- for (int32_t i = 0; i < d1(); i++) {
- printf("%0.8e ", get(i));
- if (i % 4 == 3) printf("\n");
- }
- printf("]\n");
- }
- #endif
- };
- ReadArray1D::~ReadArray1D()
- {
- }
- /**
- * Interface for reading 2D array.
- */
- class ReadArray2D {
- public:
- virtual ~ReadArray2D();
- virtual int32_t d1() const = 0;
- virtual int32_t d2() const = 0;
- virtual float get(int32_t i, int32_t j) const = 0;
- };
- ReadArray2D::~ReadArray2D()
- {
- }
- /**
- * A class to index a float array as a 1D Array without owning the pointer or
- * copy the data.
- */
- class ConstArray1D : public ReadArray1D {
- public:
- ConstArray1D() : data_(nullptr), d1_(0) {}
- ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
- virtual ~ConstArray1D();
- // Init the object, the object does not own the data nor copy.
- // It is designed to directly use data from memory mapped resources.
- void init(const int32_t* data, int32_t d1) {
- U_ASSERT(IEEE_754 == 1);
- data_ = reinterpret_cast<const float*>(data);
- d1_ = d1;
- }
- // ReadArray1D methods.
- virtual int32_t d1() const override { return d1_; }
- virtual float get(int32_t i) const override {
- U_ASSERT(i < d1_);
- return data_[i];
- }
- private:
- const float* data_;
- int32_t d1_;
- };
- ConstArray1D::~ConstArray1D()
- {
- }
- /**
- * A class to index a float array as a 2D Array without owning the pointer or
- * copy the data.
- */
- class ConstArray2D : public ReadArray2D {
- public:
- ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
- ConstArray2D(const float* data, int32_t d1, int32_t d2)
- : data_(data), d1_(d1), d2_(d2) {}
- virtual ~ConstArray2D();
- // Init the object, the object does not own the data nor copy.
- // It is designed to directly use data from memory mapped resources.
- void init(const int32_t* data, int32_t d1, int32_t d2) {
- U_ASSERT(IEEE_754 == 1);
- data_ = reinterpret_cast<const float*>(data);
- d1_ = d1;
- d2_ = d2;
- }
- // ReadArray2D methods.
- inline int32_t d1() const override { return d1_; }
- inline int32_t d2() const override { return d2_; }
- float get(int32_t i, int32_t j) const override {
- U_ASSERT(i < d1_);
- U_ASSERT(j < d2_);
- return data_[i * d2_ + j];
- }
- // Expose the ith row as a ConstArray1D
- inline ConstArray1D row(int32_t i) const {
- U_ASSERT(i < d1_);
- return ConstArray1D(data_ + i * d2_, d2_);
- }
- private:
- const float* data_;
- int32_t d1_;
- int32_t d2_;
- };
- ConstArray2D::~ConstArray2D()
- {
- }
- /**
- * A class to allocate data as a writable 1D array.
- * This is the main class implement matrix operation.
- */
- class Array1D : public ReadArray1D {
- public:
- Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
- Array1D(int32_t d1, UErrorCode &status)
- : memory_(uprv_malloc(d1 * sizeof(float))),
- data_((float*)memory_), d1_(d1) {
- if (U_SUCCESS(status)) {
- if (memory_ == nullptr) {
- status = U_MEMORY_ALLOCATION_ERROR;
- return;
- }
- clear();
- }
- }
- virtual ~Array1D();
- // A special constructor which does not own the memory but writeable
- // as a slice of an array.
- Array1D(float* data, int32_t d1)
- : memory_(nullptr), data_(data), d1_(d1) {}
- // ReadArray1D methods.
- virtual int32_t d1() const override { return d1_; }
- virtual float get(int32_t i) const override {
- U_ASSERT(i < d1_);
- return data_[i];
- }
- // Return the index which point to the max data in the array.
- inline int32_t maxIndex() const {
- int32_t index = 0;
- float max = data_[0];
- for (int32_t i = 1; i < d1_; i++) {
- if (data_[i] > max) {
- max = data_[i];
- index = i;
- }
- }
- return index;
- }
- // Slice part of the array to a new one.
- inline Array1D slice(int32_t from, int32_t size) const {
- U_ASSERT(from >= 0);
- U_ASSERT(from < d1_);
- U_ASSERT(from + size <= d1_);
- return Array1D(data_ + from, size);
- }
- // Add dot product of a 1D array and a 2D array into this one.
- inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
- U_ASSERT(a.d1() == b.d1());
- U_ASSERT(b.d2() == d1());
- for (int32_t i = 0; i < d1(); i++) {
- for (int32_t j = 0; j < a.d1(); j++) {
- data_[i] += a.get(j) * b.get(j, i);
- }
- }
- return *this;
- }
- // Hadamard Product the values of another array of the same size into this one.
- inline Array1D& hadamardProduct(const ReadArray1D& a) {
- U_ASSERT(a.d1() == d1());
- for (int32_t i = 0; i < d1(); i++) {
- data_[i] *= a.get(i);
- }
- return *this;
- }
- // Add the Hadamard Product of two arrays of the same size into this one.
- inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
- U_ASSERT(a.d1() == d1());
- U_ASSERT(b.d1() == d1());
- for (int32_t i = 0; i < d1(); i++) {
- data_[i] += a.get(i) * b.get(i);
- }
- return *this;
- }
- // Add the values of another array of the same size into this one.
- inline Array1D& add(const ReadArray1D& a) {
- U_ASSERT(a.d1() == d1());
- for (int32_t i = 0; i < d1(); i++) {
- data_[i] += a.get(i);
- }
- return *this;
- }
- // Assign the values of another array of the same size into this one.
- inline Array1D& assign(const ReadArray1D& a) {
- U_ASSERT(a.d1() == d1());
- for (int32_t i = 0; i < d1(); i++) {
- data_[i] = a.get(i);
- }
- return *this;
- }
- // Apply tanh to all the elements in the array.
- inline Array1D& tanh() {
- return tanh(*this);
- }
- // Apply tanh of a and store into this array.
- inline Array1D& tanh(const Array1D& a) {
- U_ASSERT(a.d1() == d1());
- for (int32_t i = 0; i < d1_; i++) {
- data_[i] = std::tanh(a.get(i));
- }
- return *this;
- }
- // Apply sigmoid to all the elements in the array.
- inline Array1D& sigmoid() {
- for (int32_t i = 0; i < d1_; i++) {
- data_[i] = 1.0f/(1.0f + expf(-data_[i]));
- }
- return *this;
- }
- inline Array1D& clear() {
- uprv_memset(data_, 0, d1_ * sizeof(float));
- return *this;
- }
- private:
- void* memory_;
- float* data_;
- int32_t d1_;
- };
- Array1D::~Array1D()
- {
- uprv_free(memory_);
- }
- class Array2D : public ReadArray2D {
- public:
- Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
- Array2D(int32_t d1, int32_t d2, UErrorCode &status)
- : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
- data_((float*)memory_), d1_(d1), d2_(d2) {
- if (U_SUCCESS(status)) {
- if (memory_ == nullptr) {
- status = U_MEMORY_ALLOCATION_ERROR;
- return;
- }
- clear();
- }
- }
- virtual ~Array2D();
- // ReadArray2D methods.
- virtual int32_t d1() const override { return d1_; }
- virtual int32_t d2() const override { return d2_; }
- virtual float get(int32_t i, int32_t j) const override {
- U_ASSERT(i < d1_);
- U_ASSERT(j < d2_);
- return data_[i * d2_ + j];
- }
- inline Array1D row(int32_t i) const {
- U_ASSERT(i < d1_);
- return Array1D(data_ + i * d2_, d2_);
- }
- inline Array2D& clear() {
- uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
- return *this;
- }
- private:
- void* memory_;
- float* data_;
- int32_t d1_;
- int32_t d2_;
- };
- Array2D::~Array2D()
- {
- uprv_free(memory_);
- }
- typedef enum {
- BEGIN,
- INSIDE,
- END,
- SINGLE
- } LSTMClass;
- typedef enum {
- UNKNOWN,
- CODE_POINTS,
- GRAPHEME_CLUSTER,
- } EmbeddingType;
- struct LSTMData : public UMemory {
- LSTMData(UResourceBundle* rb, UErrorCode &status);
- ~LSTMData();
- UHashtable* fDict;
- EmbeddingType fType;
- const char16_t* fName;
- ConstArray2D fEmbedding;
- ConstArray2D fForwardW;
- ConstArray2D fForwardU;
- ConstArray1D fForwardB;
- ConstArray2D fBackwardW;
- ConstArray2D fBackwardU;
- ConstArray1D fBackwardB;
- ConstArray2D fOutputW;
- ConstArray1D fOutputB;
- private:
- UResourceBundle* fBundle;
- };
- LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
- : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
- fBundle(rb)
- {
- if (U_FAILURE(status)) {
- return;
- }
- if (IEEE_754 != 1) {
- status = U_UNSUPPORTED_ERROR;
- return;
- }
- LocalUResourceBundlePointer embeddings_res(
- ures_getByKey(rb, "embeddings", nullptr, &status));
- int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
- LocalUResourceBundlePointer hunits_res(
- ures_getByKey(rb, "hunits", nullptr, &status));
- if (U_FAILURE(status)) return;
- int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
- const char16_t* type = ures_getStringByKey(rb, "type", nullptr, &status);
- if (U_FAILURE(status)) return;
- if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
- fType = CODE_POINTS;
- } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
- fType = GRAPHEME_CLUSTER;
- }
- fName = ures_getStringByKey(rb, "model", nullptr, &status);
- LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
- if (U_FAILURE(status)) return;
- int32_t data_len = 0;
- const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
- fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
- StackUResourceBundle stackTempBundle;
- ResourceDataValue value;
- ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
- ResourceArray stringArray = value.getArray(status);
- int32_t num_index = stringArray.getSize();
- if (U_FAILURE(status)) { return; }
- // put dict into hash
- int32_t stringLength;
- for (int32_t idx = 0; idx < num_index; idx++) {
- stringArray.getValue(idx, value);
- const char16_t* str = value.getString(stringLength, status);
- uhash_putiAllowZero(fDict, (void*)str, idx, &status);
- if (U_FAILURE(status)) return;
- #ifdef LSTM_VECTORIZER_DEBUG
- printf("Assign [");
- while (*str != 0x0000) {
- printf("U+%04x ", *str);
- str++;
- }
- printf("] map to %d\n", idx-1);
- #endif
- }
- int32_t mat1_size = (num_index + 1) * embedding_size;
- int32_t mat2_size = embedding_size * 4 * hunits;
- int32_t mat3_size = hunits * 4 * hunits;
- int32_t mat4_size = 4 * hunits;
- int32_t mat5_size = mat2_size;
- int32_t mat6_size = mat3_size;
- int32_t mat7_size = mat4_size;
- int32_t mat8_size = 2 * hunits * 4;
- #if U_DEBUG
- int32_t mat9_size = 4;
- U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
- mat6_size + mat7_size + mat8_size + mat9_size);
- #endif
- fEmbedding.init(data, (num_index + 1), embedding_size);
- data += mat1_size;
- fForwardW.init(data, embedding_size, 4 * hunits);
- data += mat2_size;
- fForwardU.init(data, hunits, 4 * hunits);
- data += mat3_size;
- fForwardB.init(data, 4 * hunits);
- data += mat4_size;
- fBackwardW.init(data, embedding_size, 4 * hunits);
- data += mat5_size;
- fBackwardU.init(data, hunits, 4 * hunits);
- data += mat6_size;
- fBackwardB.init(data, 4 * hunits);
- data += mat7_size;
- fOutputW.init(data, 2 * hunits, 4);
- data += mat8_size;
- fOutputB.init(data, 4);
- }
- LSTMData::~LSTMData() {
- uhash_close(fDict);
- ures_close(fBundle);
- }
- class Vectorizer : public UMemory {
- public:
- Vectorizer(UHashtable* dict) : fDict(dict) {}
- virtual ~Vectorizer();
- virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
- UVector32 &offsets, UVector32 &indices,
- UErrorCode &status) const = 0;
- protected:
- int32_t stringToIndex(const char16_t* str) const {
- UBool found = false;
- int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
- if (!found) {
- ret = fDict->count;
- }
- #ifdef LSTM_VECTORIZER_DEBUG
- printf("[");
- while (*str != 0x0000) {
- printf("U+%04x ", *str);
- str++;
- }
- printf("] map to %d\n", ret);
- #endif
- return ret;
- }
- private:
- UHashtable* fDict;
- };
- Vectorizer::~Vectorizer()
- {
- }
- class CodePointsVectorizer : public Vectorizer {
- public:
- CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
- virtual ~CodePointsVectorizer();
- virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
- UVector32 &offsets, UVector32 &indices,
- UErrorCode &status) const override;
- };
- CodePointsVectorizer::~CodePointsVectorizer()
- {
- }
- void CodePointsVectorizer::vectorize(
- UText *text, int32_t startPos, int32_t endPos,
- UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
- {
- if (offsets.ensureCapacity(endPos - startPos, status) &&
- indices.ensureCapacity(endPos - startPos, status)) {
- if (U_FAILURE(status)) return;
- utext_setNativeIndex(text, startPos);
- int32_t current;
- char16_t str[2] = {0, 0};
- while (U_SUCCESS(status) &&
- (current = (int32_t)utext_getNativeIndex(text)) < endPos) {
- // Since the LSTMBreakEngine is currently only accept chars in BMP,
- // we can ignore the possibility of hitting supplementary code
- // point.
- str[0] = (char16_t) utext_next32(text);
- U_ASSERT(!U_IS_SURROGATE(str[0]));
- offsets.addElement(current, status);
- indices.addElement(stringToIndex(str), status);
- }
- }
- }
- class GraphemeClusterVectorizer : public Vectorizer {
- public:
- GraphemeClusterVectorizer(UHashtable* dict)
- : Vectorizer(dict)
- {
- }
- virtual ~GraphemeClusterVectorizer();
- virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
- UVector32 &offsets, UVector32 &indices,
- UErrorCode &status) const override;
- };
- GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
- {
- }
- constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
- void GraphemeClusterVectorizer::vectorize(
- UText *text, int32_t startPos, int32_t endPos,
- UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
- {
- if (U_FAILURE(status)) return;
- if (!offsets.ensureCapacity(endPos - startPos, status) ||
- !indices.ensureCapacity(endPos - startPos, status)) {
- return;
- }
- if (U_FAILURE(status)) return;
- LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
- if (U_FAILURE(status)) return;
- graphemeIter->setText(text, status);
- if (U_FAILURE(status)) return;
- if (startPos != 0) {
- graphemeIter->preceding(startPos);
- }
- int32_t last = startPos;
- int32_t current = startPos;
- char16_t str[MAX_GRAPHEME_CLSTER_LENGTH];
- while ((current = graphemeIter->next()) != BreakIterator::DONE) {
- if (current >= endPos) {
- break;
- }
- if (current > startPos) {
- utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
- if (U_FAILURE(status)) return;
- offsets.addElement(last, status);
- indices.addElement(stringToIndex(str), status);
- if (U_FAILURE(status)) return;
- }
- last = current;
- }
- if (U_FAILURE(status) || last >= endPos) {
- return;
- }
- utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
- if (U_SUCCESS(status)) {
- offsets.addElement(last, status);
- indices.addElement(stringToIndex(str), status);
- }
- }
- // Computing LSTM as stated in
- // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
- // ifco is temp array allocate outside which does not need to be
- // input/output value but could avoid unnecessary memory alloc/free if passing
- // in.
- void compute(
- int32_t hunits,
- const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
- const ReadArray1D& x, Array1D& h, Array1D& c,
- Array1D& ifco)
- {
- // ifco = x * W + h * U + b
- ifco.assign(b)
- .addDotProduct(x, W)
- .addDotProduct(h, U);
- ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
- ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
- ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
- ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
- c.hadamardProduct(ifco.slice(hunits, hunits))
- .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
- h.tanh(c)
- .hadamardProduct(ifco.slice(3*hunits, hunits));
- }
- // Minimum word size
- static const int32_t MIN_WORD = 2;
- // Minimum number of characters for two words
- static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
- int32_t
- LSTMBreakEngine::divideUpDictionaryRange( UText *text,
- int32_t startPos,
- int32_t endPos,
- UVector32 &foundBreaks,
- UBool /* isPhraseBreaking */,
- UErrorCode& status) const {
- if (U_FAILURE(status)) return 0;
- int32_t beginFoundBreakSize = foundBreaks.size();
- utext_setNativeIndex(text, startPos);
- utext_moveIndex32(text, MIN_WORD_SPAN);
- if (utext_getNativeIndex(text) >= endPos) {
- return 0; // Not enough characters for two words
- }
- utext_setNativeIndex(text, startPos);
- UVector32 offsets(status);
- UVector32 indices(status);
- if (U_FAILURE(status)) return 0;
- fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
- if (U_FAILURE(status)) return 0;
- int32_t* offsetsBuf = offsets.getBuffer();
- int32_t* indicesBuf = indices.getBuffer();
- int32_t input_seq_len = indices.size();
- int32_t hunits = fData->fForwardU.d1();
- // ----- Begin of all the Array memory allocation needed for this function
- // Allocate temp array used inside compute()
- Array1D ifco(4 * hunits, status);
- Array1D c(hunits, status);
- Array1D logp(4, status);
- // TODO: limit size of hBackward. If input_seq_len is too big, we could
- // run out of memory.
- // Backward LSTM
- Array2D hBackward(input_seq_len, hunits, status);
- // Allocate fbRow and slice the internal array in two.
- Array1D fbRow(2 * hunits, status);
- // ----- End of all the Array memory allocation needed for this function
- if (U_FAILURE(status)) return 0;
- // To save the needed memory usage, the following is different from the
- // Python or ICU4X implementation. We first perform the Backward LSTM
- // and then merge the iteration of the forward LSTM and the output layer
- // together because we only neetdto remember the h[t-1] for Forward LSTM.
- for (int32_t i = input_seq_len - 1; i >= 0; i--) {
- Array1D hRow = hBackward.row(i);
- if (i != input_seq_len - 1) {
- hRow.assign(hBackward.row(i+1));
- }
- #ifdef LSTM_DEBUG
- printf("hRow %d\n", i);
- hRow.print();
- printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
- printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
- fData->fEmbedding.row(indicesBuf[i]).print();
- #endif // LSTM_DEBUG
- compute(hunits,
- fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
- fData->fEmbedding.row(indicesBuf[i]),
- hRow, c, ifco);
- }
- Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
- Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
- // The following iteration merge the forward LSTM and the output layer
- // together.
- c.clear(); // reuse c since it is the same size.
- for (int32_t i = 0; i < input_seq_len; i++) {
- #ifdef LSTM_DEBUG
- printf("forwardRow %d\n", i);
- forwardRow.print();
- #endif // LSTM_DEBUG
- // Forward LSTM
- // Calculate the result into forwardRow, which point to the data in the first half
- // of fbRow.
- compute(hunits,
- fData->fForwardW, fData->fForwardU, fData->fForwardB,
- fData->fEmbedding.row(indicesBuf[i]),
- forwardRow, c, ifco);
- // assign the data from hBackward.row(i) to second half of fbRowa.
- backwardRow.assign(hBackward.row(i));
- logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
- #ifdef LSTM_DEBUG
- printf("backwardRow %d\n", i);
- backwardRow.print();
- printf("logp %d\n", i);
- logp.print();
- #endif // LSTM_DEBUG
- // current = argmax(logp)
- LSTMClass current = (LSTMClass)logp.maxIndex();
- // BIES logic.
- if (current == BEGIN || current == SINGLE) {
- if (i != 0) {
- foundBreaks.addElement(offsetsBuf[i], status);
- if (U_FAILURE(status)) return 0;
- }
- }
- }
- return foundBreaks.size() - beginFoundBreakSize;
- }
- Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
- if (U_FAILURE(status)) {
- return nullptr;
- }
- switch (data->fType) {
- case CODE_POINTS:
- return new CodePointsVectorizer(data->fDict);
- break;
- case GRAPHEME_CLUSTER:
- return new GraphemeClusterVectorizer(data->fDict);
- break;
- default:
- break;
- }
- UPRV_UNREACHABLE_EXIT;
- }
- LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
- : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
- {
- if (U_FAILURE(status)) {
- fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so.
- return;
- }
- setCharacters(set);
- }
- LSTMBreakEngine::~LSTMBreakEngine() {
- delete fData;
- delete fVectorizer;
- }
- const char16_t* LSTMBreakEngine::name() const {
- return fData->fName;
- }
- UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
- // open root from brkitr tree.
- UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
- b = ures_getByKeyWithFallback(b, "lstm", b, &status);
- UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
- ures_close(b);
- return result;
- }
- U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
- {
- if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
- return nullptr;
- }
- UnicodeString name = defaultLSTM(script, status);
- if (U_FAILURE(status)) return nullptr;
- CharString namebuf;
- namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
- LocalUResourceBundlePointer rb(
- ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
- if (U_FAILURE(status)) return nullptr;
- return CreateLSTMData(rb.orphan(), status);
- }
- U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
- {
- return new LSTMData(rb, status);
- }
- U_CAPI const LanguageBreakEngine* U_EXPORT2
- CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
- {
- UnicodeString unicodeSetString;
- switch(script) {
- case USCRIPT_THAI:
- unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
- break;
- case USCRIPT_MYANMAR:
- unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
- break;
- default:
- delete data;
- return nullptr;
- }
- UnicodeSet unicodeSet;
- unicodeSet.applyPattern(unicodeSetString, status);
- const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
- if (U_FAILURE(status) || engine == nullptr) {
- if (engine != nullptr) {
- delete engine;
- } else {
- status = U_MEMORY_ALLOCATION_ERROR;
- }
- return nullptr;
- }
- return engine;
- }
- U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
- {
- delete data;
- }
- U_CAPI const char16_t* U_EXPORT2 LSTMDataName(const LSTMData* data)
- {
- return data->fName;
- }
- U_NAMESPACE_END
- #endif /* #if !UCONFIG_NO_BREAK_ITERATION */
|