Files
ANSCORE/modules/ANSODEngine/ANSCLIPTokenizer.cpp

287 lines
9.7 KiB
C++
Raw Permalink Normal View History

2026-03-28 16:54:11 +11:00
#include "ANSCLIPTokenizer.h"
#include <fstream>
#include <algorithm>
#include <climits>
#include <cctype>
#include <set>
namespace ANSCENTER
{
// =========================================================================
// UTF-8 helper
// =========================================================================
std::string ANSCLIPTokenizer::codepointToUtf8(int cp)
{
std::string s;
if (cp < 0x80) {
s += static_cast<char>(cp);
}
else if (cp < 0x800) {
s += static_cast<char>(0xC0 | (cp >> 6));
s += static_cast<char>(0x80 | (cp & 0x3F));
}
else if (cp < 0x10000) {
s += static_cast<char>(0xE0 | (cp >> 12));
s += static_cast<char>(0x80 | ((cp >> 6) & 0x3F));
s += static_cast<char>(0x80 | (cp & 0x3F));
}
return s;
}
// =========================================================================
// Byte encoder - CLIP's bytes_to_unicode()
// =========================================================================
void ANSCLIPTokenizer::initByteEncoder()
{
// 188 "printable" bytes map to their own code point.
// The remaining 68 bytes map to code points 256..323.
std::set<int> printable;
for (int i = 33; i <= 126; ++i) printable.insert(i); // ! to ~
for (int i = 161; i <= 172; ++i) printable.insert(i); // inverted-! to not-sign
for (int i = 174; i <= 255; ++i) printable.insert(i); // registered-sign to y-umlaut
int extra = 256;
for (int b = 0; b < 256; ++b) {
int cp = printable.count(b) ? b : extra++;
m_byteEncoder[b] = codepointToUtf8(cp);
}
}
// =========================================================================
// Pre-tokenisation (simplified CLIP regex for ASCII text)
// =========================================================================
std::vector<std::string> ANSCLIPTokenizer::preTokenize(const std::string& text)
{
// Lowercase
std::string lower;
lower.reserve(text.size());
for (unsigned char c : text)
lower += static_cast<char>(std::tolower(c));
std::vector<std::string> tokens;
size_t i = 0;
while (i < lower.size()) {
// Skip whitespace
if (std::isspace(static_cast<unsigned char>(lower[i]))) {
++i;
continue;
}
// Contractions: 's 't 'm 'd 're 've 'll
if (lower[i] == '\'' && i + 1 < lower.size()) {
char c2 = lower[i + 1];
if (c2 == 's' || c2 == 't' || c2 == 'm' || c2 == 'd') {
tokens.push_back(lower.substr(i, 2));
i += 2;
continue;
}
if (i + 2 < lower.size()) {
std::string tri = lower.substr(i, 3);
if (tri == "'re" || tri == "'ve" || tri == "'ll") {
tokens.push_back(tri);
i += 3;
continue;
}
}
}
// Letters
if (std::isalpha(static_cast<unsigned char>(lower[i]))) {
size_t start = i;
while (i < lower.size() && std::isalpha(static_cast<unsigned char>(lower[i])))
++i;
tokens.push_back(lower.substr(start, i - start));
continue;
}
// Digits
if (std::isdigit(static_cast<unsigned char>(lower[i]))) {
size_t start = i;
while (i < lower.size() && std::isdigit(static_cast<unsigned char>(lower[i])))
++i;
tokens.push_back(lower.substr(start, i - start));
continue;
}
// Non-whitespace, non-alnum (punctuation, symbols)
{
size_t start = i;
while (i < lower.size() &&
!std::isspace(static_cast<unsigned char>(lower[i])) &&
!std::isalpha(static_cast<unsigned char>(lower[i])) &&
!std::isdigit(static_cast<unsigned char>(lower[i])))
++i;
if (i > start)
tokens.push_back(lower.substr(start, i - start));
}
}
return tokens;
}
// =========================================================================
// BPE algorithm
// =========================================================================
std::vector<std::string> ANSCLIPTokenizer::bpe(const std::string& word) const
{
// Check cache
auto cacheIt = m_cache.find(word);
if (cacheIt != m_cache.end())
return cacheIt->second;
// Byte-encode the word
std::vector<std::string> symbols;
for (unsigned char c : word)
symbols.push_back(m_byteEncoder[c]);
// Append </w> to the last symbol (CLIP end-of-word marker)
if (!symbols.empty())
symbols.back() += "</w>";
if (symbols.size() <= 1) {
m_cache[word] = symbols;
return symbols;
}
// Iteratively merge the highest-priority (lowest-rank) pair
while (symbols.size() > 1) {
// Find the pair with the lowest rank
int bestRank = INT_MAX;
int bestIdx = -1;
for (size_t j = 0; j + 1 < symbols.size(); ++j) {
auto it = m_bpeRanks.find({ symbols[j], symbols[j + 1] });
if (it != m_bpeRanks.end() && it->second < bestRank) {
bestRank = it->second;
bestIdx = static_cast<int>(j);
}
}
if (bestIdx < 0) break; // no mergeable pairs left
// Merge ALL occurrences of the best pair
std::string a = symbols[bestIdx];
std::string b = symbols[bestIdx + 1];
std::string merged = a + b;
std::vector<std::string> next;
size_t j = 0;
while (j < symbols.size()) {
if (j + 1 < symbols.size() && symbols[j] == a && symbols[j + 1] == b) {
next.push_back(merged);
j += 2;
} else {
next.push_back(symbols[j]);
j += 1;
}
}
symbols = std::move(next);
}
m_cache[word] = symbols;
return symbols;
}
// =========================================================================
// Load vocabulary from BPE merges file
// =========================================================================
bool ANSCLIPTokenizer::Load(const std::string& mergesFilePath)
{
initByteEncoder();
std::ifstream file(mergesFilePath);
if (!file.is_open())
return false;
std::string line;
std::vector<std::pair<std::string, std::string>> merges;
while (std::getline(file, line)) {
// Trim trailing \r \n
while (!line.empty() && (line.back() == '\r' || line.back() == '\n'))
line.pop_back();
// Skip empty lines and comments (#version: 0.2)
if (line.empty() || line[0] == '#')
continue;
size_t sp = line.find(' ');
if (sp == std::string::npos)
continue;
merges.emplace_back(line.substr(0, sp), line.substr(sp + 1));
}
file.close();
// Build BPE ranks
m_bpeRanks.clear();
for (int i = 0; i < static_cast<int>(merges.size()); ++i)
m_bpeRanks[merges[i]] = i;
// Build vocabulary (same order as OpenAI CLIP):
// [0..255] single byte-encoded chars
// [256..511] byte-encoded chars + </w>
// [512..512+N-1] merged tokens
// [49406] <|startoftext|>
// [49407] <|endoftext|>
m_encoder.clear();
int id = 0;
for (int b = 0; b < 256; ++b)
m_encoder[m_byteEncoder[b]] = id++;
for (int b = 0; b < 256; ++b)
m_encoder[m_byteEncoder[b] + "</w>"] = id++;
for (const auto& merge : merges)
m_encoder[merge.first + merge.second] = id++;
m_encoder["<|startoftext|>"] = id++;
m_encoder["<|endoftext|>"] = id++;
m_cache.clear();
m_loaded = true;
return true;
}
// =========================================================================
// Tokenize
// =========================================================================
TokenizerResult ANSCLIPTokenizer::Tokenize(const std::string& text, int maxLength) const
{
TokenizerResult result;
result.inputIds.resize(maxLength, 0); // pad with 0 (CLIP convention)
result.attentionMask.resize(maxLength, 0);
if (!m_loaded)
return result;
// BOS token
result.inputIds[0] = BOS_TOKEN;
result.attentionMask[0] = 1;
int pos = 1;
// Pre-tokenize -> BPE -> encode
for (const auto& word : preTokenize(text)) {
for (const auto& tok : bpe(word)) {
auto it = m_encoder.find(tok);
if (it != m_encoder.end() && pos < maxLength - 1) {
result.inputIds[pos] = it->second;
result.attentionMask[pos] = 1;
++pos;
}
}
}
// EOS token
if (pos < maxLength) {
result.inputIds[pos] = EOS_TOKEN;
result.attentionMask[pos] = 1;
}
return result;
}
}