Refactor project structure
This commit is contained in:
286
modules/ANSODEngine/ANSCLIPTokenizer.cpp
Normal file
286
modules/ANSODEngine/ANSCLIPTokenizer.cpp
Normal file
@@ -0,0 +1,286 @@
|
||||
#include "ANSCLIPTokenizer.h"
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cctype>
|
||||
#include <set>
|
||||
|
||||
namespace ANSCENTER
|
||||
{
|
||||
// =========================================================================
|
||||
// UTF-8 helper
|
||||
// =========================================================================
|
||||
|
||||
std::string ANSCLIPTokenizer::codepointToUtf8(int cp)
|
||||
{
|
||||
std::string s;
|
||||
if (cp < 0x80) {
|
||||
s += static_cast<char>(cp);
|
||||
}
|
||||
else if (cp < 0x800) {
|
||||
s += static_cast<char>(0xC0 | (cp >> 6));
|
||||
s += static_cast<char>(0x80 | (cp & 0x3F));
|
||||
}
|
||||
else if (cp < 0x10000) {
|
||||
s += static_cast<char>(0xE0 | (cp >> 12));
|
||||
s += static_cast<char>(0x80 | ((cp >> 6) & 0x3F));
|
||||
s += static_cast<char>(0x80 | (cp & 0x3F));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Byte encoder - CLIP's bytes_to_unicode()
|
||||
// =========================================================================
|
||||
|
||||
void ANSCLIPTokenizer::initByteEncoder()
|
||||
{
|
||||
// 188 "printable" bytes map to their own code point.
|
||||
// The remaining 68 bytes map to code points 256..323.
|
||||
std::set<int> printable;
|
||||
for (int i = 33; i <= 126; ++i) printable.insert(i); // ! to ~
|
||||
for (int i = 161; i <= 172; ++i) printable.insert(i); // inverted-! to not-sign
|
||||
for (int i = 174; i <= 255; ++i) printable.insert(i); // registered-sign to y-umlaut
|
||||
|
||||
int extra = 256;
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
int cp = printable.count(b) ? b : extra++;
|
||||
m_byteEncoder[b] = codepointToUtf8(cp);
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Pre-tokenisation (simplified CLIP regex for ASCII text)
|
||||
// =========================================================================
|
||||
|
||||
std::vector<std::string> ANSCLIPTokenizer::preTokenize(const std::string& text)
|
||||
{
|
||||
// Lowercase
|
||||
std::string lower;
|
||||
lower.reserve(text.size());
|
||||
for (unsigned char c : text)
|
||||
lower += static_cast<char>(std::tolower(c));
|
||||
|
||||
std::vector<std::string> tokens;
|
||||
size_t i = 0;
|
||||
|
||||
while (i < lower.size()) {
|
||||
// Skip whitespace
|
||||
if (std::isspace(static_cast<unsigned char>(lower[i]))) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Contractions: 's 't 'm 'd 're 've 'll
|
||||
if (lower[i] == '\'' && i + 1 < lower.size()) {
|
||||
char c2 = lower[i + 1];
|
||||
if (c2 == 's' || c2 == 't' || c2 == 'm' || c2 == 'd') {
|
||||
tokens.push_back(lower.substr(i, 2));
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
if (i + 2 < lower.size()) {
|
||||
std::string tri = lower.substr(i, 3);
|
||||
if (tri == "'re" || tri == "'ve" || tri == "'ll") {
|
||||
tokens.push_back(tri);
|
||||
i += 3;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Letters
|
||||
if (std::isalpha(static_cast<unsigned char>(lower[i]))) {
|
||||
size_t start = i;
|
||||
while (i < lower.size() && std::isalpha(static_cast<unsigned char>(lower[i])))
|
||||
++i;
|
||||
tokens.push_back(lower.substr(start, i - start));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Digits
|
||||
if (std::isdigit(static_cast<unsigned char>(lower[i]))) {
|
||||
size_t start = i;
|
||||
while (i < lower.size() && std::isdigit(static_cast<unsigned char>(lower[i])))
|
||||
++i;
|
||||
tokens.push_back(lower.substr(start, i - start));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Non-whitespace, non-alnum (punctuation, symbols)
|
||||
{
|
||||
size_t start = i;
|
||||
while (i < lower.size() &&
|
||||
!std::isspace(static_cast<unsigned char>(lower[i])) &&
|
||||
!std::isalpha(static_cast<unsigned char>(lower[i])) &&
|
||||
!std::isdigit(static_cast<unsigned char>(lower[i])))
|
||||
++i;
|
||||
if (i > start)
|
||||
tokens.push_back(lower.substr(start, i - start));
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// BPE algorithm
|
||||
// =========================================================================
|
||||
|
||||
std::vector<std::string> ANSCLIPTokenizer::bpe(const std::string& word) const
|
||||
{
|
||||
// Check cache
|
||||
auto cacheIt = m_cache.find(word);
|
||||
if (cacheIt != m_cache.end())
|
||||
return cacheIt->second;
|
||||
|
||||
// Byte-encode the word
|
||||
std::vector<std::string> symbols;
|
||||
for (unsigned char c : word)
|
||||
symbols.push_back(m_byteEncoder[c]);
|
||||
|
||||
// Append </w> to the last symbol (CLIP end-of-word marker)
|
||||
if (!symbols.empty())
|
||||
symbols.back() += "</w>";
|
||||
|
||||
if (symbols.size() <= 1) {
|
||||
m_cache[word] = symbols;
|
||||
return symbols;
|
||||
}
|
||||
|
||||
// Iteratively merge the highest-priority (lowest-rank) pair
|
||||
while (symbols.size() > 1) {
|
||||
// Find the pair with the lowest rank
|
||||
int bestRank = INT_MAX;
|
||||
int bestIdx = -1;
|
||||
|
||||
for (size_t j = 0; j + 1 < symbols.size(); ++j) {
|
||||
auto it = m_bpeRanks.find({ symbols[j], symbols[j + 1] });
|
||||
if (it != m_bpeRanks.end() && it->second < bestRank) {
|
||||
bestRank = it->second;
|
||||
bestIdx = static_cast<int>(j);
|
||||
}
|
||||
}
|
||||
|
||||
if (bestIdx < 0) break; // no mergeable pairs left
|
||||
|
||||
// Merge ALL occurrences of the best pair
|
||||
std::string a = symbols[bestIdx];
|
||||
std::string b = symbols[bestIdx + 1];
|
||||
std::string merged = a + b;
|
||||
|
||||
std::vector<std::string> next;
|
||||
size_t j = 0;
|
||||
while (j < symbols.size()) {
|
||||
if (j + 1 < symbols.size() && symbols[j] == a && symbols[j + 1] == b) {
|
||||
next.push_back(merged);
|
||||
j += 2;
|
||||
} else {
|
||||
next.push_back(symbols[j]);
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
symbols = std::move(next);
|
||||
}
|
||||
|
||||
m_cache[word] = symbols;
|
||||
return symbols;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Load vocabulary from BPE merges file
|
||||
// =========================================================================
|
||||
|
||||
bool ANSCLIPTokenizer::Load(const std::string& mergesFilePath)
|
||||
{
|
||||
initByteEncoder();
|
||||
|
||||
std::ifstream file(mergesFilePath);
|
||||
if (!file.is_open())
|
||||
return false;
|
||||
|
||||
std::string line;
|
||||
std::vector<std::pair<std::string, std::string>> merges;
|
||||
|
||||
while (std::getline(file, line)) {
|
||||
// Trim trailing \r \n
|
||||
while (!line.empty() && (line.back() == '\r' || line.back() == '\n'))
|
||||
line.pop_back();
|
||||
// Skip empty lines and comments (#version: 0.2)
|
||||
if (line.empty() || line[0] == '#')
|
||||
continue;
|
||||
|
||||
size_t sp = line.find(' ');
|
||||
if (sp == std::string::npos)
|
||||
continue;
|
||||
|
||||
merges.emplace_back(line.substr(0, sp), line.substr(sp + 1));
|
||||
}
|
||||
file.close();
|
||||
|
||||
// Build BPE ranks
|
||||
m_bpeRanks.clear();
|
||||
for (int i = 0; i < static_cast<int>(merges.size()); ++i)
|
||||
m_bpeRanks[merges[i]] = i;
|
||||
|
||||
// Build vocabulary (same order as OpenAI CLIP):
|
||||
// [0..255] single byte-encoded chars
|
||||
// [256..511] byte-encoded chars + </w>
|
||||
// [512..512+N-1] merged tokens
|
||||
// [49406] <|startoftext|>
|
||||
// [49407] <|endoftext|>
|
||||
m_encoder.clear();
|
||||
int id = 0;
|
||||
for (int b = 0; b < 256; ++b)
|
||||
m_encoder[m_byteEncoder[b]] = id++;
|
||||
for (int b = 0; b < 256; ++b)
|
||||
m_encoder[m_byteEncoder[b] + "</w>"] = id++;
|
||||
for (const auto& merge : merges)
|
||||
m_encoder[merge.first + merge.second] = id++;
|
||||
m_encoder["<|startoftext|>"] = id++;
|
||||
m_encoder["<|endoftext|>"] = id++;
|
||||
|
||||
m_cache.clear();
|
||||
m_loaded = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Tokenize
|
||||
// =========================================================================
|
||||
|
||||
TokenizerResult ANSCLIPTokenizer::Tokenize(const std::string& text, int maxLength) const
|
||||
{
|
||||
TokenizerResult result;
|
||||
result.inputIds.resize(maxLength, 0); // pad with 0 (CLIP convention)
|
||||
result.attentionMask.resize(maxLength, 0);
|
||||
|
||||
if (!m_loaded)
|
||||
return result;
|
||||
|
||||
// BOS token
|
||||
result.inputIds[0] = BOS_TOKEN;
|
||||
result.attentionMask[0] = 1;
|
||||
int pos = 1;
|
||||
|
||||
// Pre-tokenize -> BPE -> encode
|
||||
for (const auto& word : preTokenize(text)) {
|
||||
for (const auto& tok : bpe(word)) {
|
||||
auto it = m_encoder.find(tok);
|
||||
if (it != m_encoder.end() && pos < maxLength - 1) {
|
||||
result.inputIds[pos] = it->second;
|
||||
result.attentionMask[pos] = 1;
|
||||
++pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EOS token
|
||||
if (pos < maxLength) {
|
||||
result.inputIds[pos] = EOS_TOKEN;
|
||||
result.attentionMask[pos] = 1;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user