Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion google/genai/_local_tokenizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import requests # type: ignore
import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2
from transformers import AutoProcessor


# Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
Expand Down Expand Up @@ -218,6 +217,13 @@ def get_tokenizer_name(model_name: str) -> str:
def get_huggingface_tokenizer(tokenizer_name: str) -> Any:
"""Loads huggingface tokenizer from the given tokenizer name."""
# Load the processor which includes the tokenizer
try:
from transformers import AutoProcessor
except ImportError:
raise ImportError(
"Please install transformers to use huggingface tokenizer: pip install"
" transformers"
) from ImportError
processor = AutoProcessor.from_pretrained( # type: ignore[no-untyped-call]
GEMMA_TOKENIZER_TO_MODEL_NAMES[tokenizer_name]
)
Expand Down
19 changes: 0 additions & 19 deletions google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,3 @@ def test_get_sentencepiece_caching(
mock_get.assert_called_once()


class TestGetHuggingFaceTokenizer(unittest.TestCase):

@patch("genai._local_tokenizer_loader.AutoProcessor")
def test_get_huggingface_tokenizer_success(self, mock_auto_processor):
mock_processor = MagicMock()
mock_tokenizer = MagicMock()
mock_processor.tokenizer = mock_tokenizer
mock_auto_processor.from_pretrained.return_value = mock_processor

tokenizer = loader.get_huggingface_tokenizer("gemma4")

self.assertEqual(tokenizer, mock_tokenizer)
mock_auto_processor.from_pretrained.assert_called_once_with(
"google/gemma-4-E4B-it"
)

def test_get_huggingface_tokenizer_unsupported(self):
with self.assertRaises(KeyError):
loader.get_huggingface_tokenizer("unsupported")
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ dependencies = [

[project.optional-dependencies]
aiohttp = ["aiohttp>=3.10.11, <4.0.0"]
local-tokenizer = ["sentencepiece>=0.2.0", "protobuf", "transformers"]
local-tokenizer = [
"sentencepiece>=0.2.0",
"protobuf",
"pillow",
"torch",
"torchvision",
"transformers",
]
pyopenssl = ["pyopenssl"]

[project.urls]
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ websockets==16.0
mcp>=1.14.0; python_version > '3.9'
sentencepiece>=0.2.0
protobuf
transformers>=5.10.1
Loading