Files
ANSCORE/modules/ANSODEngine/ANSCLIPTokenizer.h

66 lines
2.2 KiB
C
Raw Normal View History

2026-03-28 16:54:11 +11:00
#ifndef ANSCLIPTOKENIZER_H
#define ANSCLIPTOKENIZER_H
#pragma once
#include <string>
#include <vector>
#include <unordered_map>
#include <map>
#include <cstdint>
namespace ANSCENTER
{
struct TokenizerResult
{
std::vector<int64_t> inputIds;
std::vector<int64_t> attentionMask;
};
/// CLIP BPE tokenizer for text-prompted segmentation models.
///
/// Requires a BPE merges file (merges.txt from HuggingFace
/// openai/clip-vit-base-patch32). Place the file alongside the
/// ONNX model in the model folder.
class ANSCLIPTokenizer
{
public:
/// Load BPE vocabulary from a merges file.
/// @param mergesFilePath Path to the CLIP BPE merges file (merges.txt).
/// @return true on success.
bool Load(const std::string& mergesFilePath);
/// @return true if vocabulary has been loaded.
bool IsLoaded() const { return m_loaded; }
/// Tokenize a text prompt into input IDs and attention mask.
/// @param text The text to tokenize (e.g., "person").
/// @param maxLength Output sequence length (padded/truncated). Default 32.
/// @return TokenizerResult with inputIds and attentionMask vectors.
TokenizerResult Tokenize(const std::string& text, int maxLength = 32) const;
private:
bool m_loaded = false;
// Byte value (0-255) -> unicode string representation (CLIP byte encoding)
std::string m_byteEncoder[256];
// BPE merge: (token_a, token_b) -> priority rank (lower = merge first)
std::map<std::pair<std::string, std::string>, int> m_bpeRanks;
// Token string -> integer ID
std::unordered_map<std::string, int> m_encoder;
static constexpr int BOS_TOKEN = 49406; // <|startoftext|>
static constexpr int EOS_TOKEN = 49407; // <|endoftext|>
// BPE result cache (word -> BPE token list)
mutable std::unordered_map<std::string, std::vector<std::string>> m_cache;
void initByteEncoder();
std::vector<std::string> bpe(const std::string& word) const;
static std::vector<std::string> preTokenize(const std::string& text);
static std::string codepointToUtf8(int codepoint);
};
}
#endif