diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs index 5bd204f501..8786eeb9ac 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs @@ -59,6 +59,50 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool specialTokens); } + internal SentencePieceBaseModel( + bool addBos, bool addEos, + string bosToken, int bosId, + string eosToken, int eosId, + string unkToken, int unkId, + bool addDummyPrefix, bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, bool byteFallback, + ReadOnlySpan precompiledCharsmap, bool removeExtraWhitespaces, + IReadOnlyDictionary? specialTokens) + { + AddBeginningOfSentence = addBos; + AddEndOfSentence = addEos; + BeginningOfSentenceToken = bosToken; + BeginningOfSentenceId = bosId; + EndOfSentenceToken = eosToken; + EndOfSentenceId = eosId; + UnknownToken = unkToken; + UnknownId = unkId; + AddDummyPrefix = addDummyPrefix; + EscapeWhiteSpaces = escapeWhiteSpaces; + TreatWhitespaceAsSuffix = treatWhitespaceAsSuffix; + ByteFallback = byteFallback; + SpecialTokens = specialTokens; + + if (specialTokens is not null && specialTokens.Count > 0) + { + InternalSpecialTokens = new Dictionary(); + SpecialTokensReverse = new Dictionary(); + + foreach (var item in specialTokens) + { + InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value); + SpecialTokensReverse.Add(item.Value, item.Key); + } + + SpecialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); + } + + Normalizer = new SentencePieceNormalizer( + precompiledCharsmap, removeExtraWhitespaces, + addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, specialTokens); + } + internal Regex? SpecialTokensRegex { get; } internal Dictionary? InternalSpecialTokens { get; } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs index cb945d24fa..8ef032682d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs @@ -7,6 +7,7 @@ using System.Buffers; using System.Collections.Generic; using System.IO; +using System.Text.Json; namespace Microsoft.ML.Tokenizers { @@ -30,6 +31,11 @@ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, }; } + private SentencePieceTokenizer(SentencePieceBaseModel model) + { + _model = model; + } + /// /// The special tokens. /// @@ -457,5 +463,602 @@ public static SentencePieceTokenizer Create( return new SentencePieceTokenizer(modelProto, addBeginningOfSentence, addEndOfSentence, specialTokens); } + + /// + /// Creates a Unigram from an in-memory vocabulary of (piece, score) pairs. + /// + /// + /// The vocabulary as an ordered sequence of (piece, score) pairs. The position of each pair + /// in the sequence determines its token ID. + /// + /// The index (token ID) of the unknown token in . + /// Whether to emit the beginning-of-sentence token during encoding. + /// Whether to emit the end-of-sentence token during encoding. + /// + /// Optional precompiled character normalization map (as found in the SentencePiece normalizer_spec.precompiled_charsmap + /// field or in the Hugging Face tokenizer.json normalizer.precompiled_charsmap property). + /// Pass to skip precompiled normalization. + /// + /// Whether to prepend the dummy whitespace prefix character (U+2581) at the start of the input. + /// Whether to replace spaces with the dummy whitespace character (U+2581) during normalization. + /// Whether to emit the U+2581 character at the end of the last token rather than the beginning of the first token. + /// Whether unknown characters are decomposed into UTF-8 byte pieces (<0x00>..<0xFF>) instead of the unknown token. + /// Additional special tokens to recognize, supplied as a mapping of token string to token ID. + /// A new instance. + /// + /// The beginning-of-sentence and end-of-sentence token IDs are auto-detected by looking for pieces + /// named <s> and </s> in . If a piece is not found it is + /// treated as absent; requesting or + /// when the corresponding piece is absent throws an . A <pad> piece + /// is likewise detected automatically when present. + /// + /// When creating the tokenizer, ensure that the vocabulary is sourced from a trusted provider. + /// + /// + public static SentencePieceTokenizer Create( + IEnumerable<(string Piece, float Score)> vocab, + int unkId, + bool addBeginningOfSentence = true, + bool addEndOfSentence = false, + ReadOnlySpan precompiledCharsMap = default, + bool addDummyPrefix = true, + bool escapeWhiteSpaces = true, + bool treatWhitespaceAsSuffix = false, + bool byteFallback = false, + IReadOnlyDictionary? specialTokens = null) + { + if (vocab is null) + { + throw new ArgumentNullException(nameof(vocab)); + } + + IReadOnlyList<(string Piece, float Score)> pieces = vocab as IReadOnlyList<(string Piece, float Score)> + ?? new List<(string Piece, float Score)>(vocab); + + SentencePieceUnigramModel model = new SentencePieceUnigramModel( + pieces, unkId, addBeginningOfSentence, addEndOfSentence, + precompiledCharsMap, addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, removeExtraWhitespaces: true, byteFallback, specialTokens); + + return new SentencePieceTokenizer(model); + } + + /// + /// Creates a Unigram by parsing a Hugging Face tokenizer.json + /// that contains a Unigram model (model.type == "Unigram"). + /// + /// A stream containing the UTF-8-encoded tokenizer.json content. + /// Whether to emit the beginning-of-sentence token during encoding. + /// Whether to emit the end-of-sentence token during encoding. + /// Additional special tokens to recognize, supplied as a mapping of token string to token ID. + /// A new instance. + /// + /// The following fields are read from the JSON: + /// + /// model.vocab — array of [piece, score] pairs (required). + /// model.unk_id — index of the unknown token (required). + /// model.byte_fallback — whether unknown characters fall back to UTF-8 byte pieces. + /// added_tokens — special tokens (those with "special": true) and their IDs. + /// normalizer.precompiled_charsmap (base64) — normalization map; also searched inside a Sequence normalizer. + /// pre_tokenizer of type Metaspaceadd_prefix_space and replacement; also searched inside a Sequence pre-tokenizer. + /// post_processor (TemplateProcessing, RobertaProcessing, BertProcessing, or a Sequence of these) — the special tokens that wrap a single sequence, gated by (prefix) and (suffix). + /// + /// + /// remove_extra_whitespaces has no direct representation in tokenizer.json and is assumed to be + /// . Pair-sequence templates and per-token type_ids are not applied. Templates that + /// place a special token in the middle of the sequence are rejected with . + /// + /// + /// When creating the tokenizer, ensure that the JSON stream is sourced from a trusted provider. + /// + /// + public static SentencePieceTokenizer CreateFromTokenizerJson( + Stream tokenizerJsonStream, + bool addBeginningOfSentence = true, + bool addEndOfSentence = false, + IReadOnlyDictionary? specialTokens = null) + { + if (tokenizerJsonStream is null) + { + throw new ArgumentNullException(nameof(tokenizerJsonStream)); + } + + using JsonDocument doc = JsonDocument.Parse(tokenizerJsonStream); + JsonElement root = doc.RootElement; + + // Validate model type + if (!root.TryGetProperty("model", out JsonElement modelElement)) + { + throw new InvalidDataException("The tokenizer.json does not contain a 'model' property."); + } + + if (modelElement.ValueKind != JsonValueKind.Object) + { + throw new InvalidDataException("The tokenizer.json 'model' property must be a JSON object."); + } + + if (!modelElement.TryGetProperty("type", out JsonElement modelTypeElement)) + { + throw new InvalidDataException("The tokenizer.json model does not contain a 'type' property; this factory only supports 'Unigram' models."); + } + + if (!string.Equals(modelTypeElement.GetString(), "Unigram", StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidDataException($"Expected model type 'Unigram' but found '{modelTypeElement.GetString()}'."); + } + + if (!modelElement.TryGetProperty("unk_id", out JsonElement unkIdElement)) + { + throw new InvalidDataException("The tokenizer.json model does not contain an 'unk_id' property."); + } + + int unkId = unkIdElement.GetInt32(); + + bool byteFallback = modelElement.TryGetProperty("byte_fallback", out JsonElement byteFallbackElement) && + byteFallbackElement.ValueKind == JsonValueKind.True; + + if (!modelElement.TryGetProperty("vocab", out JsonElement vocabElement) || + vocabElement.ValueKind != JsonValueKind.Array) + { + throw new InvalidDataException("The tokenizer.json model does not contain a valid 'vocab' array."); + } + + List<(string Piece, float Score)> vocab = new List<(string Piece, float Score)>(vocabElement.GetArrayLength()); + foreach (JsonElement entry in vocabElement.EnumerateArray()) + { + if (entry.ValueKind != JsonValueKind.Array || entry.GetArrayLength() < 2) + { + throw new InvalidDataException("Each entry in 'model.vocab' must be a [piece, score] array."); + } + + string? piece = entry[0].GetString(); + if (piece is null) + { + throw new InvalidDataException("A piece string in 'model.vocab' is null."); + } + + vocab.Add((piece, entry[1].GetSingle())); + } + + // Extract normalizer settings + byte[]? precompiledCharsMap = null; + bool addDummyPrefix = true; + // HF tokenizer.json has no remove_extra_whitespaces flag; SpmConverter encodes that behavior as + // explicit normalizer steps (a right-Strip plus a Replace collapsing runs of spaces). Deduce it from + // those steps, defaulting to false when absent to match the HF fast-tokenizer runtime. + bool removeExtraWhitespaces = false; + if (root.TryGetProperty("normalizer", out JsonElement normalizerElement) && + normalizerElement.ValueKind == JsonValueKind.Object) + { + precompiledCharsMap = ExtractPrecompiledCharsMap(normalizerElement); + removeExtraWhitespaces = NormalizerCollapsesWhitespace(normalizerElement); + } + + // Extract pre_tokenizer settings + bool escapeWhiteSpaces = true; + bool treatWhitespaceAsSuffix = false; + if (root.TryGetProperty("pre_tokenizer", out JsonElement preTokenizerElement) && + preTokenizerElement.ValueKind == JsonValueKind.Object) + { + ExtractMetaspaceSettings(preTokenizerElement, ref addDummyPrefix, ref escapeWhiteSpaces, ref treatWhitespaceAsSuffix); + } + + // Merge the special tokens declared in added_tokens (authoritative source for their IDs) with any + // caller-supplied special tokens; the caller's entries win on conflict. + Dictionary mergedSpecialTokens = ParseAddedTokens(root); + if (specialTokens is not null) + { + foreach (var kvp in specialTokens) + { + mergedSpecialTokens[kvp.Key] = kvp.Value; + } + } + + // Resolve the prefix/suffix special-token wrapping from the post_processor (if present), falling back + // to the SentencePiece-conventional / names otherwise. + ResolvePostProcessorAffixes(root, vocab, mergedSpecialTokens, + out List<(int Id, string Token)> prefixTokens, out List<(int Id, string Token)> suffixTokens); + + // Ensure every wrapping token is registered as a special token so it is classified Control and round-trips on decode. + foreach (var (id, token) in prefixTokens) + { + mergedSpecialTokens[token] = id; + } + foreach (var (id, token) in suffixTokens) + { + mergedSpecialTokens[token] = id; + } + + int padId = mergedSpecialTokens.TryGetValue("", out int p) ? p : FindPieceId(vocab, ""); + + SentencePieceUnigramModel model = new SentencePieceUnigramModel( + vocab, unkId, addBeginningOfSentence, addEndOfSentence, + precompiledCharsMap is not null ? precompiledCharsMap.AsSpan() : default, + addDummyPrefix, escapeWhiteSpaces, treatWhitespaceAsSuffix, removeExtraWhitespaces, byteFallback, + mergedSpecialTokens.Count > 0 ? mergedSpecialTokens : null, + prefixTokens, suffixTokens, padId); + + return new SentencePieceTokenizer(model); + } + + // Reads the special tokens (those marked "special": true) from the top-level added_tokens array. + private static Dictionary ParseAddedTokens(JsonElement root) + { + Dictionary result = new Dictionary(); + if (!root.TryGetProperty("added_tokens", out JsonElement addedTokens) || addedTokens.ValueKind != JsonValueKind.Array) + { + return result; + } + + foreach (JsonElement entry in addedTokens.EnumerateArray()) + { + if (entry.ValueKind != JsonValueKind.Object) + { + continue; + } + + if (!entry.TryGetProperty("special", out JsonElement specialElement) || specialElement.ValueKind != JsonValueKind.True) + { + continue; + } + + if (entry.TryGetProperty("content", out JsonElement contentElement) && + entry.TryGetProperty("id", out JsonElement idElement) && + contentElement.GetString() is string content) + { + result[content] = idElement.GetInt32(); + } + } + + return result; + } + + // Resolves the ordered prefix/suffix special tokens that wrap an encoded sequence, from the post_processor. + private static void ResolvePostProcessorAffixes( + JsonElement root, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + out List<(int Id, string Token)> prefixTokens, + out List<(int Id, string Token)> suffixTokens) + { + prefixTokens = new List<(int Id, string Token)>(); + suffixTokens = new List<(int Id, string Token)>(); + + if (root.TryGetProperty("post_processor", out JsonElement postProcessor) && + postProcessor.ValueKind == JsonValueKind.Object) + { + ProcessPostProcessor(postProcessor, vocab, specialTokens, prefixTokens, suffixTokens); + return; + } + + // No post_processor: fall back to the SentencePiece-conventional names. + AddAffixToken(prefixTokens, "", vocab, specialTokens, required: false); + AddAffixToken(suffixTokens, "", vocab, specialTokens, required: false); + } + + private static void ProcessPostProcessor( + JsonElement postProcessor, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + List<(int Id, string Token)> prefixTokens, + List<(int Id, string Token)> suffixTokens) + { + string? type = postProcessor.TryGetProperty("type", out JsonElement typeEl) ? typeEl.GetString() : null; + + switch (type) + { + case "TemplateProcessing": + ProcessTemplate(postProcessor, vocab, specialTokens, prefixTokens, suffixTokens); + break; + + case "RobertaProcessing": + AddProcessorAffix(postProcessor, "cls", prefixTokens, vocab, specialTokens); + AddProcessorAffix(postProcessor, "sep", suffixTokens, vocab, specialTokens); + break; + + case "BertProcessing": + AddProcessorAffix(postProcessor, "cls", prefixTokens, vocab, specialTokens); + AddProcessorAffix(postProcessor, "sep", suffixTokens, vocab, specialTokens); + break; + + case "Sequence": + if (postProcessor.TryGetProperty("processors", out JsonElement processors) && processors.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement inner in processors.EnumerateArray()) + { + if (inner.ValueKind == JsonValueKind.Object) + { + ProcessPostProcessor(inner, vocab, specialTokens, prefixTokens, suffixTokens); + } + } + } + break; + + default: + // ByteLevel and other processors do not contribute special-token wrapping; ignore them. + break; + } + } + + // Parses a TemplateProcessing "single" template into leading (prefix) and trailing (suffix) special tokens. + private static void ProcessTemplate( + JsonElement postProcessor, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + List<(int Id, string Token)> prefixTokens, + List<(int Id, string Token)> suffixTokens) + { + if (!postProcessor.TryGetProperty("single", out JsonElement single) || single.ValueKind != JsonValueKind.Array) + { + return; + } + + JsonElement? ppSpecialTokens = postProcessor.TryGetProperty("special_tokens", out JsonElement st) && st.ValueKind == JsonValueKind.Object + ? st : (JsonElement?)null; + + bool seenSequence = false; + foreach (JsonElement item in single.EnumerateArray()) + { + if (item.ValueKind != JsonValueKind.Object) + { + continue; + } + + if (item.TryGetProperty("Sequence", out _)) + { + if (seenSequence) + { + throw new NotSupportedException("tokenizer.json post_processor templates with more than one sequence are not supported."); + } + + seenSequence = true; + } + else if (item.TryGetProperty("SpecialToken", out JsonElement specialToken) && + specialToken.TryGetProperty("id", out JsonElement idEl) && + idEl.GetString() is string tokenName) + { + int id = ResolveTemplateTokenId(tokenName, ppSpecialTokens, specialTokens, vocab); + (seenSequence ? suffixTokens : prefixTokens).Add((id, tokenName)); + } + } + + if (!seenSequence) + { + throw new NotSupportedException("tokenizer.json post_processor template does not contain a sequence placeholder."); + } + } + + private static int ResolveTemplateTokenId( + string tokenName, + JsonElement? ppSpecialTokens, + IReadOnlyDictionary specialTokens, + IReadOnlyList<(string Piece, float Score)> vocab) + { + if (ppSpecialTokens is JsonElement st && + st.TryGetProperty(tokenName, out JsonElement entry) && + entry.TryGetProperty("ids", out JsonElement ids) && + ids.ValueKind == JsonValueKind.Array && + ids.GetArrayLength() > 0) + { + return ids[0].GetInt32(); + } + + if (specialTokens.TryGetValue(tokenName, out int specialId)) + { + return specialId; + } + + int vocabId = FindPieceId(vocab, tokenName); + if (vocabId < 0) + { + throw new InvalidDataException($"The tokenizer.json post_processor references special token '{tokenName}' that is not present in the vocabulary."); + } + + return vocabId; + } + + private static void AddProcessorAffix( + JsonElement postProcessor, + string property, + List<(int Id, string Token)> target, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens) + { + // Roberta/Bert processors store cls/sep as [token, id] arrays. + if (postProcessor.TryGetProperty(property, out JsonElement el) && el.ValueKind == JsonValueKind.Array && el.GetArrayLength() >= 2 && + el[0].GetString() is string token) + { + target.Add((el[1].GetInt32(), token)); + } + } + + private static void AddAffixToken( + List<(int Id, string Token)> target, + string tokenName, + IReadOnlyList<(string Piece, float Score)> vocab, + IReadOnlyDictionary specialTokens, + bool required) + { + int id = specialTokens.TryGetValue(tokenName, out int specialId) ? specialId : FindPieceId(vocab, tokenName); + if (id >= 0) + { + target.Add((id, tokenName)); + } + else if (required) + { + throw new InvalidDataException($"The tokenizer.json does not contain the required special token '{tokenName}'."); + } + } + + private static int FindPieceId(IReadOnlyList<(string Piece, float Score)> vocab, string token) + { + for (int i = 0; i < vocab.Count; i++) + { + if (vocab[i].Piece == token) + { + return i; + } + } + + return -1; + } + + private static byte[]? ExtractPrecompiledCharsMap(JsonElement normalizer) + { + if (!normalizer.TryGetProperty("type", out JsonElement typeEl)) + { + return null; + } + + string? type = typeEl.GetString(); + if (string.Equals(type, "Precompiled", StringComparison.OrdinalIgnoreCase)) + { + if (normalizer.TryGetProperty("precompiled_charsmap", out JsonElement mapEl)) + { + string? base64 = mapEl.GetString(); + if (base64 is not null) + { + return Convert.FromBase64String(base64); + } + } + return null; + } + else if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) && + normalizer.TryGetProperty("normalizers", out JsonElement normalizersEl) && + normalizersEl.ValueKind == JsonValueKind.Array) + { + // A Sequence may legitimately interleave the precompiled map with other steps (Nmt, Replace, ...). + // Extract the precompiled map and ignore the steps we don't model rather than failing the load. + byte[]? result = null; + foreach (JsonElement inner in normalizersEl.EnumerateArray()) + { + if (inner.ValueKind != JsonValueKind.Object) + { + continue; + } + + byte[]? innerResult = ExtractPrecompiledCharsMap(inner); + if (innerResult is not null) + { + result = innerResult; + } + } + return result; + } + + // Other normalizer types (Nmt, Replace, Lowercase, ...) carry no precompiled map; treat as absent. + return null; + } + + // Detects whether the normalizer collapses extra whitespace, i.e. SentencePiece's remove_extra_whitespaces. + // HF's SpmConverter emits this as a right-Strip plus a Replace of a runs-of-spaces Regex (" {2,}") -> "▁". + private static bool NormalizerCollapsesWhitespace(JsonElement normalizer) + { + if (normalizer.ValueKind != JsonValueKind.Object || !normalizer.TryGetProperty("type", out JsonElement typeEl)) + { + return false; + } + + string? type = typeEl.GetString(); + + if (string.Equals(type, "Strip", StringComparison.OrdinalIgnoreCase)) + { + // A right-Strip removes trailing whitespace; treat its presence as the strip half of the behavior. + return !normalizer.TryGetProperty("strip_right", out JsonElement stripRight) || stripRight.ValueKind != JsonValueKind.False; + } + + if (string.Equals(type, "Replace", StringComparison.OrdinalIgnoreCase)) + { + return ReplaceCollapsesSpaces(normalizer); + } + + if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) && + normalizer.TryGetProperty("normalizers", out JsonElement normalizersEl) && + normalizersEl.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement inner in normalizersEl.EnumerateArray()) + { + if (NormalizerCollapsesWhitespace(inner)) + { + return true; + } + } + } + + return false; + } + + // True only for a Replace whose Regex matches runs of two-or-more spaces, not a single-space Metaspace Replace. + private static bool ReplaceCollapsesSpaces(JsonElement replace) + { + if (!replace.TryGetProperty("pattern", out JsonElement patternEl) || + patternEl.ValueKind != JsonValueKind.Object || + !patternEl.TryGetProperty("Regex", out JsonElement regexEl)) + { + return false; + } + + string? pattern = regexEl.GetString(); + if (pattern is null) + { + return false; + } + + // Do not trim: HF's canonical patterns " {2,}" and " +" carry a significant leading space. + switch (pattern) + { + case " {2,}": + case " +": + case "[ ]+": + case "[ ]{2,}": + case "\\s+": + case "\\s{2,}": + return true; + default: + return false; + } + } + + private static void ExtractMetaspaceSettings(JsonElement preTokenizer, ref bool addDummyPrefix, ref bool escapeWhiteSpaces, ref bool treatWhitespaceAsSuffix) + { + if (!preTokenizer.TryGetProperty("type", out JsonElement typeEl)) + { + return; + } + + string? type = typeEl.GetString(); + if (string.Equals(type, "Metaspace", StringComparison.OrdinalIgnoreCase)) + { + if (preTokenizer.TryGetProperty("add_prefix_space", out JsonElement addPrefixEl)) + { + addDummyPrefix = addPrefixEl.GetBoolean(); + } + + if (preTokenizer.TryGetProperty("replacement", out JsonElement replacementEl)) + { + string? replacement = replacementEl.GetString(); + escapeWhiteSpaces = replacement == "\u2581"; // U+2581 LOWER ONE EIGHTH BLOCK (▁) + } + + if (preTokenizer.TryGetProperty("prepend_scheme", out JsonElement prependSchemeEl)) + { + string? scheme = prependSchemeEl.GetString(); + // "never" suppresses the dummy prefix; "always"/"first" keep the default (true) + if (string.Equals(scheme, "never", StringComparison.OrdinalIgnoreCase)) + { + addDummyPrefix = false; + } + } + } + else if (string.Equals(type, "Sequence", StringComparison.OrdinalIgnoreCase) && + preTokenizer.TryGetProperty("pretokenizers", out JsonElement preTokenizersEl) && + preTokenizersEl.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement inner in preTokenizersEl.EnumerateArray()) + { + ExtractMetaspaceSettings(inner, ref addDummyPrefix, ref escapeWhiteSpaces, ref treatWhitespaceAsSuffix); + } + } + } } } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs index 3714206cf0..5ecb08c69b 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs @@ -22,6 +22,8 @@ internal sealed class SentencePieceUnigramModel : SentencePieceBaseModel private readonly DoubleArrayTrie _trie; private readonly float _minScore; private readonly float _maxScore; + private readonly (int Id, string Token)[] _prefixTokens; + private readonly (int Id, string Token)[] _suffixTokens; private const float UnkPenalty = 10.0f; public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : base(modelProto, addBos, addEos, specialTokens) @@ -91,6 +93,252 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos _vocab[modelProto.TrainerSpec.PadPiece] = modelProto.TrainerSpec.PadId; _vocabReverse[modelProto.TrainerSpec.PadId] = (modelProto.TrainerSpec.PadPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control); } + + _prefixTokens = DefaultAffix(BeginningOfSentenceId, BeginningOfSentenceToken); + _suffixTokens = DefaultAffix(EndOfSentenceId, EndOfSentenceToken); + } + + // Constructor that builds a Unigram model directly from a list of (piece, score) pairs. + // BOS, EOS, and PAD tokens are identified by their names ("", "", "") in the vocab; + // if not found by name, they are treated as absent (id = -1) to avoid misidentifying real pieces. + internal SentencePieceUnigramModel( + IReadOnlyList<(string Piece, float Score)> pieces, + int unkId, + bool addBos, + bool addEos, + ReadOnlySpan precompiledCharsmap, + bool addDummyPrefix, + bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, + bool removeExtraWhitespaces, + bool byteFallback, + IReadOnlyDictionary? specialTokens) + : this(pieces, unkId, addBos, addEos, precompiledCharsmap, addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, removeExtraWhitespaces, byteFallback, specialTokens, + CheckSpecialId(addBos, FindSpecialTokenId(ValidateVocab(pieces, unkId), ""), "addBeginningOfSentence"), + CheckSpecialId(addEos, FindSpecialTokenId(pieces, ""), "addEndOfSentence"), + FindSpecialTokenId(pieces, ""), prefixTokens: null, suffixTokens: null) + { + } + + // Constructor that builds a Unigram model with explicit prefix/suffix special-token lists, for example + // resolved from a tokenizer.json post_processor template. addBeginningOfSentence gates the prefix list + // and addEndOfSentence gates the suffix list; an empty list is allowed (no tokens are emitted). + internal SentencePieceUnigramModel( + IReadOnlyList<(string Piece, float Score)> pieces, + int unkId, + bool addBos, + bool addEos, + ReadOnlySpan precompiledCharsmap, + bool addDummyPrefix, + bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, + bool removeExtraWhitespaces, + bool byteFallback, + IReadOnlyDictionary? specialTokens, + IReadOnlyList<(int Id, string Token)> prefixTokens, + IReadOnlyList<(int Id, string Token)> suffixTokens, + int padId) + : this(pieces, unkId, addBos, addEos, precompiledCharsmap, addDummyPrefix, escapeWhiteSpaces, + treatWhitespaceAsSuffix, removeExtraWhitespaces, byteFallback, specialTokens, + FirstId(prefixTokens), FirstId(suffixTokens), padId, prefixTokens, suffixTokens) + { + } + + private SentencePieceUnigramModel( + IReadOnlyList<(string Piece, float Score)> pieces, + int unkId, + bool addBos, + bool addEos, + ReadOnlySpan precompiledCharsmap, + bool addDummyPrefix, + bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, + bool removeExtraWhitespaces, + bool byteFallback, + IReadOnlyDictionary? specialTokens, + int bosId, int eosId, int padId, + IReadOnlyList<(int Id, string Token)>? prefixTokens, + IReadOnlyList<(int Id, string Token)>? suffixTokens) + : base(addBos, addEos, + bosId >= 0 && bosId < GetPieceCount(pieces) ? pieces[bosId].Piece : "", bosId, + eosId >= 0 && eosId < GetPieceCount(pieces) ? pieces[eosId].Piece : "", eosId, + GetPieceAtIndex(pieces, unkId), unkId, + addDummyPrefix, escapeWhiteSpaces, treatWhitespaceAsSuffix, byteFallback, + precompiledCharsmap, removeExtraWhitespaces, specialTokens) + { + Debug.Assert(pieces is not null); + + _vocab = new SortedDictionary(OrdinalUtf8StringComparer.Instance); + _vocabReverse = new (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[pieces!.Count]; + _minScore = float.MaxValue; + _maxScore = float.MinValue; + + // Control tokens (BOS/EOS/PAD plus any caller- or added_tokens-supplied special tokens) are kept + // out of the trie so normal segmentation never produces them; they are re-inserted afterwards. + HashSet controlIds = new HashSet(); + AddControlId(controlIds, bosId); + AddControlId(controlIds, eosId); + AddControlId(controlIds, padId); + if (specialTokens is not null) + { + foreach (int specialId in specialTokens.Values) + { + AddControlId(controlIds, specialId); + } + } + + for (int i = 0; i < pieces.Count; i++) + { + var (piece, score) = pieces[i]; + if (i == unkId) + { + _vocabReverse[i] = (piece, score, ModelProto.Types.SentencePiece.Types.Type.Unknown); + } + else if (controlIds.Contains(i)) + { + _vocabReverse[i] = (piece, score, ModelProto.Types.SentencePiece.Types.Type.Control); + } + else + { + _vocabReverse[i] = (piece, score, ModelProto.Types.SentencePiece.Types.Type.Normal); + _vocab.Add(piece, i); + _minScore = Math.Min(_minScore, score); + _maxScore = Math.Max(_maxScore, score); + } + } + + ByteCodeToIdOffset = _vocab.TryGetValue("<0x00>", out int id) ? id : MaxByteId; + OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; + MaxIdByteFallbackId = ByteCodeToIdOffset + 0xFF; + + _trie = new DoubleArrayTrie(_vocab); + + // Re-insert special tokens into the vocab maps after the trie is built so they map like regular tokens. + string unkToken = pieces[unkId].Piece; + _vocab[unkToken] = unkId; + _vocabReverse[unkId] = (unkToken, 0f, ModelProto.Types.SentencePiece.Types.Type.Unknown); + + foreach (int controlId in controlIds) + { + if (controlId == unkId) + { + continue; // unk is classified Unknown above; don't downgrade it to Control. + } + + if (controlId >= 0 && controlId < pieces.Count) + { + string piece = pieces[controlId].Piece; + _vocab[piece] = controlId; + _vocabReverse[controlId] = (piece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control); + } + } + + _prefixTokens = prefixTokens is not null ? ToAffixArray(prefixTokens) : DefaultAffix(BeginningOfSentenceId, BeginningOfSentenceToken); + _suffixTokens = suffixTokens is not null ? ToAffixArray(suffixTokens) : DefaultAffix(EndOfSentenceId, EndOfSentenceToken); + } + + private static (int Id, string Token)[] DefaultAffix(int id, string token) + => id >= 0 ? new[] { (id, token) } : Array.Empty<(int, string)>(); + + private static (int Id, string Token)[] ToAffixArray(IReadOnlyList<(int Id, string Token)> tokens) + { + var array = new (int Id, string Token)[tokens.Count]; + for (int i = 0; i < tokens.Count; i++) + { + array[i] = tokens[i]; + } + + return array; + } + + private static int FirstId(IReadOnlyList<(int Id, string Token)> tokens) => tokens.Count > 0 ? tokens[0].Id : -1; + + private void AddPrefixTokens(List tokens) + { + foreach (var (id, token) in _prefixTokens) + { + tokens.Add(new EncodedToken(id, token, new Range(0, 0))); + } + } + + private void AddSuffixTokens(List tokens, int offset) + { + foreach (var (id, token) in _suffixTokens) + { + tokens.Add(new EncodedToken(id, token, new Range(offset, offset))); + } + } + + private static void AddControlId(HashSet set, int id) + { + if (id >= 0) + { + set.Add(id); + } + } + + private static int GetPieceCount(IReadOnlyList<(string Piece, float Score)>? pieces) + => pieces?.Count ?? 0; + + private static string GetPieceAtIndex(IReadOnlyList<(string Piece, float Score)>? pieces, int index) + { + if (pieces is null) + { + throw new ArgumentNullException("vocab"); + } + + if ((uint)index >= (uint)pieces.Count) + { + throw new ArgumentOutOfRangeException("unkId", "unkId must be a valid index in the vocabulary."); + } + + return pieces[index].Piece; + } + + // Validates pieces is not null and unkId is in range; returns pieces unchanged. + private static IReadOnlyList<(string Piece, float Score)> ValidateVocab( + IReadOnlyList<(string Piece, float Score)>? pieces, int unkId) + { + if (pieces is null) + { + throw new ArgumentNullException("vocab"); + } + + if ((uint)unkId >= (uint)pieces.Count) + { + throw new ArgumentOutOfRangeException("unkId", "unkId must be a valid index in the vocabulary."); + } + + return pieces; + } + + // Finds a special token by name; returns -1 if not found. + private static int FindSpecialTokenId(IReadOnlyList<(string Piece, float Score)>? pieces, string tokenName) + { + if (pieces is null) + { + return -1; + } + + for (int i = 0; i < pieces.Count; i++) + { + if (pieces[i].Piece == tokenName) + { + return i; + } + } + + return -1; + } + + private static int CheckSpecialId(bool required, int id, string paramName) + { + if (required && id < 0) + { + throw new ArgumentException($"The vocabulary does not contain the required special token.", paramName); + } + return id; } public override IReadOnlyDictionary Vocabulary => new ReadOnlyDictionary(_vocab); @@ -218,7 +466,7 @@ private void EncodeToTokensWithSpecialTokens( if (addBeginningOfSentence) { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + AddPrefixTokens(tokens); } int currentOffset = 0; @@ -250,7 +498,7 @@ private void EncodeToTokensWithSpecialTokens( if (addEndOfSentence) { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(progressOffset, progressOffset))); + AddSuffixTokens(tokens, progressOffset); } normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); @@ -268,7 +516,7 @@ private void EncodeToTokensWithoutSpecialTokens( { if (addBeginningOfSentence) { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + AddPrefixTokens(tokens); } int progressOffset = 0; @@ -278,7 +526,7 @@ private void EncodeToTokensWithoutSpecialTokens( if (addEndOfSentence) { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(progressOffset, progressOffset))); + AddSuffixTokens(tokens, progressOffset); } normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); @@ -571,12 +819,15 @@ public override IReadOnlyList EncodeToIds( if (addBeginningOfSentence) { - ids.Add(BeginningOfSentenceId); - if (maxTokenCount == 1) + foreach (var (id, _) in _prefixTokens) { - normalizedText = null; - charsConsumed = 0; - return ids; // done. no more space for anything else. + ids.Add(id); + if (ids.Count >= maxTokenCount) + { + normalizedText = null; + charsConsumed = 0; + return ids; // done. no more space for anything else. + } } } @@ -595,9 +846,17 @@ public override IReadOnlyList EncodeToIds( EncodeToIdsWithoutSpecialTokens(textToEncode, considerNormalization, ids, buffer, ref normalizedString, out normalizedText, out charsConsumed, maxTokenCount); } - if (addEndOfSentence && ids.Count < maxTokenCount) + if (addEndOfSentence) { - ids.Add(EndOfSentenceId); + foreach (var (id, _) in _suffixTokens) + { + if (ids.Count >= maxTokenCount) + { + break; + } + + ids.Add(id); + } } if (normalizedString is not null) @@ -960,13 +1219,15 @@ public override int CountTokens( if (addBeginningOfSentence) { - tokenCount++; - - if (maxTokenCount == 1) + foreach (var _ in _prefixTokens) { - normalizedText = null; - charsConsumed = 0; - return tokenCount; + tokenCount++; + if (tokenCount >= maxTokenCount) + { + normalizedText = null; + charsConsumed = 0; + return tokenCount; + } } } @@ -987,7 +1248,15 @@ public override int CountTokens( if (addEndOfSentence && tokenCount < maxTokenCount) { - tokenCount++; + foreach (var _ in _suffixTokens) + { + if (tokenCount >= maxTokenCount) + { + break; + } + + tokenCount++; + } } if (normalizedString is not null) @@ -1228,12 +1497,14 @@ public override int GetIndexByTokenCountFromEnd( if (addEndOfSentence) { - tokenCount++; - - if (maxTokenCount == 1) + foreach (var _ in _suffixTokens) { - normalizedText = null; - return textToEncode.Length; + tokenCount++; + if (tokenCount >= maxTokenCount) + { + normalizedText = null; + return textToEncode.Length; + } } } @@ -1256,7 +1527,15 @@ public override int GetIndexByTokenCountFromEnd( if (addBeginningOfSentence && tokenCount < maxTokenCount) { - tokenCount++; + foreach (var _ in _prefixTokens) + { + if (tokenCount >= maxTokenCount) + { + break; + } + + tokenCount++; + } } ArrayPool.Shared.Return(buffer); diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs index ca671ddebe..53343aa8d5 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs @@ -562,5 +562,502 @@ public void SpecialTokensTest() Assert.Equal("", _unigramTokenizer.EndOfSentenceToken); Assert.Equal(2, _unigramTokenizer.EndOfSentenceId); } + + [Fact] + public void CreateFromVocabTest() + { + // Build a minimal synthetic Unigram vocab: =0, =1, =2, then normal tokens + var vocab = new List<(string Piece, float Score)> + { + ("", 0f), + ("", 0f), + ("", 0f), + ("▁Hello", -1f), + (",", -2f), + ("▁world", -3f), + ("!", -4f), + }; + + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.Create( + vocab, unkId: 0, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal("", tokenizer.UnknownToken); + Assert.Equal(0, tokenizer.UnknownId); + Assert.Equal("", tokenizer.BeginningOfSentenceToken); + Assert.Equal(1, tokenizer.BeginningOfSentenceId); + Assert.Equal("", tokenizer.EndOfSentenceToken); + Assert.Equal(2, tokenizer.EndOfSentenceId); + + IReadOnlyList ids = tokenizer.EncodeToIds("Hello, world!", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(new[] { 3, 4, 5, 6 }, ids); + + string decoded = tokenizer.Decode(ids, considerSpecialTokens: false); + Assert.Equal("Hello, world!", decoded); + } + + [Fact] + public void CreateFromVocabNullTest() + { + Assert.Throws(() => + SentencePieceTokenizer.Create((IEnumerable<(string Piece, float Score)>)null!, unkId: 0)); + } + + [Fact] + public void CreateFromVocabInvalidUnkIdTest() + { + var vocab = new List<(string Piece, float Score)> { ("a", 0f) }; + Assert.Throws(() => + SentencePieceTokenizer.Create(vocab, unkId: 5)); + } + + [Fact] + public void CreateFromTokenizerJsonTest() + { + using Stream jsonStream = File.OpenRead(Path.Combine("Paraphrase-multilingual-MiniLM-L12-v2", "tokenizer.json")); + SentencePieceTokenizer jsonTokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + jsonStream, addBeginningOfSentence: false, addEndOfSentence: false); + + // The tokenizer.json vocab has =0, =1, =2, =3, then normal tokens + // (shifted +1 relative to .model which has =0, =1, =2) + Assert.Equal("", jsonTokenizer.UnknownToken); + Assert.Equal(3, jsonTokenizer.UnknownId); + Assert.Equal("", jsonTokenizer.BeginningOfSentenceToken); + Assert.Equal(0, jsonTokenizer.BeginningOfSentenceId); + Assert.Equal("", jsonTokenizer.EndOfSentenceToken); + Assert.Equal(2, jsonTokenizer.EndOfSentenceId); + + // Pieces produced should match the .model tokenizer; IDs are shifted by +1 + IReadOnlyList jsonTokens = jsonTokenizer.EncodeToTokens("Hello, world!", out _, addBeginningOfSentence: false, addEndOfSentence: false); + IReadOnlyList modelTokens = _unigramTokenizer.EncodeToTokens("Hello, world!", out _, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(modelTokens.Count, jsonTokens.Count); + for (int i = 0; i < modelTokens.Count; i++) + { + Assert.Equal(modelTokens[i].Value, jsonTokens[i].Value); + // JSON IDs are offset by 1 from the .model IDs for normal tokens + Assert.Equal(modelTokens[i].Id + 1, jsonTokens[i].Id); + } + } + + [Fact] + public void CreateFromTokenizerJsonNullStreamTest() + { + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(null!)); + } + + [Fact] + public void CreateFromTokenizerJsonNormalizationTest() + { + // Verify that the JSON tokenizer applies the precompiled charsmap normalization + // (same normalization as the .model tokenizer) + using Stream jsonStream = File.OpenRead(Path.Combine("Paraphrase-multilingual-MiniLM-L12-v2", "tokenizer.json")); + SentencePieceTokenizer jsonTokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + jsonStream, addBeginningOfSentence: false, addEndOfSentence: false); + + // "㍻" normalizes to "平成" via the precompiled charsmap (NFKC normalization) + IReadOnlyList jsonIds = jsonTokenizer.EncodeToIds("㍻", addBeginningOfSentence: false, addEndOfSentence: false); + IReadOnlyList modelIds = _unigramTokenizer.EncodeToIds("㍻", addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(modelIds.Count, jsonIds.Count); + for (int i = 0; i < modelIds.Count; i++) + { + Assert.Equal(modelIds[i] + 1, jsonIds[i]); + } + } + + [Fact] + public void CreateFromVocabNoSpecialTokensTest() + { + // Vocab without // — resembles bge-m3/potion layout. + // Verify that real pieces (e.g. ",") are not marked Control and remain encodable. + var vocab = new List<(string Piece, float Score)> + { + ("[PAD]", 0f), // 0 + ("[UNK]", 0f), // 1 + (",", -1f), // 2 + ("▁Hello", -2f), // 3 + ("▁world", -3f), // 4 + ("!", -4f), // 5 + }; + + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.Create( + vocab, unkId: 1, addBeginningOfSentence: false, addEndOfSentence: false); + + // "," must be in the vocabulary and encodable (not silently dropped as Control) + IReadOnlyList ids = tokenizer.EncodeToIds("Hello, world!", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Contains(2, ids); // id 2 is "," + } + + [Fact] + public void CreateFromVocabBosRequiredButAbsentTest() + { + // Vocab without : addBeginningOfSentence:true should throw rather than emit index 0. + var vocab = new List<(string Piece, float Score)> + { + ("[UNK]", 0f), + ("▁Hello", -1f), + }; + + Assert.Throws(() => + SentencePieceTokenizer.Create(vocab, unkId: 0, addBeginningOfSentence: true)); + } + + [Fact] + public void CreateFromTokenizerJsonSequenceNormalizerWithExtraStepsTest() + { + // A Sequence normalizer that interleaves the precompiled map with other steps (e.g. Replace) + // is common in real tokenizers; the precompiled map is extracted and the other steps are ignored + // rather than failing the load. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": { + "type": "Sequence", + "normalizers": [ + { "type": "Nmt" }, + { "type": "Precompiled", "precompiled_charsmap": "" }, + { "type": "Replace", "pattern": " ", "content": "_" } + ] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false); + Assert.NotNull(tokenizer); + } + + // Vocab shared by the remove_extra_whitespaces deduction tests; "▁" is its own piece so a preserved + // extra space surfaces as an extra token. + private const string WhitespaceDeductionVocab = + "\"vocab\": [[\"\", 0.0], [\"\u2581a\", -1.0], [\"\u2581b\", -1.0], [\"\u2581\", -3.0], [\"a\", -10.0], [\"b\", -10.0]]"; + + [Fact] + public void CreateFromTokenizerJsonDeducesRemoveExtraWhitespacesFromReplaceStep() + { + // HF encodes remove_extra_whitespaces as a Strip + Replace(" {2,}" -> "▁"); the collapsing Replace + // ALONE (no sibling Strip) must enable whitespace collapsing so "a b" collapses to two pieces. + string json = $$""" + { + "normalizer": { + "type": "Sequence", + "normalizers": [ + { "type": "Replace", "pattern": { "Regex": " {2,}" }, "content": "\u2581" } + ] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + {{WhitespaceDeductionVocab}} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(2, tokenizer.CountTokens("a b", addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonDeducesRemoveExtraWhitespacesFromStripStep() + { + // A right-Strip alone (no Replace) also marks the behavior. + string json = $$""" + { + "normalizer": { + "type": "Sequence", + "normalizers": [ + { "type": "Strip", "strip_left": false, "strip_right": true } + ] + }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + {{WhitespaceDeductionVocab}} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(2, tokenizer.CountTokens("a b", addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNoCollapseStepPreservesExtraWhitespace() + { + // Without a Strip/Replace collapsing step (e.g. older bare-Precompiled files), remove_extra_whitespaces + // is deduced false to match the HF fast-tokenizer runtime, so the extra space is preserved as a token. + string json = $$""" + { + "normalizer": { "type": "Precompiled", "precompiled_charsmap": "" }, + "pre_tokenizer": { "type": "Metaspace", "replacement": "\u2581", "add_prefix_space": true }, + "model": { + "type": "Unigram", + "unk_id": 0, + {{WhitespaceDeductionVocab}} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(3, tokenizer.CountTokens("a b", addBeginningOfSentence: false, addEndOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNullNormalizerTest() + { + // A null normalizer value in JSON should not throw. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "normalizer": null + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + stream, addBeginningOfSentence: false); + Assert.NotNull(tokenizer); + } + + [Fact] + public void CreateFromVocabAbsentBosNotDecodedAsIdZeroTest() + { + // Vocab without /. With the add flags off, BOS/EOS must stay absent (-1) + // rather than being clamped to 0, so id 0 decodes as its real piece. + var vocab = new List<(string Piece, float Score)> + { + ("", 0f), // 0 + ("▁Hello", -1f), // 1 + ("▁world", -2f), // 2 + }; + + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.Create( + vocab, unkId: 0, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(-1, tokenizer.BeginningOfSentenceId); + Assert.Equal(-1, tokenizer.EndOfSentenceId); + + // id 0 is , not BOS; decoding it with considerSpecialTokens must yield the unk piece. + string decoded = tokenizer.Decode(new[] { 0 }, considerSpecialTokens: true); + Assert.Equal("", decoded); + } + + [Fact] + public void CreateFromTokenizerJsonMissingModelTypeTest() + { + string json = """ + { + "model": { + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNonUnigramModelTypeTest() + { + string json = """ + { + "model": { + "type": "BPE", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNullModelTest() + { + string json = """ + { + "model": null + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false)); + } + + [Fact] + public void CreateFromTokenizerJsonNullPreTokenizerTest() + { + // A null pre_tokenizer value in JSON should not throw. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "pre_tokenizer": null + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson( + stream, addBeginningOfSentence: false); + Assert.NotNull(tokenizer); + } + + [Fact] + public void CreateFromTokenizerJsonTemplateMultiTokenSuffixTest() + { + // XLNet-style template: the sequence is followed by two special tokens ( then ). + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 1, + "vocab": [["", 0.0], ["", 0.0], ["", 0.0], ["a", -1.0], ["b", -2.0]] + }, + "added_tokens": [ + { "id": 0, "content": "", "special": true }, + { "id": 1, "content": "", "special": true }, + { "id": 2, "content": "", "special": true } + ], + "pre_tokenizer": { "type": "Metaspace", "add_prefix_space": false, "replacement": "_" }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { "Sequence": { "id": "A", "type_id": 0 } }, + { "SpecialToken": { "id": "", "type_id": 0 } }, + { "SpecialToken": { "id": "", "type_id": 0 } } + ], + "special_tokens": { + "": { "id": "", "ids": [0], "tokens": [""] }, + "": { "id": "", "ids": [2], "tokens": [""] } + } + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(0, tokenizer.EndOfSentenceId); + Assert.Equal("", tokenizer.EndOfSentenceToken); + + IReadOnlyList withSuffix = tokenizer.EncodeToIds("a", addBeginningOfSentence: false, addEndOfSentence: true); + Assert.Equal(new[] { 3, 0, 2 }, withSuffix); + + IReadOnlyList withoutSuffix = tokenizer.EncodeToIds("a", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(new[] { 3 }, withoutSuffix); + + Assert.Equal("a", tokenizer.Decode(withSuffix, considerSpecialTokens: false)); + } + + [Fact] + public void CreateFromTokenizerJsonRobertaProcessingTest() + { + // RobertaProcessing wraps the sequence with cls () at the front and sep () at the end. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 1, + "vocab": [["", 0.0], ["", 0.0], ["", 0.0], ["a", -1.0]] + }, + "added_tokens": [ + { "id": 0, "content": "", "special": true }, + { "id": 2, "content": "", "special": true } + ], + "pre_tokenizer": { "type": "Metaspace", "add_prefix_space": false, "replacement": "_" }, + "post_processor": { "type": "RobertaProcessing", "sep": ["", 2], "cls": ["", 0] } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + Assert.Equal(0, tokenizer.BeginningOfSentenceId); + Assert.Equal("", tokenizer.BeginningOfSentenceToken); + Assert.Equal(2, tokenizer.EndOfSentenceId); + Assert.Equal("", tokenizer.EndOfSentenceToken); + + IReadOnlyList ids = tokenizer.EncodeToIds("a", addBeginningOfSentence: true, addEndOfSentence: true); + Assert.Equal(new[] { 0, 3, 2 }, ids); + } + + [Fact] + public void CreateFromTokenizerJsonAddedTokenRecognizedTest() + { + // A special token from added_tokens that is not // must still be recognized as atomic. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 1, + "vocab": [["", 0.0], ["", 0.0], ["", 0.0], ["a", -1.0], ["", -5.0]] + }, + "added_tokens": [ + { "id": 4, "content": "", "special": true } + ], + "pre_tokenizer": { "type": "Metaspace", "add_prefix_space": false, "replacement": "_" } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + SentencePieceTokenizer tokenizer = SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false); + + IReadOnlyList ids = tokenizer.EncodeToIds("aa", addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(new[] { 3, 4, 3 }, ids); + } + + [Fact] + public void CreateFromTokenizerJsonTemplateMultiSequenceThrowsTest() + { + // A template with more than one sequence placeholder cannot be represented and must be rejected. + string json = """ + { + "model": { + "type": "Unigram", + "unk_id": 0, + "vocab": [["", 0.0], ["a", -1.0]] + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { "Sequence": { "id": "A", "type_id": 0 } }, + { "Sequence": { "id": "B", "type_id": 0 } } + ], + "special_tokens": {} + } + } + """; + + using Stream stream = new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(json)); + Assert.Throws(() => + SentencePieceTokenizer.CreateFromTokenizerJson(stream, addBeginningOfSentence: false, addEndOfSentence: false)); + } } }