Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 9 additions & 27 deletions google/genai/_local_tokenizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
import hashlib
import os
import tempfile
from typing import Any, Optional, cast
from typing import Optional, cast
import uuid

import requests # type: ignore
import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2
from transformers import AutoProcessor


# Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
Expand All @@ -48,26 +47,21 @@
"gemini-3-pro-preview": "gemma3",
}

# https://github.com/google/gemma_pytorch stop supporting gemma 4 moving forward.
_GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES = {
"gemini-3.5-flash": "gemma4",
"gemini-3.1-flash-lite": "gemma4",
"gemini-3.1-pro-preview": "gemma4",
"gemini-4-flash-preview": "gemma4",
}

GEMMA_TOKENIZER_TO_MODEL_NAMES = {
"gemma4": "google/gemma-4-E4B-it",
}


@dataclasses.dataclass(frozen=True)
class _TokenizerConfig:
model_url: str
model_hash: str


# TODO: update gemma3 tokenizer
_TOKENIZERS = {
"gemma2": _TokenizerConfig(
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
model_hash=(
"61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
),
),
"gemma3": _TokenizerConfig(
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
model_hash=(
Expand Down Expand Up @@ -183,7 +177,7 @@ def _load_model_proto_bytes(tokenizer_name: str) -> bytes:
"""Loads model proto bytes from the given tokenizer name."""
if tokenizer_name not in _TOKENIZERS:
raise ValueError(
f"Tokenizer {tokenizer_name} is not supported. "
f"Tokenizer {tokenizer_name} is not supported."
f"Supported tokenizers: {list(_TOKENIZERS.keys())}"
)
return _load(
Expand All @@ -208,23 +202,11 @@ def get_tokenizer_name(model_name: str) -> str:
return _GEMINI_MODELS_TO_TOKENIZER_NAMES[model_name]
if model_name in _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys():
return _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES[model_name]
if model_name in _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES.keys():
return _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES[model_name]
raise ValueError(
f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_MODELS_TO_TOKENIZER_NAMES.keys())}, {', '.join(_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys())}.\n" # pylint: disable=line-too-long
)


def get_huggingface_tokenizer(tokenizer_name: str) -> Any:
"""Loads huggingface tokenizer from the given tokenizer name."""
# Load the processor which includes the tokenizer
processor = AutoProcessor.from_pretrained( # type: ignore[no-untyped-call]
GEMMA_TOKENIZER_TO_MODEL_NAMES[tokenizer_name]
)
# Access the underlying tokenizer if needed
return processor.tokenizer


@functools.lru_cache()
def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor:
"""Loads sentencepiece tokenizer from the given tokenizer name."""
Expand Down
41 changes: 9 additions & 32 deletions google/genai/local_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,8 @@ class LocalTokenizer:

def __init__(self, model_name: str):
self._tokenizer_name = loader.get_tokenizer_name(model_name)
self._model_proto = None
if self._tokenizer_name in loader.GEMMA_TOKENIZER_TO_MODEL_NAMES:
self._tokenizer = loader.get_huggingface_tokenizer(self._tokenizer_name)
else:
self._model_proto = loader.load_model_proto(self._tokenizer_name)
self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)
self._model_proto = loader.load_model_proto(self._tokenizer_name)
self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)

@_common.experimental_warning(
"The SDK's local tokenizer implementation is experimental and may change"
Expand Down Expand Up @@ -369,46 +365,27 @@ def compute_tokens(
# tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')]
"""
processed_contents = t.t_contents(contents)
roles = []

text_accumulator = _TextsAccumulator()
for content in processed_contents:
text_accumulator.add_content(content)
tokens_protos = self._tokenizer.EncodeAsImmutableProto(
text_accumulator.get_texts()
)

roles = []
for content in processed_contents:
if content.parts:
for _ in content.parts:
roles.append(content.role)

token_infos = []
if self._tokenizer_name in loader.GEMMA_TOKENIZER_TO_MODEL_NAMES:
# Use the HuggingFace tokenizer since gemma_pytorch is not available for
# gemma 4+.
token_ids = self._tokenizer.encode(list(text_accumulator.get_texts()))
for token_id, role in zip(token_ids, roles):
token_infos.append(
types.TokensInfo(
token_ids=token_id,
tokens=[
token.replace("_", " ")
.encode("utf-8")
.replace(b"\xe2\x96\x81", b" ")
for token in self._tokenizer.convert_ids_to_tokens(token_id)
],
role=role,
)
)
return types.ComputeTokensResult(tokens_info=token_infos)

tokens_protos = self._tokenizer.EncodeAsImmutableProto(
text_accumulator.get_texts()
)

for tokens_proto, role in zip(tokens_protos, roles):
token_infos.append(
types.TokensInfo(
token_ids=[piece.id for piece in tokens_proto.pieces],
tokens=[
_token_str_to_bytes(
piece.piece, self._model_proto.pieces[piece.id].type # type: ignore[union-attr]
piece.piece, self._model_proto.pieces[piece.id].type
)
for piece in tokens_proto.pieces
],
Expand Down
73 changes: 3 additions & 70 deletions google/genai/tests/local_tokenizer/test_local_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def setUp(self):
self.mock_load_model_proto = patch(
'genai._local_tokenizer_loader.load_model_proto'
).start()
self.addCleanup(patch.stopall)
self.mock_get_sentencepiece = patch(
'genai._local_tokenizer_loader.get_sentencepiece'
).start()
Expand All @@ -40,6 +39,9 @@ def setUp(self):

self.tokenizer = local_tokenizer.LocalTokenizer(model_name='gemini-3-pro-preview')

def tearDown(self):
patch.stopall()

def test_count_tokens_simple_string(self):
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
result = self.tokenizer.count_tokens('Hello world')
Expand Down Expand Up @@ -339,72 +341,3 @@ def test_invalid_format(self):
def test_invalid_hex_value(self):
with self.assertRaisesRegex(ValueError, 'Invalid hex value'):
local_tokenizer._parse_hex_byte('<0xFG>')


class TestLocalTokenizerHuggingFace(unittest.TestCase):

def setUp(self):
self.mock_get_huggingface_tokenizer = patch(
'genai._local_tokenizer_loader.get_huggingface_tokenizer'
).start()
self.addCleanup(patch.stopall)

self.mock_tokenizer = MagicMock()
self.mock_get_huggingface_tokenizer.return_value = self.mock_tokenizer

# gemini-3.5-flash maps to gemma4 (HuggingFace)
self.tokenizer = local_tokenizer.LocalTokenizer(model_name='gemini-3.5-flash')

def test_count_tokens_simple_string(self):
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
result = self.tokenizer.count_tokens('Hello world')
self.assertEqual(result.total_tokens, 3)
self.mock_tokenizer.encode.assert_called_once_with(['Hello world'])

def test_compute_tokens_simple_string(self):
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
self.mock_tokenizer.convert_ids_to_tokens.return_value = ['He', 'llo', ' world']

result = self.tokenizer.compute_tokens('Hello world')

self.assertEqual(len(result.tokens_info), 1)
self.assertEqual(result.tokens_info[0].token_ids, [1, 2, 3])
self.assertEqual(result.tokens_info[0].tokens, [b'He', b'llo', b' world'])
self.assertEqual(result.tokens_info[0].role, 'user')

self.mock_tokenizer.encode.assert_called_once_with(['Hello world'])
self.mock_tokenizer.convert_ids_to_tokens.assert_called_once_with([1, 2, 3])

def test_compute_tokens_special_characters(self):
self.mock_tokenizer.encode.return_value = [[1, 2]]
# Use U+2581 (lower one eighth block) and underscore
self.mock_tokenizer.convert_ids_to_tokens.return_value = ['_world', '\u2581hello']

result = self.tokenizer.compute_tokens('dummy')

self.assertEqual(result.tokens_info[0].tokens, [b' world', b' hello'])

def test_compute_tokens_with_chat_history(self):
self.mock_tokenizer.encode.return_value = [[1], [2, 3]]
self.mock_tokenizer.convert_ids_to_tokens.side_effect = [
['Hello'],
['Hi', ' there!']
]
history = [
types.Content(role='user', parts=[types.Part(text='Hello')]),
types.Content(role='model', parts=[types.Part(text='Hi there!')]),
]
result = self.tokenizer.compute_tokens(history)
self.assertEqual(len(result.tokens_info), 2)
self.assertEqual(result.tokens_info[0].token_ids, [1])
self.assertEqual(result.tokens_info[0].tokens, [b'Hello'])
self.assertEqual(result.tokens_info[0].role, 'user')
self.assertEqual(result.tokens_info[1].token_ids, [2, 3])
self.assertEqual(result.tokens_info[1].tokens, [b'Hi', b' there!'])
self.assertEqual(result.tokens_info[1].role, 'model')

self.mock_tokenizer.encode.assert_called_once_with(['Hello', 'Hi there!'])
self.mock_tokenizer.convert_ids_to_tokens.assert_has_calls([
unittest.mock.call([1]),
unittest.mock.call([2, 3])
])
59 changes: 13 additions & 46 deletions google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
]
).SerializeToString()

GEMMA3_HASH = "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c"
GEMMA2_HASH = "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"


class TestGetTokenizerName(unittest.TestCase):
Expand All @@ -58,18 +58,6 @@ def test_get_tokenizer_name_success(self):
loader.get_tokenizer_name("gemini-2.5-pro-preview-06-05"), "gemma3"
)

def test_get_tokenizer_name_huggingface(self):
self.assertEqual(loader.get_tokenizer_name("gemini-3.5-flash"), "gemma4")
self.assertEqual(
loader.get_tokenizer_name("gemini-3.1-flash-lite"), "gemma4"
)
self.assertEqual(
loader.get_tokenizer_name("gemini-3.1-pro-preview"), "gemma4"
)
self.assertEqual(
loader.get_tokenizer_name("gemini-4-flash-preview"), "gemma4"
)

def test_get_tokenizer_name_unsupported(self):
with self.assertRaisesRegex(
ValueError, "Model unsupported-model is not supported"
Expand Down Expand Up @@ -117,9 +105,9 @@ def test_load_model_proto_from_url(
):
mock_exists.return_value = False # Don't use cache
self._setup_get_mock(mock_get)
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH

proto = loader.load_model_proto("gemma3")
proto = loader.load_model_proto("gemma2")

self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto)
self.assertEqual(len(proto.pieces), 4)
Expand All @@ -140,9 +128,9 @@ def test_load_model_proto_from_cache(
):
mock_exists.return_value = True # Use cache
mock_open_func.return_value.read.return_value = FAKE_MODEL_CONTENT
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH

proto = loader.load_model_proto("gemma3")
proto = loader.load_model_proto("gemma2")

self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto)
mock_get.assert_not_called()
Expand All @@ -166,10 +154,10 @@ def test_load_model_proto_corrupted_cache(
# First hash for corrupted cache, second for good download
mock_sha256.side_effect = [
MagicMock(hexdigest=MagicMock(return_value="wrong_hash")),
MagicMock(hexdigest=MagicMock(return_value=GEMMA3_HASH)),
MagicMock(hexdigest=MagicMock(return_value=GEMMA2_HASH)),
]

proto = loader.load_model_proto("gemma3")
proto = loader.load_model_proto("gemma2")

self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto)
mock_remove.assert_called_once()
Expand All @@ -192,7 +180,7 @@ def test_load_model_proto_bad_hash_from_url(
with self.assertRaisesRegex(
ValueError, "Downloaded model file is corrupted"
):
loader.load_model_proto("gemma3")
loader.load_model_proto("gemma2")

def test_load_model_proto_unsupported(self, *args):
with self.assertRaisesRegex(
Expand All @@ -212,9 +200,9 @@ def test_get_sentencepiece_success(
):
mock_exists.return_value = False
self._setup_get_mock(mock_get)
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH

processor = loader.get_sentencepiece("gemma3")
processor = loader.get_sentencepiece("gemma2")

self.assertIsInstance(processor, spm.SentencePieceProcessor)
mock_get.assert_called_once()
Expand All @@ -237,32 +225,11 @@ def test_get_sentencepiece_caching(
):
mock_exists.return_value = False
self._setup_get_mock(mock_get)
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH

# Call twice
loader.get_sentencepiece("gemma3")
loader.get_sentencepiece("gemma3")
loader.get_sentencepiece("gemma2")
loader.get_sentencepiece("gemma2")

# Should only be loaded once due to lru_cache
mock_get.assert_called_once()


class TestGetHuggingFaceTokenizer(unittest.TestCase):

@patch("genai._local_tokenizer_loader.AutoProcessor")
def test_get_huggingface_tokenizer_success(self, mock_auto_processor):
mock_processor = MagicMock()
mock_tokenizer = MagicMock()
mock_processor.tokenizer = mock_tokenizer
mock_auto_processor.from_pretrained.return_value = mock_processor

tokenizer = loader.get_huggingface_tokenizer("gemma4")

self.assertEqual(tokenizer, mock_tokenizer)
mock_auto_processor.from_pretrained.assert_called_once_with(
"google/gemma-4-E4B-it"
)

def test_get_huggingface_tokenizer_unsupported(self):
with self.assertRaises(KeyError):
loader.get_huggingface_tokenizer("unsupported")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [

[project.optional-dependencies]
aiohttp = ["aiohttp>=3.10.11, <4.0.0"]
local-tokenizer = ["sentencepiece>=0.2.0", "protobuf", "transformers"]
local-tokenizer = ["sentencepiece>=0.2.0", "protobuf"]
pyopenssl = ["pyopenssl"]

[project.urls]
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ websockets==16.0
mcp>=1.14.0; python_version > '3.9'
sentencepiece>=0.2.0
protobuf
transformers>=5.10.1
Loading