diff --git a/runtime/processor/wetext_processor.cc b/runtime/processor/wetext_processor.cc index 1027a82..e30d117 100644 --- a/runtime/processor/wetext_processor.cc +++ b/runtime/processor/wetext_processor.cc @@ -15,8 +15,49 @@ #include "processor/wetext_processor.h" #include "utils/wetext_log.h" +#include "utils/wetext_string.h" namespace wetext { + +static char32_t UTF8ToCodePoint(const std::string& ch) { + int len = UTF8CharLength(ch[0]); + char32_t cp = 0; + if (len == 1) { + cp = static_cast(ch[0]); + } else if (len == 2) { + cp = ((static_cast(ch[0]) & 0x1F) << 6) | + (static_cast(ch[1]) & 0x3F); + } else if (len == 3) { + cp = ((static_cast(ch[0]) & 0x0F) << 12) | + ((static_cast(ch[1]) & 0x3F) << 6) | + (static_cast(ch[2]) & 0x3F); + } else if (len == 4) { + cp = ((static_cast(ch[0]) & 0x07) << 18) | + ((static_cast(ch[1]) & 0x3F) << 12) | + ((static_cast(ch[2]) & 0x3F) << 6) | + (static_cast(ch[3]) & 0x3F); + } + return cp; +} + +static bool IsKnownChar(char32_t cp) { + // ASCII printable characters (space to ~) + if (cp >= 0x0020 && cp <= 0x007E) return true; + // CJK Unified Ideographs + if (cp >= 0x4E00 && cp <= 0x9FFF) return true; + // CJK Unified Ideographs Extension A + if (cp >= 0x3400 && cp <= 0x4DBF) return true; + // CJK Compatibility Ideographs + if (cp >= 0xF900 && cp <= 0xFAFF) return true; + // CJK Symbols and Punctuation + if (cp >= 0x3000 && cp <= 0x303F) return true; + // General Punctuation + if (cp >= 0x2000 && cp <= 0x206F) return true; + // Fullwidth forms + if (cp >= 0xFF00 && cp <= 0xFFEF) return true; + return false; +} + Processor::Processor(const std::string& tagger_path, const std::string& verbalizer_path) { tagger_.reset(StdVectorFst::Read(tagger_path)); @@ -76,8 +117,28 @@ std::string Processor::Verbalize(const std::string& input) { return output; } +std::string Processor::TagOOV(const std::string& input) { + std::vector chars; + SplitUTF8StringToChars(input, &chars); + std::string output; + for (const auto& ch : chars) { + char32_t cp = UTF8ToCodePoint(ch); + if (IsKnownChar(cp)) { + output += ch; + } else { + output += "" + ch + ""; + } + } + return output; +} + std::string Processor::Normalize(const std::string& input) { - return Verbalize(Tag(input)); + std::string output = Verbalize(Tag(input)); + if (parse_type_ == ParseType::kZH_TN && + output.find("") == std::string::npos) { + output = TagOOV(output); + } + return output; } } // namespace wetext diff --git a/runtime/processor/wetext_processor.h b/runtime/processor/wetext_processor.h index e11d307..78da511 100644 --- a/runtime/processor/wetext_processor.h +++ b/runtime/processor/wetext_processor.h @@ -34,6 +34,7 @@ class Processor { std::string Tag(const std::string& input); std::string Verbalize(const std::string& input); std::string Normalize(const std::string& input); + std::string TagOOV(const std::string& input); private: std::string ShortestPath(const StdVectorFst& lattice);