From 432e260524285b25c254b744a4b288ed2b388e0b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 10 Jun 2026 14:10:46 -0700 Subject: [PATCH] feat: Gemma 4 local tokenizer support PiperOrigin-RevId: 930069410 --- google/genai/_local_tokenizer_loader.py | 36 +++------ google/genai/local_tokenizer.py | 41 +++-------- .../local_tokenizer/test_local_tokenizer.py | 73 +------------------ .../test_local_tokenizer_loader.py | 59 ++++----------- pyproject.toml | 2 +- requirements.txt | 1 - 6 files changed, 35 insertions(+), 177 deletions(-) diff --git a/google/genai/_local_tokenizer_loader.py b/google/genai/_local_tokenizer_loader.py index da883fcac..0f6edc3ad 100644 --- a/google/genai/_local_tokenizer_loader.py +++ b/google/genai/_local_tokenizer_loader.py @@ -18,13 +18,12 @@ import hashlib import os import tempfile -from typing import Any, Optional, cast +from typing import Optional, cast import uuid import requests # type: ignore import sentencepiece as spm from sentencepiece import sentencepiece_model_pb2 -from transformers import AutoProcessor # Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models @@ -48,18 +47,6 @@ "gemini-3-pro-preview": "gemma3", } -# https://github.com/google/gemma_pytorch stop supporting gemma 4 moving forward. -_GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES = { - "gemini-3.5-flash": "gemma4", - "gemini-3.1-flash-lite": "gemma4", - "gemini-3.1-pro-preview": "gemma4", - "gemini-4-flash-preview": "gemma4", -} - -GEMMA_TOKENIZER_TO_MODEL_NAMES = { - "gemma4": "google/gemma-4-E4B-it", -} - @dataclasses.dataclass(frozen=True) class _TokenizerConfig: @@ -67,7 +54,14 @@ class _TokenizerConfig: model_hash: str +# TODO: update gemma3 tokenizer _TOKENIZERS = { + "gemma2": _TokenizerConfig( + model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model", + model_hash=( + "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2" + ), + ), "gemma3": _TokenizerConfig( model_url="https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model", model_hash=( @@ -183,7 +177,7 @@ def _load_model_proto_bytes(tokenizer_name: str) -> bytes: """Loads model proto bytes from the given tokenizer name.""" if tokenizer_name not in _TOKENIZERS: raise ValueError( - f"Tokenizer {tokenizer_name} is not supported. " + f"Tokenizer {tokenizer_name} is not supported." f"Supported tokenizers: {list(_TOKENIZERS.keys())}" ) return _load( @@ -208,23 +202,11 @@ def get_tokenizer_name(model_name: str) -> str: return _GEMINI_MODELS_TO_TOKENIZER_NAMES[model_name] if model_name in _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys(): return _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES[model_name] - if model_name in _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES.keys(): - return _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES[model_name] raise ValueError( f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_MODELS_TO_TOKENIZER_NAMES.keys())}, {', '.join(_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys())}.\n" # pylint: disable=line-too-long ) -def get_huggingface_tokenizer(tokenizer_name: str) -> Any: - """Loads huggingface tokenizer from the given tokenizer name.""" - # Load the processor which includes the tokenizer - processor = AutoProcessor.from_pretrained( # type: ignore[no-untyped-call] - GEMMA_TOKENIZER_TO_MODEL_NAMES[tokenizer_name] - ) - # Access the underlying tokenizer if needed - return processor.tokenizer - - @functools.lru_cache() def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor: """Loads sentencepiece tokenizer from the given tokenizer name.""" diff --git a/google/genai/local_tokenizer.py b/google/genai/local_tokenizer.py index cf947c39f..0527ff5a1 100644 --- a/google/genai/local_tokenizer.py +++ b/google/genai/local_tokenizer.py @@ -291,12 +291,8 @@ class LocalTokenizer: def __init__(self, model_name: str): self._tokenizer_name = loader.get_tokenizer_name(model_name) - self._model_proto = None - if self._tokenizer_name in loader.GEMMA_TOKENIZER_TO_MODEL_NAMES: - self._tokenizer = loader.get_huggingface_tokenizer(self._tokenizer_name) - else: - self._model_proto = loader.load_model_proto(self._tokenizer_name) - self._tokenizer = loader.get_sentencepiece(self._tokenizer_name) + self._model_proto = loader.load_model_proto(self._tokenizer_name) + self._tokenizer = loader.get_sentencepiece(self._tokenizer_name) @_common.experimental_warning( "The SDK's local tokenizer implementation is experimental and may change" @@ -369,46 +365,27 @@ def compute_tokens( # tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')] """ processed_contents = t.t_contents(contents) - roles = [] - text_accumulator = _TextsAccumulator() for content in processed_contents: text_accumulator.add_content(content) + tokens_protos = self._tokenizer.EncodeAsImmutableProto( + text_accumulator.get_texts() + ) + + roles = [] + for content in processed_contents: if content.parts: for _ in content.parts: roles.append(content.role) token_infos = [] - if self._tokenizer_name in loader.GEMMA_TOKENIZER_TO_MODEL_NAMES: - # Use the HuggingFace tokenizer since gemma_pytorch is not available for - # gemma 4+. - token_ids = self._tokenizer.encode(list(text_accumulator.get_texts())) - for token_id, role in zip(token_ids, roles): - token_infos.append( - types.TokensInfo( - token_ids=token_id, - tokens=[ - token.replace("_", " ") - .encode("utf-8") - .replace(b"\xe2\x96\x81", b" ") - for token in self._tokenizer.convert_ids_to_tokens(token_id) - ], - role=role, - ) - ) - return types.ComputeTokensResult(tokens_info=token_infos) - - tokens_protos = self._tokenizer.EncodeAsImmutableProto( - text_accumulator.get_texts() - ) - for tokens_proto, role in zip(tokens_protos, roles): token_infos.append( types.TokensInfo( token_ids=[piece.id for piece in tokens_proto.pieces], tokens=[ _token_str_to_bytes( - piece.piece, self._model_proto.pieces[piece.id].type # type: ignore[union-attr] + piece.piece, self._model_proto.pieces[piece.id].type ) for piece in tokens_proto.pieces ], diff --git a/google/genai/tests/local_tokenizer/test_local_tokenizer.py b/google/genai/tests/local_tokenizer/test_local_tokenizer.py index 1d9a12db9..f50dfaaa5 100644 --- a/google/genai/tests/local_tokenizer/test_local_tokenizer.py +++ b/google/genai/tests/local_tokenizer/test_local_tokenizer.py @@ -29,7 +29,6 @@ def setUp(self): self.mock_load_model_proto = patch( 'genai._local_tokenizer_loader.load_model_proto' ).start() - self.addCleanup(patch.stopall) self.mock_get_sentencepiece = patch( 'genai._local_tokenizer_loader.get_sentencepiece' ).start() @@ -40,6 +39,9 @@ def setUp(self): self.tokenizer = local_tokenizer.LocalTokenizer(model_name='gemini-3-pro-preview') + def tearDown(self): + patch.stopall() + def test_count_tokens_simple_string(self): self.mock_tokenizer.encode.return_value = [[1, 2, 3]] result = self.tokenizer.count_tokens('Hello world') @@ -339,72 +341,3 @@ def test_invalid_format(self): def test_invalid_hex_value(self): with self.assertRaisesRegex(ValueError, 'Invalid hex value'): local_tokenizer._parse_hex_byte('<0xFG>') - - -class TestLocalTokenizerHuggingFace(unittest.TestCase): - - def setUp(self): - self.mock_get_huggingface_tokenizer = patch( - 'genai._local_tokenizer_loader.get_huggingface_tokenizer' - ).start() - self.addCleanup(patch.stopall) - - self.mock_tokenizer = MagicMock() - self.mock_get_huggingface_tokenizer.return_value = self.mock_tokenizer - - # gemini-3.5-flash maps to gemma4 (HuggingFace) - self.tokenizer = local_tokenizer.LocalTokenizer(model_name='gemini-3.5-flash') - - def test_count_tokens_simple_string(self): - self.mock_tokenizer.encode.return_value = [[1, 2, 3]] - result = self.tokenizer.count_tokens('Hello world') - self.assertEqual(result.total_tokens, 3) - self.mock_tokenizer.encode.assert_called_once_with(['Hello world']) - - def test_compute_tokens_simple_string(self): - self.mock_tokenizer.encode.return_value = [[1, 2, 3]] - self.mock_tokenizer.convert_ids_to_tokens.return_value = ['He', 'llo', ' world'] - - result = self.tokenizer.compute_tokens('Hello world') - - self.assertEqual(len(result.tokens_info), 1) - self.assertEqual(result.tokens_info[0].token_ids, [1, 2, 3]) - self.assertEqual(result.tokens_info[0].tokens, [b'He', b'llo', b' world']) - self.assertEqual(result.tokens_info[0].role, 'user') - - self.mock_tokenizer.encode.assert_called_once_with(['Hello world']) - self.mock_tokenizer.convert_ids_to_tokens.assert_called_once_with([1, 2, 3]) - - def test_compute_tokens_special_characters(self): - self.mock_tokenizer.encode.return_value = [[1, 2]] - # Use U+2581 (lower one eighth block) and underscore - self.mock_tokenizer.convert_ids_to_tokens.return_value = ['_world', '\u2581hello'] - - result = self.tokenizer.compute_tokens('dummy') - - self.assertEqual(result.tokens_info[0].tokens, [b' world', b' hello']) - - def test_compute_tokens_with_chat_history(self): - self.mock_tokenizer.encode.return_value = [[1], [2, 3]] - self.mock_tokenizer.convert_ids_to_tokens.side_effect = [ - ['Hello'], - ['Hi', ' there!'] - ] - history = [ - types.Content(role='user', parts=[types.Part(text='Hello')]), - types.Content(role='model', parts=[types.Part(text='Hi there!')]), - ] - result = self.tokenizer.compute_tokens(history) - self.assertEqual(len(result.tokens_info), 2) - self.assertEqual(result.tokens_info[0].token_ids, [1]) - self.assertEqual(result.tokens_info[0].tokens, [b'Hello']) - self.assertEqual(result.tokens_info[0].role, 'user') - self.assertEqual(result.tokens_info[1].token_ids, [2, 3]) - self.assertEqual(result.tokens_info[1].tokens, [b'Hi', b' there!']) - self.assertEqual(result.tokens_info[1].role, 'model') - - self.mock_tokenizer.encode.assert_called_once_with(['Hello', 'Hi there!']) - self.mock_tokenizer.convert_ids_to_tokens.assert_has_calls([ - unittest.mock.call([1]), - unittest.mock.call([2, 3]) - ]) diff --git a/google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py b/google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py index 6070729cb..09cd63f42 100644 --- a/google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py +++ b/google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py @@ -47,7 +47,7 @@ ] ).SerializeToString() -GEMMA3_HASH = "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c" +GEMMA2_HASH = "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2" class TestGetTokenizerName(unittest.TestCase): @@ -58,18 +58,6 @@ def test_get_tokenizer_name_success(self): loader.get_tokenizer_name("gemini-2.5-pro-preview-06-05"), "gemma3" ) - def test_get_tokenizer_name_huggingface(self): - self.assertEqual(loader.get_tokenizer_name("gemini-3.5-flash"), "gemma4") - self.assertEqual( - loader.get_tokenizer_name("gemini-3.1-flash-lite"), "gemma4" - ) - self.assertEqual( - loader.get_tokenizer_name("gemini-3.1-pro-preview"), "gemma4" - ) - self.assertEqual( - loader.get_tokenizer_name("gemini-4-flash-preview"), "gemma4" - ) - def test_get_tokenizer_name_unsupported(self): with self.assertRaisesRegex( ValueError, "Model unsupported-model is not supported" @@ -117,9 +105,9 @@ def test_load_model_proto_from_url( ): mock_exists.return_value = False # Don't use cache self._setup_get_mock(mock_get) - mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH + mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH - proto = loader.load_model_proto("gemma3") + proto = loader.load_model_proto("gemma2") self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto) self.assertEqual(len(proto.pieces), 4) @@ -140,9 +128,9 @@ def test_load_model_proto_from_cache( ): mock_exists.return_value = True # Use cache mock_open_func.return_value.read.return_value = FAKE_MODEL_CONTENT - mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH + mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH - proto = loader.load_model_proto("gemma3") + proto = loader.load_model_proto("gemma2") self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto) mock_get.assert_not_called() @@ -166,10 +154,10 @@ def test_load_model_proto_corrupted_cache( # First hash for corrupted cache, second for good download mock_sha256.side_effect = [ MagicMock(hexdigest=MagicMock(return_value="wrong_hash")), - MagicMock(hexdigest=MagicMock(return_value=GEMMA3_HASH)), + MagicMock(hexdigest=MagicMock(return_value=GEMMA2_HASH)), ] - proto = loader.load_model_proto("gemma3") + proto = loader.load_model_proto("gemma2") self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto) mock_remove.assert_called_once() @@ -192,7 +180,7 @@ def test_load_model_proto_bad_hash_from_url( with self.assertRaisesRegex( ValueError, "Downloaded model file is corrupted" ): - loader.load_model_proto("gemma3") + loader.load_model_proto("gemma2") def test_load_model_proto_unsupported(self, *args): with self.assertRaisesRegex( @@ -212,9 +200,9 @@ def test_get_sentencepiece_success( ): mock_exists.return_value = False self._setup_get_mock(mock_get) - mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH + mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH - processor = loader.get_sentencepiece("gemma3") + processor = loader.get_sentencepiece("gemma2") self.assertIsInstance(processor, spm.SentencePieceProcessor) mock_get.assert_called_once() @@ -237,32 +225,11 @@ def test_get_sentencepiece_caching( ): mock_exists.return_value = False self._setup_get_mock(mock_get) - mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH + mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH # Call twice - loader.get_sentencepiece("gemma3") - loader.get_sentencepiece("gemma3") + loader.get_sentencepiece("gemma2") + loader.get_sentencepiece("gemma2") # Should only be loaded once due to lru_cache mock_get.assert_called_once() - - -class TestGetHuggingFaceTokenizer(unittest.TestCase): - - @patch("genai._local_tokenizer_loader.AutoProcessor") - def test_get_huggingface_tokenizer_success(self, mock_auto_processor): - mock_processor = MagicMock() - mock_tokenizer = MagicMock() - mock_processor.tokenizer = mock_tokenizer - mock_auto_processor.from_pretrained.return_value = mock_processor - - tokenizer = loader.get_huggingface_tokenizer("gemma4") - - self.assertEqual(tokenizer, mock_tokenizer) - mock_auto_processor.from_pretrained.assert_called_once_with( - "google/gemma-4-E4B-it" - ) - - def test_get_huggingface_tokenizer_unsupported(self): - with self.assertRaises(KeyError): - loader.get_huggingface_tokenizer("unsupported") diff --git a/pyproject.toml b/pyproject.toml index ca4f19131..eaf5fad64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ [project.optional-dependencies] aiohttp = ["aiohttp>=3.10.11, <4.0.0"] -local-tokenizer = ["sentencepiece>=0.2.0", "protobuf", "transformers"] +local-tokenizer = ["sentencepiece>=0.2.0", "protobuf"] pyopenssl = ["pyopenssl"] [project.urls] diff --git a/requirements.txt b/requirements.txt index 6c021d229..322d9fe27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,3 @@ websockets==16.0 mcp>=1.14.0; python_version > '3.9' sentencepiece>=0.2.0 protobuf -transformers>=5.10.1