287 lines
9.7 KiB
C++
287 lines
9.7 KiB
C++
|
|
#include "ANSCLIPTokenizer.h"
|
||
|
|
#include <fstream>
|
||
|
|
#include <algorithm>
|
||
|
|
#include <climits>
|
||
|
|
#include <cctype>
|
||
|
|
#include <set>
|
||
|
|
|
||
|
|
namespace ANSCENTER
|
||
|
|
{
|
||
|
|
// =========================================================================
|
||
|
|
// UTF-8 helper
|
||
|
|
// =========================================================================
|
||
|
|
|
||
|
|
std::string ANSCLIPTokenizer::codepointToUtf8(int cp)
|
||
|
|
{
|
||
|
|
std::string s;
|
||
|
|
if (cp < 0x80) {
|
||
|
|
s += static_cast<char>(cp);
|
||
|
|
}
|
||
|
|
else if (cp < 0x800) {
|
||
|
|
s += static_cast<char>(0xC0 | (cp >> 6));
|
||
|
|
s += static_cast<char>(0x80 | (cp & 0x3F));
|
||
|
|
}
|
||
|
|
else if (cp < 0x10000) {
|
||
|
|
s += static_cast<char>(0xE0 | (cp >> 12));
|
||
|
|
s += static_cast<char>(0x80 | ((cp >> 6) & 0x3F));
|
||
|
|
s += static_cast<char>(0x80 | (cp & 0x3F));
|
||
|
|
}
|
||
|
|
return s;
|
||
|
|
}
|
||
|
|
|
||
|
|
// =========================================================================
|
||
|
|
// Byte encoder - CLIP's bytes_to_unicode()
|
||
|
|
// =========================================================================
|
||
|
|
|
||
|
|
void ANSCLIPTokenizer::initByteEncoder()
|
||
|
|
{
|
||
|
|
// 188 "printable" bytes map to their own code point.
|
||
|
|
// The remaining 68 bytes map to code points 256..323.
|
||
|
|
std::set<int> printable;
|
||
|
|
for (int i = 33; i <= 126; ++i) printable.insert(i); // ! to ~
|
||
|
|
for (int i = 161; i <= 172; ++i) printable.insert(i); // inverted-! to not-sign
|
||
|
|
for (int i = 174; i <= 255; ++i) printable.insert(i); // registered-sign to y-umlaut
|
||
|
|
|
||
|
|
int extra = 256;
|
||
|
|
for (int b = 0; b < 256; ++b) {
|
||
|
|
int cp = printable.count(b) ? b : extra++;
|
||
|
|
m_byteEncoder[b] = codepointToUtf8(cp);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// =========================================================================
|
||
|
|
// Pre-tokenisation (simplified CLIP regex for ASCII text)
|
||
|
|
// =========================================================================
|
||
|
|
|
||
|
|
std::vector<std::string> ANSCLIPTokenizer::preTokenize(const std::string& text)
|
||
|
|
{
|
||
|
|
// Lowercase
|
||
|
|
std::string lower;
|
||
|
|
lower.reserve(text.size());
|
||
|
|
for (unsigned char c : text)
|
||
|
|
lower += static_cast<char>(std::tolower(c));
|
||
|
|
|
||
|
|
std::vector<std::string> tokens;
|
||
|
|
size_t i = 0;
|
||
|
|
|
||
|
|
while (i < lower.size()) {
|
||
|
|
// Skip whitespace
|
||
|
|
if (std::isspace(static_cast<unsigned char>(lower[i]))) {
|
||
|
|
++i;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Contractions: 's 't 'm 'd 're 've 'll
|
||
|
|
if (lower[i] == '\'' && i + 1 < lower.size()) {
|
||
|
|
char c2 = lower[i + 1];
|
||
|
|
if (c2 == 's' || c2 == 't' || c2 == 'm' || c2 == 'd') {
|
||
|
|
tokens.push_back(lower.substr(i, 2));
|
||
|
|
i += 2;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
if (i + 2 < lower.size()) {
|
||
|
|
std::string tri = lower.substr(i, 3);
|
||
|
|
if (tri == "'re" || tri == "'ve" || tri == "'ll") {
|
||
|
|
tokens.push_back(tri);
|
||
|
|
i += 3;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Letters
|
||
|
|
if (std::isalpha(static_cast<unsigned char>(lower[i]))) {
|
||
|
|
size_t start = i;
|
||
|
|
while (i < lower.size() && std::isalpha(static_cast<unsigned char>(lower[i])))
|
||
|
|
++i;
|
||
|
|
tokens.push_back(lower.substr(start, i - start));
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Digits
|
||
|
|
if (std::isdigit(static_cast<unsigned char>(lower[i]))) {
|
||
|
|
size_t start = i;
|
||
|
|
while (i < lower.size() && std::isdigit(static_cast<unsigned char>(lower[i])))
|
||
|
|
++i;
|
||
|
|
tokens.push_back(lower.substr(start, i - start));
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Non-whitespace, non-alnum (punctuation, symbols)
|
||
|
|
{
|
||
|
|
size_t start = i;
|
||
|
|
while (i < lower.size() &&
|
||
|
|
!std::isspace(static_cast<unsigned char>(lower[i])) &&
|
||
|
|
!std::isalpha(static_cast<unsigned char>(lower[i])) &&
|
||
|
|
!std::isdigit(static_cast<unsigned char>(lower[i])))
|
||
|
|
++i;
|
||
|
|
if (i > start)
|
||
|
|
tokens.push_back(lower.substr(start, i - start));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return tokens;
|
||
|
|
}
|
||
|
|
|
||
|
|
// =========================================================================
|
||
|
|
// BPE algorithm
|
||
|
|
// =========================================================================
|
||
|
|
|
||
|
|
std::vector<std::string> ANSCLIPTokenizer::bpe(const std::string& word) const
|
||
|
|
{
|
||
|
|
// Check cache
|
||
|
|
auto cacheIt = m_cache.find(word);
|
||
|
|
if (cacheIt != m_cache.end())
|
||
|
|
return cacheIt->second;
|
||
|
|
|
||
|
|
// Byte-encode the word
|
||
|
|
std::vector<std::string> symbols;
|
||
|
|
for (unsigned char c : word)
|
||
|
|
symbols.push_back(m_byteEncoder[c]);
|
||
|
|
|
||
|
|
// Append </w> to the last symbol (CLIP end-of-word marker)
|
||
|
|
if (!symbols.empty())
|
||
|
|
symbols.back() += "</w>";
|
||
|
|
|
||
|
|
if (symbols.size() <= 1) {
|
||
|
|
m_cache[word] = symbols;
|
||
|
|
return symbols;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Iteratively merge the highest-priority (lowest-rank) pair
|
||
|
|
while (symbols.size() > 1) {
|
||
|
|
// Find the pair with the lowest rank
|
||
|
|
int bestRank = INT_MAX;
|
||
|
|
int bestIdx = -1;
|
||
|
|
|
||
|
|
for (size_t j = 0; j + 1 < symbols.size(); ++j) {
|
||
|
|
auto it = m_bpeRanks.find({ symbols[j], symbols[j + 1] });
|
||
|
|
if (it != m_bpeRanks.end() && it->second < bestRank) {
|
||
|
|
bestRank = it->second;
|
||
|
|
bestIdx = static_cast<int>(j);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if (bestIdx < 0) break; // no mergeable pairs left
|
||
|
|
|
||
|
|
// Merge ALL occurrences of the best pair
|
||
|
|
std::string a = symbols[bestIdx];
|
||
|
|
std::string b = symbols[bestIdx + 1];
|
||
|
|
std::string merged = a + b;
|
||
|
|
|
||
|
|
std::vector<std::string> next;
|
||
|
|
size_t j = 0;
|
||
|
|
while (j < symbols.size()) {
|
||
|
|
if (j + 1 < symbols.size() && symbols[j] == a && symbols[j + 1] == b) {
|
||
|
|
next.push_back(merged);
|
||
|
|
j += 2;
|
||
|
|
} else {
|
||
|
|
next.push_back(symbols[j]);
|
||
|
|
j += 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
symbols = std::move(next);
|
||
|
|
}
|
||
|
|
|
||
|
|
m_cache[word] = symbols;
|
||
|
|
return symbols;
|
||
|
|
}
|
||
|
|
|
||
|
|
// =========================================================================
|
||
|
|
// Load vocabulary from BPE merges file
|
||
|
|
// =========================================================================
|
||
|
|
|
||
|
|
bool ANSCLIPTokenizer::Load(const std::string& mergesFilePath)
|
||
|
|
{
|
||
|
|
initByteEncoder();
|
||
|
|
|
||
|
|
std::ifstream file(mergesFilePath);
|
||
|
|
if (!file.is_open())
|
||
|
|
return false;
|
||
|
|
|
||
|
|
std::string line;
|
||
|
|
std::vector<std::pair<std::string, std::string>> merges;
|
||
|
|
|
||
|
|
while (std::getline(file, line)) {
|
||
|
|
// Trim trailing \r \n
|
||
|
|
while (!line.empty() && (line.back() == '\r' || line.back() == '\n'))
|
||
|
|
line.pop_back();
|
||
|
|
// Skip empty lines and comments (#version: 0.2)
|
||
|
|
if (line.empty() || line[0] == '#')
|
||
|
|
continue;
|
||
|
|
|
||
|
|
size_t sp = line.find(' ');
|
||
|
|
if (sp == std::string::npos)
|
||
|
|
continue;
|
||
|
|
|
||
|
|
merges.emplace_back(line.substr(0, sp), line.substr(sp + 1));
|
||
|
|
}
|
||
|
|
file.close();
|
||
|
|
|
||
|
|
// Build BPE ranks
|
||
|
|
m_bpeRanks.clear();
|
||
|
|
for (int i = 0; i < static_cast<int>(merges.size()); ++i)
|
||
|
|
m_bpeRanks[merges[i]] = i;
|
||
|
|
|
||
|
|
// Build vocabulary (same order as OpenAI CLIP):
|
||
|
|
// [0..255] single byte-encoded chars
|
||
|
|
// [256..511] byte-encoded chars + </w>
|
||
|
|
// [512..512+N-1] merged tokens
|
||
|
|
// [49406] <|startoftext|>
|
||
|
|
// [49407] <|endoftext|>
|
||
|
|
m_encoder.clear();
|
||
|
|
int id = 0;
|
||
|
|
for (int b = 0; b < 256; ++b)
|
||
|
|
m_encoder[m_byteEncoder[b]] = id++;
|
||
|
|
for (int b = 0; b < 256; ++b)
|
||
|
|
m_encoder[m_byteEncoder[b] + "</w>"] = id++;
|
||
|
|
for (const auto& merge : merges)
|
||
|
|
m_encoder[merge.first + merge.second] = id++;
|
||
|
|
m_encoder["<|startoftext|>"] = id++;
|
||
|
|
m_encoder["<|endoftext|>"] = id++;
|
||
|
|
|
||
|
|
m_cache.clear();
|
||
|
|
m_loaded = true;
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
|
||
|
|
// =========================================================================
|
||
|
|
// Tokenize
|
||
|
|
// =========================================================================
|
||
|
|
|
||
|
|
TokenizerResult ANSCLIPTokenizer::Tokenize(const std::string& text, int maxLength) const
|
||
|
|
{
|
||
|
|
TokenizerResult result;
|
||
|
|
result.inputIds.resize(maxLength, 0); // pad with 0 (CLIP convention)
|
||
|
|
result.attentionMask.resize(maxLength, 0);
|
||
|
|
|
||
|
|
if (!m_loaded)
|
||
|
|
return result;
|
||
|
|
|
||
|
|
// BOS token
|
||
|
|
result.inputIds[0] = BOS_TOKEN;
|
||
|
|
result.attentionMask[0] = 1;
|
||
|
|
int pos = 1;
|
||
|
|
|
||
|
|
// Pre-tokenize -> BPE -> encode
|
||
|
|
for (const auto& word : preTokenize(text)) {
|
||
|
|
for (const auto& tok : bpe(word)) {
|
||
|
|
auto it = m_encoder.find(tok);
|
||
|
|
if (it != m_encoder.end() && pos < maxLength - 1) {
|
||
|
|
result.inputIds[pos] = it->second;
|
||
|
|
result.attentionMask[pos] = 1;
|
||
|
|
++pos;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// EOS token
|
||
|
|
if (pos < maxLength) {
|
||
|
|
result.inputIds[pos] = EOS_TOKEN;
|
||
|
|
result.attentionMask[pos] = 1;
|
||
|
|
}
|
||
|
|
|
||
|
|
return result;
|
||
|
|
}
|
||
|
|
}
|