From 480aaa613f123d67028cb9b9160e5c7afd7d04b6 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 2 Jul 2026 03:48:10 +0000 Subject: [PATCH 1/4] feat: add Milvus memory service Signed-off-by: Cheney Zhang --- contributing/samples/milvus_memory/README.md | 111 +++ pyproject.toml | 4 + src/google/adk_community/memory/__init__.py | 5 +- .../memory/milvus_memory_service.py | 669 ++++++++++++++++++ .../memory/test_milvus_memory_service_e2e.py | 148 ++++ .../memory/test_milvus_memory_service.py | 430 +++++++++++ 6 files changed, 1366 insertions(+), 1 deletion(-) create mode 100644 contributing/samples/milvus_memory/README.md create mode 100644 src/google/adk_community/memory/milvus_memory_service.py create mode 100644 tests/integration/memory/test_milvus_memory_service_e2e.py create mode 100644 tests/unittests/memory/test_milvus_memory_service.py diff --git a/contributing/samples/milvus_memory/README.md b/contributing/samples/milvus_memory/README.md new file mode 100644 index 00000000..f4ee6651 --- /dev/null +++ b/contributing/samples/milvus_memory/README.md @@ -0,0 +1,111 @@ +# Milvus Memory Service sample + +This sample shows how to use Milvus as an ADK `BaseMemoryService` backend for +cross-session memory. + +`MilvusMemoryService` supports the same configuration shape for: + +- Milvus Lite: local development with a local database path +- Milvus server: self-hosted Milvus, such as `http://localhost:19530` +- Zilliz Cloud: managed Milvus with a cloud endpoint and token + +## Installation + +```bash +pip install "google-adk-community[milvus]" +``` + +The `milvus` extra installs current `pymilvus` and Milvus Lite packages. +Milvus Lite 3.x local storage is not compatible with older 2.x local database +files or directories, so create a new local database path for new projects. + +## Configuration + +Use `MILVUS_URI` and `MILVUS_TOKEN` for all deployment modes: + +```bash +# Milvus Lite +export MILVUS_URI="./adk_milvus_memory.db" + +# Milvus server +export MILVUS_URI="http://localhost:19530" + +# Zilliz Cloud +export MILVUS_URI="https://your-endpoint.api.gcp-us-west1.zillizcloud.com" +export MILVUS_TOKEN="your-token" +``` + +`MILVUS_TOKEN` is only needed for authenticated deployments such as Zilliz +Cloud. If you use a non-default Milvus database, set `MILVUS_DB_NAME`. + +## Use with Runner + +```python +from google.adk.agents import Agent +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.adk_community.memory import MilvusMemoryService +from google.genai import Client + + +genai_client = Client() + + +def embedding_function(texts): + response = genai_client.models.embed_content( + model="gemini-embedding-001", + contents=list(texts), + ) + return [list(embedding.values) for embedding in response.embeddings] + + +memory_service = MilvusMemoryService( + embedding_function=embedding_function, + dimension=3072, + collection_name="adk_memory", +) + +agent = Agent( + name="memory_agent", + model="gemini-flash-latest", + instruction="Use memory to personalize responses when relevant.", +) + +runner = Runner( + app_name="milvus_memory_app", + agent=agent, + session_service=InMemorySessionService(), + memory_service=memory_service, +) +``` + +After a session has useful conversation history, persist it: + +```python +session = await runner.session_service.get_session( + app_name="milvus_memory_app", + user_id="user-1", + session_id="session-1", +) +await memory_service.add_session_to_memory(session) +``` + +Later, retrieve relevant memories: + +```python +result = await memory_service.search_memory( + app_name="milvus_memory_app", + user_id="user-1", + query="what did the user say about database preferences?", +) +for memory in result.memories: + print(memory.content.parts[0].text) +``` + +## Notes + +- `dimension` must match the embedding model output dimension. +- Re-ingesting the same ADK event uses a stable ID and updates the existing + Milvus record instead of creating duplicates. +- Search is scoped by `app_name` and `user_id`, so users cannot retrieve each + other's memories through the memory service. diff --git a/pyproject.toml b/pyproject.toml index a03bdcab..9eb9904e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,10 @@ hitl = [ "fastapi>=0.110.0", "sqlalchemy>=2.0.0", ] +milvus = [ + "milvus-lite>=3.0", + "pymilvus>=3.0.0", +] [tool.pyink] diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c0..6acebef6 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -14,11 +14,14 @@ """Community memory services for ADK.""" +from .milvus_memory_service import MilvusMemoryService +from .milvus_memory_service import MilvusMemoryServiceConfig from .open_memory_service import OpenMemoryService from .open_memory_service import OpenMemoryServiceConfig __all__ = [ + "MilvusMemoryService", + "MilvusMemoryServiceConfig", "OpenMemoryService", "OpenMemoryServiceConfig", ] - diff --git a/src/google/adk_community/memory/milvus_memory_service.py b/src/google/adk_community/memory/milvus_memory_service.py new file mode 100644 index 00000000..7d48290a --- /dev/null +++ b/src/google/adk_community/memory/milvus_memory_service.py @@ -0,0 +1,669 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus-backed memory service for ADK.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable +from collections.abc import Callable +from collections.abc import Mapping +from collections.abc import Sequence +import hashlib +import inspect +import json +import logging +import os +import re +from typing import Optional +from typing import TYPE_CHECKING + +from google.adk.memory import _utils +from google.adk.memory.base_memory_service import BaseMemoryService +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override + +from .utils import extract_text_from_event + +if TYPE_CHECKING: + from google.adk.events.event import Event + from google.adk.sessions.session import Session + +logger = logging.getLogger("google_adk_community." + __name__) + +EmbeddingFunction = Callable[ + [Sequence[str]], + Sequence[Sequence[float]] | Awaitable[Sequence[Sequence[float]]], +] + +_DEFAULT_COLLECTION_NAME = "adk_memory" +_DEFAULT_LITE_URI = "./adk_milvus_memory.db" +_UNKNOWN_SESSION_ID = "__unknown_session_id__" +_EVENT_SOURCE = "adk_event" +_DIRECT_MEMORY_SOURCE = "adk_memory" +_FIELD_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _load_pymilvus(): + try: + from pymilvus import DataType + from pymilvus import MilvusClient + except ImportError as exc: + raise ImportError( + "pymilvus is required to use MilvusMemoryService. " + "Install it with: pip install google-adk-community[milvus]" + ) from exc + return MilvusClient, DataType + + +def _env_value(name: str) -> Optional[str]: + value = os.getenv(name) + return value if value else None + + +def _json_safe(value: Mapping[str, object] | None) -> dict[str, object]: + if not value: + return {} + return json.loads(json.dumps(dict(value), default=str)) + + +def _quote_filter_value(value: str) -> str: + return json.dumps(value) + + +def _hash_id(*parts: object) -> str: + digest = hashlib.sha256() + for part in parts: + digest.update(str(part).encode("utf-8")) + digest.update(b"\0") + return "adk-" + digest.hexdigest() + + +def _content_to_text(content: types.Content | None) -> str: + if not content or not content.parts: + return "" + text_parts = [ + part.text for part in content.parts if part.text and not part.thought + ] + return " ".join(text_parts) + + +def _timestamp_from_event(event: Event) -> str: + if event.timestamp is None: + return "" + return _utils.format_timestamp(event.timestamp) + + +def _field_name(field: Mapping[str, object]) -> str | None: + raw_name = field.get("name", field.get("field_name")) + return str(raw_name) if raw_name is not None else None + + +def _field_dim(field: Mapping[str, object]) -> int | None: + params = field.get("params", {}) + candidates = [] + if isinstance(params, Mapping): + candidates.append(params.get("dim")) + nested_params = params.get("params") + if isinstance(nested_params, Mapping): + candidates.append(nested_params.get("dim")) + candidates.append(field.get("dim")) + for candidate in candidates: + if candidate is not None: + return int(candidate) + return None + + +def _field_type(field: Mapping[str, object]) -> object | None: + return field.get("type", field.get("data_type", field.get("datatype"))) + + +def _type_label(data_type: object) -> str: + return str(getattr(data_type, "name", data_type)) + + +def _field_type_matches(actual: object, expected: object) -> bool: + if actual == expected: + return True + actual_value = getattr(actual, "value", actual) + expected_value = getattr(expected, "value", expected) + if actual_value == expected_value: + return True + return _type_label(actual) == _type_label(expected) + + +def _validate_field_names(config: "MilvusMemoryServiceConfig") -> None: + field_names = [ + config.id_field, + config.vector_field, + config.text_field, + config.app_name_field, + config.user_id_field, + config.session_id_field, + config.event_id_field, + config.author_field, + config.timestamp_field, + config.source_field, + config.metadata_field, + ] + invalid_names = [ + name for name in field_names if not _FIELD_NAME_PATTERN.fullmatch(name) + ] + if invalid_names: + raise ValueError( + "Milvus field names must be valid identifiers: " + + ", ".join(sorted(set(invalid_names))) + ) + duplicate_names = sorted( + {name for name in field_names if field_names.count(name) > 1} + ) + if duplicate_names: + raise ValueError( + "Milvus field names must be unique: " + ", ".join(duplicate_names) + ) + + +class MilvusMemoryServiceConfig(BaseModel): + """Configuration for Milvus memory storage.""" + + uri: str = Field( + default_factory=lambda: _env_value("MILVUS_URI") or _DEFAULT_LITE_URI + ) + token: Optional[str] = Field( + default_factory=lambda: _env_value("MILVUS_TOKEN") + ) + db_name: Optional[str] = Field( + default_factory=lambda: _env_value("MILVUS_DB_NAME") + ) + collection_name: str = Field(default=_DEFAULT_COLLECTION_NAME, min_length=1) + dimension: int = Field(gt=0) + search_top_k: int = Field(default=10, ge=1, le=100) + metric_type: str = Field(default="COSINE") + index_type: str = Field(default="AUTOINDEX") + consistency_level: Optional[str] = Field(default="Session") + timeout: Optional[float] = Field(default=None, gt=0.0) + id_field: str = Field(default="id") + vector_field: str = Field(default="embedding") + text_field: str = Field(default="text") + app_name_field: str = Field(default="app_name") + user_id_field: str = Field(default="user_id") + session_id_field: str = Field(default="session_id") + event_id_field: str = Field(default="event_id") + author_field: str = Field(default="author") + timestamp_field: str = Field(default="timestamp") + source_field: str = Field(default="source") + metadata_field: str = Field(default="metadata") + id_max_length: int = Field(default=512, gt=0) + text_max_length: int = Field(default=65535, gt=0) + scalar_max_length: int = Field(default=1024, gt=0) + + +class MilvusMemoryService(BaseMemoryService): + """A Milvus-backed implementation of ADK's BaseMemoryService.""" + + def __init__( + self, + *, + embedding_function: EmbeddingFunction, + dimension: int | None = None, + config: MilvusMemoryServiceConfig | None = None, + uri: str | None = None, + token: str | None = None, + db_name: str | None = None, + collection_name: str | None = None, + search_top_k: int | None = None, + consistency_level: str | None = None, + ): + """Initializes the Milvus memory service. + + Args: + embedding_function: Function that embeds a batch of texts. + dimension: Embedding vector dimension. Required unless provided by config. + config: Optional MilvusMemoryServiceConfig. + uri: Optional Milvus URI override. Defaults to MILVUS_URI or local Lite. + token: Optional token override. Defaults to MILVUS_TOKEN. + db_name: Optional Milvus database name override. Defaults to + MILVUS_DB_NAME. + collection_name: Optional collection name override. + search_top_k: Optional search result limit override. + consistency_level: Optional collection consistency level override. + """ + if config is None: + if dimension is None: + raise ValueError("dimension is required when config is not provided.") + config = MilvusMemoryServiceConfig(dimension=dimension) + elif dimension is not None: + config = config.model_copy(update={"dimension": dimension}) + + updates: dict[str, object] = {} + if uri is not None: + updates["uri"] = uri + if token is not None: + updates["token"] = token + if db_name is not None: + updates["db_name"] = db_name + if collection_name is not None: + updates["collection_name"] = collection_name + if search_top_k is not None: + updates["search_top_k"] = search_top_k + if consistency_level is not None: + updates["consistency_level"] = consistency_level + if updates: + config = config.model_copy(update=updates) + + _validate_field_names(config) + + self._embedding_function = embedding_function + self._config = config + milvus_client, data_type = _load_pymilvus() + self._data_type = data_type + client_kwargs: dict[str, object] = { + "uri": self._config.uri, + } + if self._config.token: + client_kwargs["token"] = self._config.token + if self._config.db_name: + client_kwargs["db_name"] = self._config.db_name + if self._config.timeout is not None: + client_kwargs["timeout"] = self._config.timeout + self._client = milvus_client(**client_kwargs) + self._ensure_collection() + + def _ensure_collection(self) -> None: + if self._client.has_collection( + collection_name=self._config.collection_name, + timeout=self._config.timeout, + ): + self._validate_existing_collection() + return + + schema = self._client.create_schema( + auto_id=False, + enable_dynamic_field=True, + ) + schema.add_field( + field_name=self._config.id_field, + datatype=self._data_type.VARCHAR, + is_primary=True, + max_length=self._config.id_max_length, + ) + schema.add_field( + field_name=self._config.vector_field, + datatype=self._data_type.FLOAT_VECTOR, + dim=self._config.dimension, + ) + for field_name, max_length in [ + (self._config.text_field, self._config.text_max_length), + (self._config.app_name_field, self._config.scalar_max_length), + (self._config.user_id_field, self._config.scalar_max_length), + (self._config.session_id_field, self._config.scalar_max_length), + (self._config.event_id_field, self._config.scalar_max_length), + (self._config.author_field, self._config.scalar_max_length), + (self._config.timestamp_field, self._config.scalar_max_length), + (self._config.source_field, self._config.scalar_max_length), + ]: + schema.add_field( + field_name=field_name, + datatype=self._data_type.VARCHAR, + max_length=max_length, + ) + schema.add_field( + field_name=self._config.metadata_field, + datatype=self._data_type.JSON, + ) + + index_params = self._client.prepare_index_params() + index_params.add_index( + field_name=self._config.vector_field, + index_type=self._config.index_type, + metric_type=self._config.metric_type, + ) + create_kwargs: dict[str, object] = { + "collection_name": self._config.collection_name, + "schema": schema, + "index_params": index_params, + "timeout": self._config.timeout, + } + if self._config.consistency_level: + create_kwargs["consistency_level"] = self._config.consistency_level + self._client.create_collection(**create_kwargs) + + def _validate_existing_collection(self) -> None: + description = self._client.describe_collection( + collection_name=self._config.collection_name, + timeout=self._config.timeout, + ) + fields = { + name: field + for field in description.get("fields", []) + if isinstance(field, Mapping) and (name := _field_name(field)) + } + required_fields = [ + self._config.id_field, + self._config.vector_field, + self._config.text_field, + self._config.app_name_field, + self._config.user_id_field, + ] + missing_fields = [field for field in required_fields if field not in fields] + if missing_fields: + raise ValueError( + "Milvus collection " + f"{self._config.collection_name!r} is missing required fields: " + + ", ".join(missing_fields) + ) + if description.get("auto_id") is True: + raise ValueError( + "Milvus collection " + f"{self._config.collection_name!r} must use auto_id=False." + ) + id_field = fields[self._config.id_field] + if id_field.get("is_primary") is not None and not id_field.get( + "is_primary" + ): + raise ValueError( + "Milvus collection " + f"{self._config.collection_name!r} field " + f"{self._config.id_field!r} must be the primary key." + ) + self._validate_field_type( + id_field, self._data_type.VARCHAR, self._config.id_field + ) + self._validate_field_type( + fields[self._config.vector_field], + self._data_type.FLOAT_VECTOR, + self._config.vector_field, + ) + for field_name in [ + self._config.text_field, + self._config.app_name_field, + self._config.user_id_field, + ]: + self._validate_field_type( + fields[field_name], self._data_type.VARCHAR, field_name + ) + vector_dim = _field_dim(fields[self._config.vector_field]) + if vector_dim is not None and vector_dim != self._config.dimension: + raise ValueError( + "Milvus collection " + f"{self._config.collection_name!r} has vector dimension " + f"{vector_dim}, expected {self._config.dimension}." + ) + + def _validate_field_type( + self, field: Mapping[str, object], expected_type: object, field_name: str + ) -> None: + actual_type = _field_type(field) + if actual_type is not None and not _field_type_matches( + actual_type, expected_type + ): + raise ValueError( + "Milvus collection " + f"{self._config.collection_name!r} field {field_name!r} has type " + f"{_type_label(actual_type)}, expected {_type_label(expected_type)}." + ) + + async def _embed_texts(self, texts: Sequence[str]) -> list[list[float]]: + if not texts: + return [] + embeddings = self._embedding_function(texts) + if inspect.isawaitable(embeddings): + embeddings = await embeddings + + vectors = [list(vector) for vector in embeddings] + if len(vectors) != len(texts): + raise ValueError( + "embedding_function returned " + f"{len(vectors)} vectors for {len(texts)} texts." + ) + for vector in vectors: + if len(vector) != self._config.dimension: + raise ValueError( + "embedding_function returned vector dimension " + f"{len(vector)}, expected {self._config.dimension}." + ) + return vectors + + def _scope_filter(self, *, app_name: str, user_id: str) -> str: + return ( + f"{self._config.app_name_field} == {_quote_filter_value(app_name)} " + f"and {self._config.user_id_field} == {_quote_filter_value(user_id)}" + ) + + def _event_to_record( + self, + *, + app_name: str, + user_id: str, + session_id: str | None, + event: Event, + text: str, + embedding: Sequence[float], + custom_metadata: Mapping[str, object] | None = None, + ) -> dict[str, object]: + scoped_session_id = session_id or _UNKNOWN_SESSION_ID + event_id = event.id or _hash_id( + app_name, + user_id, + scoped_session_id, + event.author, + event.timestamp, + text, + ) + record_id = _hash_id(app_name, user_id, scoped_session_id, event_id) + metadata = _json_safe(custom_metadata) + metadata.update({ + "invocation_id": event.invocation_id, + "source": _EVENT_SOURCE, + }) + return self._record( + record_id=record_id, + app_name=app_name, + user_id=user_id, + session_id=scoped_session_id, + event_id=event_id, + author=event.author or "", + timestamp=_timestamp_from_event(event), + text=text, + embedding=embedding, + source=_EVENT_SOURCE, + metadata=metadata, + ) + + def _memory_to_record( + self, + *, + app_name: str, + user_id: str, + memory: MemoryEntry, + text: str, + embedding: Sequence[float], + index: int, + custom_metadata: Mapping[str, object] | None = None, + ) -> dict[str, object]: + record_id = memory.id or _hash_id(app_name, user_id, index, text) + metadata = _json_safe(custom_metadata) + metadata.update(_json_safe(memory.custom_metadata)) + metadata["source"] = _DIRECT_MEMORY_SOURCE + return self._record( + record_id=record_id, + app_name=app_name, + user_id=user_id, + session_id="", + event_id="", + author=memory.author or "", + timestamp=memory.timestamp or "", + text=text, + embedding=embedding, + source=_DIRECT_MEMORY_SOURCE, + metadata=metadata, + ) + + def _record( + self, + *, + record_id: str, + app_name: str, + user_id: str, + session_id: str, + event_id: str, + author: str, + timestamp: str, + text: str, + embedding: Sequence[float], + source: str, + metadata: Mapping[str, object], + ) -> dict[str, object]: + return { + self._config.id_field: record_id[: self._config.id_max_length], + self._config.vector_field: list(embedding), + self._config.text_field: text[: self._config.text_max_length], + self._config.app_name_field: app_name[: self._config.scalar_max_length], + self._config.user_id_field: user_id[: self._config.scalar_max_length], + self._config.session_id_field: session_id[ + : self._config.scalar_max_length + ], + self._config.event_id_field: event_id[: self._config.scalar_max_length], + self._config.author_field: author[: self._config.scalar_max_length], + self._config.timestamp_field: timestamp[ + : self._config.scalar_max_length + ], + self._config.source_field: source[: self._config.scalar_max_length], + self._config.metadata_field: dict(metadata), + } + + async def _upsert_records(self, records: Sequence[dict[str, object]]) -> None: + if not records: + return + await asyncio.to_thread( + self._client.upsert, + collection_name=self._config.collection_name, + data=list(records), + timeout=self._config.timeout, + ) + + @override + async def add_session_to_memory(self, session: Session) -> None: + await self.add_events_to_memory( + app_name=session.app_name, + user_id=session.user_id, + events=session.events, + session_id=session.id, + ) + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + items = [ + (event, text) + for event in events + if (text := extract_text_from_event(event)) + ] + embeddings = await self._embed_texts([text for _, text in items]) + records = [ + self._event_to_record( + app_name=app_name, + user_id=user_id, + session_id=session_id, + event=event, + text=text, + embedding=embedding, + custom_metadata=custom_metadata, + ) + for (event, text), embedding in zip(items, embeddings) + ] + await self._upsert_records(records) + logger.info("Added %d memories to Milvus.", len(records)) + + @override + async def add_memory( + self, + *, + app_name: str, + user_id: str, + memories: Sequence[MemoryEntry], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + items = [ + (index, memory, text) + for index, memory in enumerate(memories) + if (text := _content_to_text(memory.content)) + ] + embeddings = await self._embed_texts([text for _, _, text in items]) + records = [ + self._memory_to_record( + app_name=app_name, + user_id=user_id, + memory=memory, + text=text, + embedding=embedding, + index=index, + custom_metadata=custom_metadata, + ) + for (index, memory, text), embedding in zip(items, embeddings) + ] + await self._upsert_records(records) + logger.info("Added %d direct memories to Milvus.", len(records)) + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + query_embedding = (await self._embed_texts([query]))[0] + results = await asyncio.to_thread( + self._client.search, + collection_name=self._config.collection_name, + data=[query_embedding], + filter=self._scope_filter(app_name=app_name, user_id=user_id), + limit=self._config.search_top_k, + output_fields=[ + self._config.text_field, + self._config.author_field, + self._config.timestamp_field, + self._config.metadata_field, + ], + anns_field=self._config.vector_field, + search_params={"metric_type": self._config.metric_type}, + timeout=self._config.timeout, + ) + memories: list[MemoryEntry] = [] + for hit in results[0] if results else []: + entity = hit.get("entity", {}) + text = entity.get(self._config.text_field, "") + if not text: + continue + memories.append( + MemoryEntry( + content=types.Content(parts=[types.Part(text=text)]), + author=entity.get(self._config.author_field) or None, + timestamp=entity.get(self._config.timestamp_field) or None, + custom_metadata=entity.get(self._config.metadata_field) or {}, + ) + ) + return SearchMemoryResponse(memories=memories) + + async def close(self) -> None: + await asyncio.to_thread(self._client.close) diff --git a/tests/integration/memory/test_milvus_memory_service_e2e.py b/tests/integration/memory/test_milvus_memory_service_e2e.py new file mode 100644 index 00000000..bdb8295c --- /dev/null +++ b/tests/integration/memory/test_milvus_memory_service_e2e.py @@ -0,0 +1,148 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +import uuid + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +from google.adk_community.memory.milvus_memory_service import MilvusMemoryService +from google.adk_community.memory.milvus_memory_service import MilvusMemoryServiceConfig + +_VOCAB = ("milvus", "vector", "memory", "cooking") + + +def _keyword_embedding(texts): + vectors = [] + for text in texts: + lowered = text.lower() + vector = [float(lowered.count(word)) for word in _VOCAB] + if not any(vector): + vector[-1] = 0.01 + vectors.append(vector) + return vectors + + +def _session() -> Session: + return Session( + app_name="milvus-memory-e2e", + user_id="user-1", + id="session-1", + last_update_time=1000, + events=[ + Event( + id="event-1", + invocation_id="inv-1", + author="user", + timestamp=12345, + content=types.Content( + parts=[ + types.Part(text="Milvus stores vector memory for agents.") + ] + ), + ), + Event( + id="event-2", + invocation_id="inv-2", + author="user", + timestamp=12346, + content=types.Content( + parts=[types.Part(text="I enjoy cooking pasta.")] + ), + ), + ], + ) + + +async def _drop_collection(service: MilvusMemoryService) -> None: + try: + await asyncio.to_thread( + service._client.drop_collection, # pylint: disable=protected-access + collection_name=service._config.collection_name, # pylint: disable=protected-access + ) + except Exception: + pass + + +async def _run_memory_e2e(config: MilvusMemoryServiceConfig) -> None: + service = MilvusMemoryService( + embedding_function=_keyword_embedding, + config=config, + ) + try: + await service.add_session_to_memory(_session()) + + result = await service.search_memory( + app_name="milvus-memory-e2e", + user_id="user-1", + query="vector memory", + ) + assert result.memories + assert ( + "Milvus stores vector memory" + in result.memories[0].content.parts[0].text + ) + + isolated_result = await service.search_memory( + app_name="milvus-memory-e2e", + user_id="user-2", + query="vector memory", + ) + assert isolated_result.memories == [] + finally: + await _drop_collection(service) + await service.close() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_MILVUS_LITE_E2E") != "1", + reason="Set RUN_MILVUS_LITE_E2E=1 to run Milvus Lite E2E.", +) +async def test_milvus_lite_memory_e2e(tmp_path: Path): + collection_name = f"adk_memory_e2e_{uuid.uuid4().hex[:8]}" + await _run_memory_e2e( + MilvusMemoryServiceConfig( + uri=str(tmp_path / "milvus_lite.db"), + collection_name=collection_name, + dimension=len(_VOCAB), + consistency_level="Strong", + ) + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), + reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", +) +async def test_zilliz_cloud_memory_e2e(): + collection_name = f"adk_memory_e2e_{uuid.uuid4().hex[:8]}" + await _run_memory_e2e( + MilvusMemoryServiceConfig( + uri=os.environ["ZILLIZ_URI"], + token=os.environ["ZILLIZ_TOKEN"], + db_name=os.getenv("ZILLIZ_DB_NAME") or os.getenv("MILVUS_DB_NAME"), + collection_name=collection_name, + dimension=len(_VOCAB), + consistency_level="Strong", + ) + ) diff --git a/tests/unittests/memory/test_milvus_memory_service.py b/tests/unittests/memory/test_milvus_memory_service.py new file mode 100644 index 00000000..bd044e00 --- /dev/null +++ b/tests/unittests/memory/test_milvus_memory_service.py @@ -0,0 +1,430 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from google.adk.events.event import Event +from google.adk.memory.memory_entry import MemoryEntry +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +from google.adk_community.memory import milvus_memory_service as milvus_module +from google.adk_community.memory.milvus_memory_service import MilvusMemoryService +from google.adk_community.memory.milvus_memory_service import MilvusMemoryServiceConfig + + +class FakeDataType: + VARCHAR = "VARCHAR" + FLOAT_VECTOR = "FLOAT_VECTOR" + JSON = "JSON" + + +class FakeSchema: + + def __init__(self): + self.fields = [] + + def add_field(self, **kwargs): + self.fields.append(kwargs) + + +class FakeIndexParams: + + def __init__(self): + self.indexes = [] + + def add_index(self, **kwargs): + self.indexes.append(kwargs) + + +@pytest.fixture +def fake_milvus(monkeypatch): + client = MagicMock() + client.has_collection.return_value = False + schema = FakeSchema() + index_params = FakeIndexParams() + client.create_schema.return_value = schema + client.prepare_index_params.return_value = index_params + client.describe_collection.return_value = { + "auto_id": False, + "fields": [ + {"name": "id", "type": FakeDataType.VARCHAR, "is_primary": True}, + { + "name": "embedding", + "type": FakeDataType.FLOAT_VECTOR, + "params": {"dim": 3}, + }, + {"name": "text", "type": FakeDataType.VARCHAR}, + {"name": "app_name", "type": FakeDataType.VARCHAR}, + {"name": "user_id", "type": FakeDataType.VARCHAR}, + ], + } + + class FakeMilvusClient: + + def __new__(cls, **kwargs): + client.init_kwargs = kwargs + return client + + monkeypatch.setattr( + milvus_module, + "_load_pymilvus", + lambda: (FakeMilvusClient, FakeDataType), + ) + return client, schema, index_params + + +def embedding_function(texts): + return [[float(index + 1), 0.0, 0.0] for index, _ in enumerate(texts)] + + +async def async_embedding_function(texts): + return embedding_function(texts) + + +def create_service(fake_milvus, **kwargs): + _ = fake_milvus + return MilvusMemoryService( + embedding_function=embedding_function, + dimension=3, + uri="memory.db", + **kwargs, + ) + + +def test_constructor_creates_collection(fake_milvus): + client, schema, index_params = fake_milvus + + service = MilvusMemoryService( + embedding_function=embedding_function, + dimension=3, + uri="memory.db", + token="token", + db_name="default", + collection_name="memory_collection", + consistency_level="Strong", + ) + + assert service is not None + assert client.init_kwargs == { + "uri": "memory.db", + "token": "token", + "db_name": "default", + } + client.create_collection.assert_called_once() + create_kwargs = client.create_collection.call_args.kwargs + assert create_kwargs["collection_name"] == "memory_collection" + assert create_kwargs["consistency_level"] == "Strong" + assert {field["field_name"] for field in schema.fields} >= { + "id", + "embedding", + "text", + "app_name", + "user_id", + "metadata", + } + assert index_params.indexes == [{ + "field_name": "embedding", + "index_type": "AUTOINDEX", + "metric_type": "COSINE", + }] + + +def test_existing_collection_dimension_mismatch_raises(fake_milvus): + client, _, _ = fake_milvus + client.has_collection.return_value = True + client.describe_collection.return_value = { + "auto_id": False, + "fields": [ + {"name": "id", "type": FakeDataType.VARCHAR, "is_primary": True}, + { + "name": "embedding", + "type": FakeDataType.FLOAT_VECTOR, + "params": {"dim": 8}, + }, + {"name": "text", "type": FakeDataType.VARCHAR}, + {"name": "app_name", "type": FakeDataType.VARCHAR}, + {"name": "user_id", "type": FakeDataType.VARCHAR}, + ], + } + + with pytest.raises(ValueError, match="vector dimension 8"): + create_service(fake_milvus) + + +def test_existing_collection_auto_id_raises(fake_milvus): + client, _, _ = fake_milvus + client.has_collection.return_value = True + client.describe_collection.return_value = { + "auto_id": True, + "fields": [ + {"name": "id", "type": FakeDataType.VARCHAR, "is_primary": True}, + { + "name": "embedding", + "type": FakeDataType.FLOAT_VECTOR, + "params": {"dim": 3}, + }, + {"name": "text", "type": FakeDataType.VARCHAR}, + {"name": "app_name", "type": FakeDataType.VARCHAR}, + {"name": "user_id", "type": FakeDataType.VARCHAR}, + ], + } + + with pytest.raises(ValueError, match="auto_id=False"): + create_service(fake_milvus) + + +def test_existing_collection_primary_key_mismatch_raises(fake_milvus): + client, _, _ = fake_milvus + client.has_collection.return_value = True + client.describe_collection.return_value = { + "auto_id": False, + "fields": [ + {"name": "id", "type": FakeDataType.VARCHAR, "is_primary": False}, + { + "name": "embedding", + "type": FakeDataType.FLOAT_VECTOR, + "params": {"dim": 3}, + }, + {"name": "text", "type": FakeDataType.VARCHAR}, + {"name": "app_name", "type": FakeDataType.VARCHAR}, + {"name": "user_id", "type": FakeDataType.VARCHAR}, + ], + } + + with pytest.raises(ValueError, match="primary key"): + create_service(fake_milvus) + + +def test_existing_collection_type_mismatch_raises(fake_milvus): + client, _, _ = fake_milvus + client.has_collection.return_value = True + client.describe_collection.return_value = { + "auto_id": False, + "fields": [ + {"name": "id", "type": FakeDataType.VARCHAR, "is_primary": True}, + { + "name": "embedding", + "type": FakeDataType.VARCHAR, + "params": {"dim": 3}, + }, + {"name": "text", "type": FakeDataType.VARCHAR}, + {"name": "app_name", "type": FakeDataType.VARCHAR}, + {"name": "user_id", "type": FakeDataType.VARCHAR}, + ], + } + + with pytest.raises(ValueError, match="embedding.*expected FLOAT_VECTOR"): + create_service(fake_milvus) + + +def test_invalid_field_name_raises(fake_milvus): + _ = fake_milvus + config = MilvusMemoryServiceConfig( + dimension=3, + app_name_field="app-name", + ) + + with pytest.raises(ValueError, match="valid identifiers"): + MilvusMemoryService( + embedding_function=embedding_function, + config=config, + uri="memory.db", + ) + + +def test_duplicate_field_name_raises(fake_milvus): + _ = fake_milvus + config = MilvusMemoryServiceConfig( + dimension=3, + user_id_field="app_name", + ) + + with pytest.raises(ValueError, match="must be unique"): + MilvusMemoryService( + embedding_function=embedding_function, + config=config, + uri="memory.db", + ) + + +@pytest.mark.asyncio +async def test_add_session_to_memory_upserts_text_events(fake_milvus): + client, _, _ = fake_milvus + service = create_service(fake_milvus) + session = Session( + app_name="test-app", + user_id="test-user", + id="session-1", + last_update_time=1000, + events=[ + Event( + id="event-1", + invocation_id="inv-1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Milvus stores vectors.")] + ), + ), + Event( + id="event-2", + invocation_id="inv-2", + author="model", + timestamp=12346, + content=types.Content( + parts=[types.Part(text="Semantic search is supported.")] + ), + ), + Event(id="event-empty", author="user", timestamp=12347), + Event( + id="event-tool", + author="agent", + timestamp=12348, + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(name="lookup") + ) + ] + ), + ), + ], + ) + + await service.add_session_to_memory(session) + + client.upsert.assert_called_once() + records = client.upsert.call_args.kwargs["data"] + assert len(records) == 2 + assert records[0]["app_name"] == "test-app" + assert records[0]["user_id"] == "test-user" + assert records[0]["session_id"] == "session-1" + assert records[0]["event_id"] == "event-1" + assert records[0]["source"] == "adk_event" + assert records[0]["metadata"]["invocation_id"] == "inv-1" + assert records[1]["embedding"] == [2.0, 0.0, 0.0] + + +@pytest.mark.asyncio +async def test_add_memory_upserts_direct_memory(fake_milvus): + client, _, _ = fake_milvus + service = create_service(fake_milvus) + memory = MemoryEntry( + id="memory-1", + author="user", + timestamp="2026-01-01T00:00:00Z", + content=types.Content(parts=[types.Part(text="Remember Milvus.")]), + custom_metadata={"kind": "fact"}, + ) + + await service.add_memory( + app_name="test-app", + user_id="test-user", + memories=[memory], + custom_metadata={"source_name": "manual"}, + ) + + records = client.upsert.call_args.kwargs["data"] + assert records == [{ + "id": "memory-1", + "embedding": [1.0, 0.0, 0.0], + "text": "Remember Milvus.", + "app_name": "test-app", + "user_id": "test-user", + "session_id": "", + "event_id": "", + "author": "user", + "timestamp": "2026-01-01T00:00:00Z", + "source": "adk_memory", + "metadata": { + "source_name": "manual", + "kind": "fact", + "source": "adk_memory", + }, + }] + + +@pytest.mark.asyncio +async def test_search_memory_returns_entries(fake_milvus): + client, _, _ = fake_milvus + service = MilvusMemoryService( + embedding_function=async_embedding_function, + dimension=3, + uri="memory.db", + search_top_k=3, + ) + client.search.return_value = [[{ + "entity": { + "text": "Milvus supports semantic memory.", + "author": "user", + "timestamp": "2026-01-01T00:00:00Z", + "metadata": {"source": "adk_event"}, + } + }]] + + result = await service.search_memory( + app_name='app "quoted"', + user_id="user-1", + query="semantic memory", + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == ( + "Milvus supports semantic memory." + ) + assert result.memories[0].custom_metadata == {"source": "adk_event"} + search_kwargs = client.search.call_args.kwargs + assert search_kwargs["filter"] == ( + 'app_name == "app \\"quoted\\"" and user_id == "user-1"' + ) + assert search_kwargs["limit"] == 3 + assert search_kwargs["data"] == [[1.0, 0.0, 0.0]] + + +@pytest.mark.asyncio +async def test_embedding_dimension_mismatch_raises(fake_milvus): + client, _, _ = fake_milvus + + def bad_embedding_function(texts): + return [[1.0] for _ in texts] + + service = MilvusMemoryService( + embedding_function=bad_embedding_function, + dimension=3, + uri="memory.db", + ) + + with pytest.raises(ValueError, match="vector dimension 1"): + await service.add_memory( + app_name="test-app", + user_id="test-user", + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text="Bad vector.")]) + ) + ], + ) + client.upsert.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_closes_client(fake_milvus): + client, _, _ = fake_milvus + service = create_service(fake_milvus) + + await service.close() + + client.close.assert_called_once() From 483f0888c8267c85d45b31aeacc0e33cc7d28f86 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 2 Jul 2026 04:07:51 +0000 Subject: [PATCH 2/4] test: add live embedding memory coverage Signed-off-by: Cheney Zhang --- .../memory/test_milvus_memory_service_e2e.py | 167 +++++++++++++++--- 1 file changed, 146 insertions(+), 21 deletions(-) diff --git a/tests/integration/memory/test_milvus_memory_service_e2e.py b/tests/integration/memory/test_milvus_memory_service_e2e.py index bdb8295c..d432b172 100644 --- a/tests/integration/memory/test_milvus_memory_service_e2e.py +++ b/tests/integration/memory/test_milvus_memory_service_e2e.py @@ -21,13 +21,27 @@ from google.adk.events.event import Event from google.adk.sessions.session import Session +from google.genai import Client from google.genai import types +import httpx import pytest from google.adk_community.memory.milvus_memory_service import MilvusMemoryService from google.adk_community.memory.milvus_memory_service import MilvusMemoryServiceConfig _VOCAB = ("milvus", "vector", "memory", "cooking") +_OPENAI_EMBEDDING_MODEL = os.getenv( + "OPENAI_EMBEDDING_MODEL", "text-embedding-3-small" +) +_OPENAI_EMBEDDING_DIMENSION = int( + os.getenv("OPENAI_EMBEDDING_DIMENSION", "1536") +) +_GOOGLE_EMBEDDING_MODEL = os.getenv( + "GOOGLE_EMBEDDING_MODEL", "gemini-embedding-001" +) +_GOOGLE_EMBEDDING_DIMENSION = int( + os.getenv("GOOGLE_EMBEDDING_DIMENSION", "3072") +) def _keyword_embedding(texts): @@ -41,6 +55,40 @@ def _keyword_embedding(texts): return vectors +def _openai_embedding(texts): + response = httpx.post( + "https://api.openai.com/v1/embeddings", + headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"}, + json={ + "model": _OPENAI_EMBEDDING_MODEL, + "input": list(texts), + }, + timeout=30, + ) + if response.status_code >= 400: + raise RuntimeError( + "OpenAI embeddings request failed with " + f"{response.status_code}: {response.text[:500]}" + ) + data = sorted(response.json()["data"], key=lambda item: item["index"]) + return [item["embedding"] for item in data] + + +def _google_embedding(texts): + api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + client = Client(api_key=api_key) + try: + response = client.models.embed_content( + model=_GOOGLE_EMBEDDING_MODEL, + contents=list(texts), + ) + except Exception as exc: + if "User location is not supported" in str(exc): + pytest.skip("Google embeddings API is unavailable from this location.") + raise + return [list(embedding.values) for embedding in response.embeddings] + + def _session() -> Session: return Session( app_name="milvus-memory-e2e", @@ -82,9 +130,12 @@ async def _drop_collection(service: MilvusMemoryService) -> None: pass -async def _run_memory_e2e(config: MilvusMemoryServiceConfig) -> None: +async def _run_memory_e2e( + config: MilvusMemoryServiceConfig, + embedding_function=_keyword_embedding, +) -> None: service = MilvusMemoryService( - embedding_function=_keyword_embedding, + embedding_function=embedding_function, config=config, ) try: @@ -96,9 +147,9 @@ async def _run_memory_e2e(config: MilvusMemoryServiceConfig) -> None: query="vector memory", ) assert result.memories - assert ( - "Milvus stores vector memory" - in result.memories[0].content.parts[0].text + assert any( + "Milvus stores vector memory" in memory.content.parts[0].text + for memory in result.memories ) isolated_result = await service.search_memory( @@ -112,20 +163,70 @@ async def _run_memory_e2e(config: MilvusMemoryServiceConfig) -> None: await service.close() +def _lite_config( + tmp_path: Path, *, dimension: int +) -> MilvusMemoryServiceConfig: + collection_name = f"adk_memory_e2e_{uuid.uuid4().hex[:8]}" + return MilvusMemoryServiceConfig( + uri=str(tmp_path / "milvus_lite.db"), + collection_name=collection_name, + dimension=dimension, + consistency_level="Strong", + ) + + +def _zilliz_config(*, dimension: int) -> MilvusMemoryServiceConfig: + collection_name = f"adk_memory_e2e_{uuid.uuid4().hex[:8]}" + return MilvusMemoryServiceConfig( + uri=os.environ["ZILLIZ_URI"], + token=os.environ["ZILLIZ_TOKEN"], + db_name=os.getenv("ZILLIZ_DB_NAME") or os.getenv("MILVUS_DB_NAME"), + collection_name=collection_name, + dimension=dimension, + consistency_level="Strong", + ) + + @pytest.mark.asyncio @pytest.mark.skipif( os.getenv("RUN_MILVUS_LITE_E2E") != "1", reason="Set RUN_MILVUS_LITE_E2E=1 to run Milvus Lite E2E.", ) async def test_milvus_lite_memory_e2e(tmp_path: Path): - collection_name = f"adk_memory_e2e_{uuid.uuid4().hex[:8]}" await _run_memory_e2e( - MilvusMemoryServiceConfig( - uri=str(tmp_path / "milvus_lite.db"), - collection_name=collection_name, - dimension=len(_VOCAB), - consistency_level="Strong", - ) + _lite_config(tmp_path, dimension=len(_VOCAB)), + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_MILVUS_LITE_E2E") != "1", + reason="Set RUN_MILVUS_LITE_E2E=1 to run Milvus Lite E2E.", +) +@pytest.mark.skipif( + not os.getenv("OPENAI_API_KEY"), + reason="Set OPENAI_API_KEY to run OpenAI embeddings E2E.", +) +async def test_milvus_lite_memory_openai_embedding_e2e(tmp_path: Path): + await _run_memory_e2e( + _lite_config(tmp_path, dimension=_OPENAI_EMBEDDING_DIMENSION), + embedding_function=_openai_embedding, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_MILVUS_LITE_E2E") != "1", + reason="Set RUN_MILVUS_LITE_E2E=1 to run Milvus Lite E2E.", +) +@pytest.mark.skipif( + not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"), + reason="Set GEMINI_API_KEY or GOOGLE_API_KEY to run Google embeddings E2E.", +) +async def test_milvus_lite_memory_google_embedding_e2e(tmp_path: Path): + await _run_memory_e2e( + _lite_config(tmp_path, dimension=_GOOGLE_EMBEDDING_DIMENSION), + embedding_function=_google_embedding, ) @@ -135,14 +236,38 @@ async def test_milvus_lite_memory_e2e(tmp_path: Path): reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", ) async def test_zilliz_cloud_memory_e2e(): - collection_name = f"adk_memory_e2e_{uuid.uuid4().hex[:8]}" await _run_memory_e2e( - MilvusMemoryServiceConfig( - uri=os.environ["ZILLIZ_URI"], - token=os.environ["ZILLIZ_TOKEN"], - db_name=os.getenv("ZILLIZ_DB_NAME") or os.getenv("MILVUS_DB_NAME"), - collection_name=collection_name, - dimension=len(_VOCAB), - consistency_level="Strong", - ) + _zilliz_config(dimension=len(_VOCAB)), + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), + reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", +) +@pytest.mark.skipif( + not os.getenv("OPENAI_API_KEY"), + reason="Set OPENAI_API_KEY to run OpenAI embeddings E2E.", +) +async def test_zilliz_cloud_memory_openai_embedding_e2e(): + await _run_memory_e2e( + _zilliz_config(dimension=_OPENAI_EMBEDDING_DIMENSION), + embedding_function=_openai_embedding, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), + reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", +) +@pytest.mark.skipif( + not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"), + reason="Set GEMINI_API_KEY or GOOGLE_API_KEY to run Google embeddings E2E.", +) +async def test_zilliz_cloud_memory_google_embedding_e2e(): + await _run_memory_e2e( + _zilliz_config(dimension=_GOOGLE_EMBEDDING_DIMENSION), + embedding_function=_google_embedding, ) From 84b764963c63afa3f56d288ac487b314ea8421a9 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 2 Jul 2026 04:19:13 +0000 Subject: [PATCH 3/4] test: gate live Milvus integration checks Signed-off-by: Cheney Zhang --- contributing/samples/milvus_memory/README.md | 31 +++++++++++++++++++ .../memory/test_milvus_memory_service_e2e.py | 28 +++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/contributing/samples/milvus_memory/README.md b/contributing/samples/milvus_memory/README.md index f4ee6651..341f0505 100644 --- a/contributing/samples/milvus_memory/README.md +++ b/contributing/samples/milvus_memory/README.md @@ -40,6 +40,9 @@ Cloud. If you use a non-default Milvus database, set `MILVUS_DB_NAME`. ## Use with Runner +`MilvusMemoryService` accepts any embedding function that returns one vector per +input text. This example uses Gemini embeddings: + ```python from google.adk.agents import Agent from google.adk.runners import Runner @@ -79,6 +82,34 @@ runner = Runner( ) ``` +You can also use another hosted embedding provider as long as `dimension` +matches the returned vectors. For example, OpenAI `text-embedding-3-small` +returns 1536-dimensional vectors: + +```python +import os + +import httpx + + +def embedding_function(texts): + response = httpx.post( + "https://api.openai.com/v1/embeddings", + headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"}, + json={"model": "text-embedding-3-small", "input": list(texts)}, + timeout=30, + ) + response.raise_for_status() + data = sorted(response.json()["data"], key=lambda item: item["index"]) + return [item["embedding"] for item in data] + + +memory_service = MilvusMemoryService( + embedding_function=embedding_function, + dimension=1536, +) +``` + After a session has useful conversation history, persist it: ```python diff --git a/tests/integration/memory/test_milvus_memory_service_e2e.py b/tests/integration/memory/test_milvus_memory_service_e2e.py index d432b172..14bf991c 100644 --- a/tests/integration/memory/test_milvus_memory_service_e2e.py +++ b/tests/integration/memory/test_milvus_memory_service_e2e.py @@ -207,6 +207,10 @@ async def test_milvus_lite_memory_e2e(tmp_path: Path): not os.getenv("OPENAI_API_KEY"), reason="Set OPENAI_API_KEY to run OpenAI embeddings E2E.", ) +@pytest.mark.skipif( + os.getenv("RUN_OPENAI_EMBEDDING_E2E") != "1", + reason="Set RUN_OPENAI_EMBEDDING_E2E=1 to run OpenAI embeddings E2E.", +) async def test_milvus_lite_memory_openai_embedding_e2e(tmp_path: Path): await _run_memory_e2e( _lite_config(tmp_path, dimension=_OPENAI_EMBEDDING_DIMENSION), @@ -223,6 +227,10 @@ async def test_milvus_lite_memory_openai_embedding_e2e(tmp_path: Path): not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"), reason="Set GEMINI_API_KEY or GOOGLE_API_KEY to run Google embeddings E2E.", ) +@pytest.mark.skipif( + os.getenv("RUN_GOOGLE_EMBEDDING_E2E") != "1", + reason="Set RUN_GOOGLE_EMBEDDING_E2E=1 to run Google embeddings E2E.", +) async def test_milvus_lite_memory_google_embedding_e2e(tmp_path: Path): await _run_memory_e2e( _lite_config(tmp_path, dimension=_GOOGLE_EMBEDDING_DIMENSION), @@ -231,6 +239,10 @@ async def test_milvus_lite_memory_google_embedding_e2e(tmp_path: Path): @pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_ZILLIZ_CLOUD_E2E") != "1", + reason="Set RUN_ZILLIZ_CLOUD_E2E=1 to run Zilliz Cloud E2E.", +) @pytest.mark.skipif( not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", @@ -242,6 +254,10 @@ async def test_zilliz_cloud_memory_e2e(): @pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_ZILLIZ_CLOUD_E2E") != "1", + reason="Set RUN_ZILLIZ_CLOUD_E2E=1 to run Zilliz Cloud E2E.", +) @pytest.mark.skipif( not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", @@ -250,6 +266,10 @@ async def test_zilliz_cloud_memory_e2e(): not os.getenv("OPENAI_API_KEY"), reason="Set OPENAI_API_KEY to run OpenAI embeddings E2E.", ) +@pytest.mark.skipif( + os.getenv("RUN_OPENAI_EMBEDDING_E2E") != "1", + reason="Set RUN_OPENAI_EMBEDDING_E2E=1 to run OpenAI embeddings E2E.", +) async def test_zilliz_cloud_memory_openai_embedding_e2e(): await _run_memory_e2e( _zilliz_config(dimension=_OPENAI_EMBEDDING_DIMENSION), @@ -258,6 +278,10 @@ async def test_zilliz_cloud_memory_openai_embedding_e2e(): @pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_ZILLIZ_CLOUD_E2E") != "1", + reason="Set RUN_ZILLIZ_CLOUD_E2E=1 to run Zilliz Cloud E2E.", +) @pytest.mark.skipif( not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", @@ -266,6 +290,10 @@ async def test_zilliz_cloud_memory_openai_embedding_e2e(): not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"), reason="Set GEMINI_API_KEY or GOOGLE_API_KEY to run Google embeddings E2E.", ) +@pytest.mark.skipif( + os.getenv("RUN_GOOGLE_EMBEDDING_E2E") != "1", + reason="Set RUN_GOOGLE_EMBEDDING_E2E=1 to run Google embeddings E2E.", +) async def test_zilliz_cloud_memory_google_embedding_e2e(): await _run_memory_e2e( _zilliz_config(dimension=_GOOGLE_EMBEDDING_DIMENSION), From 01546f3e08fed946050bd89156dc1123ae88717d Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 2 Jul 2026 05:27:30 +0000 Subject: [PATCH 4/4] feat: add Milvus RAG toolset Signed-off-by: Cheney Zhang --- contributing/samples/milvus_rag/README.md | 124 ++++ .../adk_community/tools/milvus/__init__.py | 29 + .../tools/milvus/milvus_toolset.py | 655 ++++++++++++++++++ .../tools/test_milvus_toolset_e2e.py | 274 ++++++++ .../tools/milvus/test_milvus_toolset.py | 303 ++++++++ 5 files changed, 1385 insertions(+) create mode 100644 contributing/samples/milvus_rag/README.md create mode 100644 src/google/adk_community/tools/milvus/__init__.py create mode 100644 src/google/adk_community/tools/milvus/milvus_toolset.py create mode 100644 tests/integration/tools/test_milvus_toolset_e2e.py create mode 100644 tests/unittests/tools/milvus/test_milvus_toolset.py diff --git a/contributing/samples/milvus_rag/README.md b/contributing/samples/milvus_rag/README.md new file mode 100644 index 00000000..f86b3d36 --- /dev/null +++ b/contributing/samples/milvus_rag/README.md @@ -0,0 +1,124 @@ +# Milvus RAG Toolset sample + +This sample shows how to use Milvus as a vector store for ADK retrieval tools. +`MilvusToolset` exposes a `milvus_similarity_search` tool that agents can call +to retrieve relevant context from indexed text. + +## Installation + +```bash +pip install "google-adk-community[milvus]" +``` + +The `milvus` extra installs current `pymilvus` and Milvus Lite packages. +Milvus Lite 3.x local storage is not compatible with older 2.x local database +files or directories, so create a new local database path for new projects. + +## Configuration + +Use `MILVUS_URI` and `MILVUS_TOKEN` for all deployment modes: + +```bash +# Milvus Lite +export MILVUS_URI="./adk_milvus_rag.db" + +# Milvus server +export MILVUS_URI="http://localhost:19530" + +# Zilliz Cloud +export MILVUS_URI="https://your-endpoint.api.gcp-us-west1.zillizcloud.com" +export MILVUS_TOKEN="your-token" +``` + +`MILVUS_TOKEN` is only needed for authenticated deployments such as Zilliz +Cloud. If you use a non-default Milvus database, set `MILVUS_DB_NAME`. + +## Build a Vector Store + +`MilvusVectorStore` accepts any embedding function that returns one vector per +input text. This example uses OpenAI `text-embedding-3-small`, which returns +1536-dimensional vectors: + +```python +import os + +import httpx +from google.adk_community.tools.milvus import MilvusVectorStore +from google.adk_community.tools.milvus import MilvusVectorStoreSettings + + +def embedding_function(texts): + response = httpx.post( + "https://api.openai.com/v1/embeddings", + headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"}, + json={"model": "text-embedding-3-small", "input": list(texts)}, + timeout=30, + ) + response.raise_for_status() + data = sorted(response.json()["data"], key=lambda item: item["index"]) + return [item["embedding"] for item in data] + + +vector_store = MilvusVectorStore( + embedding_function=embedding_function, + settings=MilvusVectorStoreSettings( + collection_name="adk_rag", + dimension=1536, + ), +) + +vector_store.add_texts( + [ + "Milvus Lite is useful for local RAG development.", + "Zilliz Cloud provides managed Milvus for production workloads.", + ], + metadatas=[ + {"source": "milvus-lite"}, + {"source": "zilliz-cloud"}, + ], +) +``` + +## Use with Agent + +```python +from google.adk.agents import Agent +from google.adk_community.tools.milvus import MilvusToolset + + +milvus_toolset = MilvusToolset(vector_store=vector_store) +tools = await milvus_toolset.get_tools_with_prefix() + +agent = Agent( + name="rag_agent", + model="gemini-flash-latest", + instruction="Use retrieval context when answering questions.", + tools=tools, +) +``` + +The exposed tool name is `milvus_similarity_search`. It accepts a single +`query` argument and returns: + +```python +{ + "status": "SUCCESS", + "rows": [ + { + "id": "...", + "content": "...", + "source": "...", + "metadata": {...}, + "distance": 0.12, + } + ], +} +``` + +## Notes + +- `dimension` must match the embedding model output dimension. +- `MilvusVectorStore` creates the collection if it does not already exist and + validates the existing schema before reuse. +- Tool names are prefixed through ADK's `BaseToolset` prefix mechanism, matching + the pattern used by other ADK toolsets. diff --git a/src/google/adk_community/tools/milvus/__init__.py b/src/google/adk_community/tools/milvus/__init__.py new file mode 100644 index 00000000..14e19477 --- /dev/null +++ b/src/google/adk_community/tools/milvus/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus RAG tools for Google ADK agents.""" + +from .milvus_toolset import MilvusSimilaritySearchTool +from .milvus_toolset import MilvusToolset +from .milvus_toolset import MilvusToolSettings +from .milvus_toolset import MilvusVectorStore +from .milvus_toolset import MilvusVectorStoreSettings + +__all__ = [ + "MilvusSimilaritySearchTool", + "MilvusToolset", + "MilvusToolSettings", + "MilvusVectorStore", + "MilvusVectorStoreSettings", +] diff --git a/src/google/adk_community/tools/milvus/milvus_toolset.py b/src/google/adk_community/tools/milvus/milvus_toolset.py new file mode 100644 index 00000000..579fe007 --- /dev/null +++ b/src/google/adk_community/tools/milvus/milvus_toolset.py @@ -0,0 +1,655 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus vector store and retrieval toolset for ADK.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Mapping +from collections.abc import Sequence +import hashlib +import inspect +import json +import os +import re +from typing import Any +from typing import Optional + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.base_toolset import BaseToolset +from google.adk.tools.base_toolset import ToolPredicate +from google.adk.tools.retrieval import BaseRetrievalTool +from google.adk.tools.tool_context import ToolContext +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override + +EmbeddingFunction = Callable[ + [Sequence[str]], + Sequence[Sequence[float]] | Awaitable[Sequence[Sequence[float]]], +] + +DEFAULT_MILVUS_TOOL_NAME_PREFIX = "milvus" +_DEFAULT_COLLECTION_NAME = "adk_rag" +_DEFAULT_LITE_URI = "./adk_milvus_rag.db" +_FIELD_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _load_pymilvus(): + try: + from pymilvus import DataType + from pymilvus import MilvusClient + except ImportError as exc: + raise ImportError( + "pymilvus is required to use MilvusToolset. " + "Install it with: pip install google-adk-community[milvus]" + ) from exc + return MilvusClient, DataType + + +def _env_value(name: str) -> Optional[str]: + value = os.getenv(name) + return value if value else None + + +def _json_safe(value: Mapping[str, object] | None) -> dict[str, object]: + if not value: + return {} + return json.loads(json.dumps(dict(value), default=str)) + + +def _hash_id(*parts: object) -> str: + digest = hashlib.sha256() + for part in parts: + digest.update(str(part).encode("utf-8")) + digest.update(b"\0") + return "adk-rag-" + digest.hexdigest() + + +def _field_name(field: Mapping[str, object]) -> str | None: + raw_name = field.get("name", field.get("field_name")) + return str(raw_name) if raw_name is not None else None + + +def _field_dim(field: Mapping[str, object]) -> int | None: + params = field.get("params", {}) + candidates = [] + if isinstance(params, Mapping): + candidates.append(params.get("dim")) + nested_params = params.get("params") + if isinstance(nested_params, Mapping): + candidates.append(nested_params.get("dim")) + candidates.append(field.get("dim")) + for candidate in candidates: + if candidate is not None: + return int(candidate) + return None + + +def _field_type(field: Mapping[str, object]) -> object | None: + return field.get("type", field.get("data_type", field.get("datatype"))) + + +def _type_label(data_type: object) -> str: + return str(getattr(data_type, "name", data_type)) + + +def _field_type_matches(actual: object, expected: object) -> bool: + if actual == expected: + return True + actual_value = getattr(actual, "value", actual) + expected_value = getattr(expected, "value", expected) + if actual_value == expected_value: + return True + return _type_label(actual) == _type_label(expected) + + +def _validate_field_names(settings: "MilvusVectorStoreSettings") -> None: + field_names = [ + settings.id_field, + settings.vector_field, + settings.content_field, + settings.source_field, + settings.metadata_field, + ] + invalid_names = [ + name for name in field_names if not _FIELD_NAME_PATTERN.fullmatch(name) + ] + if invalid_names: + raise ValueError( + "Milvus field names must be valid identifiers: " + + ", ".join(sorted(set(invalid_names))) + ) + duplicate_names = sorted( + {name for name in field_names if field_names.count(name) > 1} + ) + if duplicate_names: + raise ValueError( + "Milvus field names must be unique: " + ", ".join(duplicate_names) + ) + + +class MilvusVectorStoreSettings(BaseModel): + """Settings for the Milvus vector store used by retrieval tools.""" + + uri: str = Field( + default_factory=lambda: _env_value("MILVUS_URI") or _DEFAULT_LITE_URI + ) + token: Optional[str] = Field( + default_factory=lambda: _env_value("MILVUS_TOKEN") + ) + db_name: Optional[str] = Field( + default_factory=lambda: _env_value("MILVUS_DB_NAME") + ) + collection_name: str = Field(default=_DEFAULT_COLLECTION_NAME, min_length=1) + dimension: int = Field(gt=0) + search_top_k: int = Field(default=4, ge=1, le=100) + metric_type: str = Field(default="COSINE") + index_type: str = Field(default="AUTOINDEX") + consistency_level: Optional[str] = Field(default="Session") + timeout: Optional[float] = Field(default=None, gt=0.0) + auto_create: bool = True + id_field: str = Field(default="id") + vector_field: str = Field(default="embedding") + content_field: str = Field(default="content") + source_field: str = Field(default="source") + metadata_field: str = Field(default="metadata") + id_max_length: int = Field(default=512, gt=0) + content_max_length: int = Field(default=65535, gt=0) + scalar_max_length: int = Field(default=1024, gt=0) + + +class MilvusToolSettings(BaseModel): + """Settings for Milvus tools.""" + + vector_store_settings: MilvusVectorStoreSettings | None = None + similarity_search_description: str = ( + "Performs semantic similarity search over a Milvus vector store and " + "returns relevant context for the user's query." + ) + + +class MilvusVectorStore: + """Utility class for Milvus collection setup, ingestion, and search.""" + + def __init__( + self, + *, + embedding_function: EmbeddingFunction, + settings: MilvusVectorStoreSettings, + ): + _validate_field_names(settings) + self._embedding_function = embedding_function + self._settings = settings + milvus_client, data_type = _load_pymilvus() + self._data_type = data_type + client_kwargs: dict[str, object] = {"uri": self._settings.uri} + if self._settings.token: + client_kwargs["token"] = self._settings.token + if self._settings.db_name: + client_kwargs["db_name"] = self._settings.db_name + if self._settings.timeout is not None: + client_kwargs["timeout"] = self._settings.timeout + self._client = milvus_client(**client_kwargs) + if self._settings.auto_create: + self.create_vector_store() + + def create_vector_store(self) -> None: + """Create the Milvus collection if it does not already exist.""" + if self._client.has_collection( + collection_name=self._settings.collection_name, + timeout=self._settings.timeout, + ): + self._validate_existing_collection() + return + + schema = self._client.create_schema( + auto_id=False, + enable_dynamic_field=True, + ) + schema.add_field( + field_name=self._settings.id_field, + datatype=self._data_type.VARCHAR, + is_primary=True, + max_length=self._settings.id_max_length, + ) + schema.add_field( + field_name=self._settings.vector_field, + datatype=self._data_type.FLOAT_VECTOR, + dim=self._settings.dimension, + ) + schema.add_field( + field_name=self._settings.content_field, + datatype=self._data_type.VARCHAR, + max_length=self._settings.content_max_length, + ) + schema.add_field( + field_name=self._settings.source_field, + datatype=self._data_type.VARCHAR, + max_length=self._settings.scalar_max_length, + ) + schema.add_field( + field_name=self._settings.metadata_field, + datatype=self._data_type.JSON, + ) + + index_params = self._client.prepare_index_params() + index_params.add_index( + field_name=self._settings.vector_field, + index_type=self._settings.index_type, + metric_type=self._settings.metric_type, + ) + create_kwargs: dict[str, object] = { + "collection_name": self._settings.collection_name, + "schema": schema, + "index_params": index_params, + "timeout": self._settings.timeout, + } + if self._settings.consistency_level: + create_kwargs["consistency_level"] = self._settings.consistency_level + self._client.create_collection(**create_kwargs) + + async def create_vector_store_async(self) -> None: + """Asynchronously create the Milvus collection if needed.""" + await asyncio.to_thread(self.create_vector_store) + + def _validate_existing_collection(self) -> None: + description = self._client.describe_collection( + collection_name=self._settings.collection_name, + timeout=self._settings.timeout, + ) + fields = { + name: field + for field in description.get("fields", []) + if isinstance(field, Mapping) and (name := _field_name(field)) + } + required_fields = [ + self._settings.id_field, + self._settings.vector_field, + self._settings.content_field, + self._settings.source_field, + self._settings.metadata_field, + ] + missing_fields = [field for field in required_fields if field not in fields] + if missing_fields: + raise ValueError( + "Milvus collection " + f"{self._settings.collection_name!r} is missing required fields: " + + ", ".join(missing_fields) + ) + if description.get("auto_id") is True: + raise ValueError( + "Milvus collection " + f"{self._settings.collection_name!r} must use auto_id=False." + ) + id_field = fields[self._settings.id_field] + if id_field.get("is_primary") is not None and not id_field.get( + "is_primary" + ): + raise ValueError( + "Milvus collection " + f"{self._settings.collection_name!r} field " + f"{self._settings.id_field!r} must be the primary key." + ) + self._validate_field_type( + id_field, self._data_type.VARCHAR, self._settings.id_field + ) + self._validate_field_type( + fields[self._settings.vector_field], + self._data_type.FLOAT_VECTOR, + self._settings.vector_field, + ) + self._validate_field_type( + fields[self._settings.content_field], + self._data_type.VARCHAR, + self._settings.content_field, + ) + self._validate_field_type( + fields[self._settings.source_field], + self._data_type.VARCHAR, + self._settings.source_field, + ) + self._validate_field_type( + fields[self._settings.metadata_field], + self._data_type.JSON, + self._settings.metadata_field, + ) + vector_dim = _field_dim(fields[self._settings.vector_field]) + if vector_dim is not None and vector_dim != self._settings.dimension: + raise ValueError( + "Milvus collection " + f"{self._settings.collection_name!r} has vector dimension " + f"{vector_dim}, expected {self._settings.dimension}." + ) + + def _validate_field_type( + self, field: Mapping[str, object], expected_type: object, field_name: str + ) -> None: + actual_type = _field_type(field) + if actual_type is not None and not _field_type_matches( + actual_type, expected_type + ): + raise ValueError( + "Milvus collection " + f"{self._settings.collection_name!r} field {field_name!r} has type " + f"{_type_label(actual_type)}, expected {_type_label(expected_type)}." + ) + + def _embed_texts_sync(self, texts: Sequence[str]) -> list[list[float]]: + embeddings = self._embedding_function(texts) + if inspect.isawaitable(embeddings): + raise ValueError( + "embedding_function returned an awaitable. " + "Use add_texts_async() or similarity_search_async() instead." + ) + return self._validate_embeddings(texts, embeddings) + + async def _embed_texts(self, texts: Sequence[str]) -> list[list[float]]: + embeddings = self._embedding_function(texts) + if inspect.isawaitable(embeddings): + embeddings = await embeddings + return self._validate_embeddings(texts, embeddings) + + def _validate_embeddings( + self, texts: Sequence[str], embeddings: Sequence[Sequence[float]] + ) -> list[list[float]]: + vectors = [list(vector) for vector in embeddings] + if len(vectors) != len(texts): + raise ValueError( + "embedding_function returned " + f"{len(vectors)} vectors for {len(texts)} texts." + ) + for vector in vectors: + if len(vector) != self._settings.dimension: + raise ValueError( + "embedding_function returned vector dimension " + f"{len(vector)}, expected {self._settings.dimension}." + ) + return vectors + + def _records( + self, + *, + contents: Sequence[str], + embeddings: Sequence[Sequence[float]], + metadatas: Sequence[Mapping[str, object]], + ids: Sequence[str | None], + ) -> list[dict[str, object]]: + records = [] + for index, (content, embedding, metadata, record_id) in enumerate( + zip(contents, embeddings, metadatas, ids) + ): + safe_metadata = _json_safe(metadata) + source = safe_metadata.get("source", "") + if not isinstance(source, str): + source = str(source) + records.append({ + self._settings.id_field: ( + record_id or _hash_id(index, content, safe_metadata) + )[: self._settings.id_max_length], + self._settings.vector_field: list(embedding), + self._settings.content_field: content[ + : self._settings.content_max_length + ], + self._settings.source_field: source[ + : self._settings.scalar_max_length + ], + self._settings.metadata_field: safe_metadata, + }) + return records + + def _prepare_inputs( + self, + contents: Iterable[str], + metadatas: Iterable[Mapping[str, object]] | None, + ids: Iterable[str] | None, + ) -> tuple[list[str], list[Mapping[str, object]], list[str | None]]: + content_list = list(contents) + metadata_list = ( + list(metadatas) if metadatas is not None else [{} for _ in content_list] + ) + id_list = list(ids) if ids is not None else [None for _ in content_list] + if len(metadata_list) != len(content_list): + raise ValueError( + "metadatas must contain one item per content item. " + f"Got {len(metadata_list)} metadata items for " + f"{len(content_list)} contents." + ) + if len(id_list) != len(content_list): + raise ValueError( + "ids must contain one item per content item. " + f"Got {len(id_list)} ids for {len(content_list)} contents." + ) + return content_list, metadata_list, id_list + + def add_texts( + self, + contents: Iterable[str], + *, + metadatas: Iterable[Mapping[str, object]] | None = None, + ids: Iterable[str] | None = None, + ) -> dict[str, object]: + """Embed and upsert text content into the Milvus vector store.""" + content_list, metadata_list, id_list = self._prepare_inputs( + contents, metadatas, ids + ) + if not content_list: + return {"status": "SUCCESS", "inserted_count": 0} + embeddings = self._embed_texts_sync(content_list) + records = self._records( + contents=content_list, + embeddings=embeddings, + metadatas=metadata_list, + ids=id_list, + ) + self._client.upsert( + collection_name=self._settings.collection_name, + data=records, + timeout=self._settings.timeout, + ) + return {"status": "SUCCESS", "inserted_count": len(records)} + + async def add_texts_async( + self, + contents: Iterable[str], + *, + metadatas: Iterable[Mapping[str, object]] | None = None, + ids: Iterable[str] | None = None, + ) -> dict[str, object]: + """Asynchronously embed and upsert text content into Milvus.""" + content_list, metadata_list, id_list = self._prepare_inputs( + contents, metadatas, ids + ) + if not content_list: + return {"status": "SUCCESS", "inserted_count": 0} + embeddings = await self._embed_texts(content_list) + records = self._records( + contents=content_list, + embeddings=embeddings, + metadatas=metadata_list, + ids=id_list, + ) + await asyncio.to_thread( + self._client.upsert, + collection_name=self._settings.collection_name, + data=records, + timeout=self._settings.timeout, + ) + return {"status": "SUCCESS", "inserted_count": len(records)} + + def _search_result(self, hits: Sequence[Mapping[str, object]]) -> dict: + rows = [] + for hit in hits: + entity = hit.get("entity", {}) + if not isinstance(entity, Mapping): + entity = {} + rows.append({ + "id": hit.get("id") or entity.get(self._settings.id_field), + "content": entity.get(self._settings.content_field, ""), + "source": entity.get(self._settings.source_field, ""), + "metadata": entity.get(self._settings.metadata_field, {}), + "distance": hit.get("distance"), + }) + return {"status": "SUCCESS", "rows": rows} + + def similarity_search( + self, + query: str, + *, + top_k: int | None = None, + filter_expr: str | None = None, + ) -> dict: + """Perform semantic similarity search over the Milvus vector store.""" + query_embedding = self._embed_texts_sync([query])[0] + search_kwargs = self._search_kwargs(query_embedding, top_k, filter_expr) + results = self._client.search(**search_kwargs) + return self._search_result(results[0] if results else []) + + def _search_kwargs( + self, + query_embedding: Sequence[float], + top_k: int | None, + filter_expr: str | None, + ) -> dict[str, object]: + search_kwargs: dict[str, object] = { + "collection_name": self._settings.collection_name, + "data": [list(query_embedding)], + "limit": top_k or self._settings.search_top_k, + "output_fields": [ + self._settings.id_field, + self._settings.content_field, + self._settings.source_field, + self._settings.metadata_field, + ], + "anns_field": self._settings.vector_field, + "search_params": {"metric_type": self._settings.metric_type}, + "timeout": self._settings.timeout, + } + if filter_expr: + search_kwargs["filter"] = filter_expr + return search_kwargs + + async def similarity_search_async( + self, + query: str, + *, + top_k: int | None = None, + filter_expr: str | None = None, + ) -> dict: + """Asynchronously perform semantic similarity search over Milvus.""" + query_embedding = (await self._embed_texts([query]))[0] + search_kwargs = self._search_kwargs(query_embedding, top_k, filter_expr) + results = await asyncio.to_thread( + self._client.search, + **search_kwargs, + ) + return self._search_result(results[0] if results else []) + + async def close(self) -> None: + await asyncio.to_thread(self._client.close) + + +class MilvusSimilaritySearchTool(BaseRetrievalTool): + """Retrieval tool that performs similarity search over Milvus.""" + + def __init__( + self, + *, + vector_store: MilvusVectorStore, + description: str, + name: str = "similarity_search", + ): + super().__init__(name=name, description=description) + self._vector_store = vector_store + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + _ = tool_context + return await self._vector_store.similarity_search_async(args["query"]) + + +class MilvusToolset(BaseToolset): + """Toolset for retrieving context from a Milvus vector store.""" + + def __init__( + self, + *, + embedding_function: EmbeddingFunction | None = None, + vector_store: MilvusVectorStore | None = None, + milvus_tool_settings: MilvusToolSettings | None = None, + tool_filter: ToolPredicate | list[str] | None = None, + ): + super().__init__( + tool_filter=tool_filter, + tool_name_prefix=DEFAULT_MILVUS_TOOL_NAME_PREFIX, + ) + self._tool_settings = milvus_tool_settings + if vector_store is None: + if ( + milvus_tool_settings is None + or milvus_tool_settings.vector_store_settings is None + ): + raise ValueError( + "milvus_tool_settings.vector_store_settings is required when " + "vector_store is not provided." + ) + if embedding_function is None: + raise ValueError( + "embedding_function is required when vector_store is not provided." + ) + vector_store = MilvusVectorStore( + embedding_function=embedding_function, + settings=milvus_tool_settings.vector_store_settings, + ) + elif ( + milvus_tool_settings is None + or milvus_tool_settings.vector_store_settings is None + ): + existing_settings = ( + milvus_tool_settings.model_dump() if milvus_tool_settings else {} + ) + existing_settings["vector_store_settings"] = ( + vector_store._settings # pylint: disable=protected-access + ) + milvus_tool_settings = MilvusToolSettings(**existing_settings) + self._vector_store = vector_store + self._tool_settings = milvus_tool_settings + + @override + async def get_tools( + self, readonly_context: ReadonlyContext | None = None + ) -> list[BaseTool]: + """Get Milvus tools from the toolset.""" + all_tools: list[BaseTool] = [ + MilvusSimilaritySearchTool( + vector_store=self._vector_store, + description=self._tool_settings.similarity_search_description, + ) + ] + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self): + await self._vector_store.close() diff --git a/tests/integration/tools/test_milvus_toolset_e2e.py b/tests/integration/tools/test_milvus_toolset_e2e.py new file mode 100644 index 00000000..0dc5258d --- /dev/null +++ b/tests/integration/tools/test_milvus_toolset_e2e.py @@ -0,0 +1,274 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +import uuid + +from google.genai import Client +import httpx +import pytest + +from google.adk_community.tools.milvus import MilvusToolset +from google.adk_community.tools.milvus import MilvusVectorStore +from google.adk_community.tools.milvus import MilvusVectorStoreSettings + +_VOCAB = ("milvus", "local", "cloud", "production") +_OPENAI_EMBEDDING_MODEL = os.getenv( + "OPENAI_EMBEDDING_MODEL", "text-embedding-3-small" +) +_OPENAI_EMBEDDING_DIMENSION = int( + os.getenv("OPENAI_EMBEDDING_DIMENSION", "1536") +) +_GOOGLE_EMBEDDING_MODEL = os.getenv( + "GOOGLE_EMBEDDING_MODEL", "gemini-embedding-001" +) +_GOOGLE_EMBEDDING_DIMENSION = int( + os.getenv("GOOGLE_EMBEDDING_DIMENSION", "3072") +) + + +def _keyword_embedding(texts): + vectors = [] + for text in texts: + lowered = text.lower() + vector = [float(lowered.count(word)) for word in _VOCAB] + if not any(vector): + vector[-1] = 0.01 + vectors.append(vector) + return vectors + + +def _openai_embedding(texts): + response = httpx.post( + "https://api.openai.com/v1/embeddings", + headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"}, + json={ + "model": _OPENAI_EMBEDDING_MODEL, + "input": list(texts), + }, + timeout=30, + ) + if response.status_code >= 400: + raise RuntimeError( + "OpenAI embeddings request failed with " + f"{response.status_code}: {response.text[:500]}" + ) + data = sorted(response.json()["data"], key=lambda item: item["index"]) + return [item["embedding"] for item in data] + + +def _google_embedding(texts): + api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + client = Client(api_key=api_key) + try: + response = client.models.embed_content( + model=_GOOGLE_EMBEDDING_MODEL, + contents=list(texts), + ) + except Exception as exc: + if "User location is not supported" in str(exc): + pytest.skip("Google embeddings API is unavailable from this location.") + raise + return [list(embedding.values) for embedding in response.embeddings] + + +def _lite_settings( + tmp_path: Path, *, dimension: int +) -> MilvusVectorStoreSettings: + return MilvusVectorStoreSettings( + uri=str(tmp_path / "milvus_toolset.db"), + collection_name=f"adk_rag_e2e_{uuid.uuid4().hex[:8]}", + dimension=dimension, + consistency_level="Strong", + ) + + +def _zilliz_settings(*, dimension: int) -> MilvusVectorStoreSettings: + return MilvusVectorStoreSettings( + uri=os.environ["ZILLIZ_URI"], + token=os.environ["ZILLIZ_TOKEN"], + db_name=os.getenv("ZILLIZ_DB_NAME") or os.getenv("MILVUS_DB_NAME"), + collection_name=f"adk_rag_e2e_{uuid.uuid4().hex[:8]}", + dimension=dimension, + consistency_level="Strong", + ) + + +async def _drop_collection(vector_store: MilvusVectorStore) -> None: + try: + await asyncio.to_thread( + vector_store._client.drop_collection, # pylint: disable=protected-access + collection_name=vector_store._settings.collection_name, # pylint: disable=protected-access + ) + except Exception: + pass + + +async def _run_toolset_e2e( + settings: MilvusVectorStoreSettings, + embedding_function, +) -> None: + vector_store = MilvusVectorStore( + embedding_function=embedding_function, + settings=settings, + ) + try: + await vector_store.add_texts_async( + [ + "Milvus Lite is useful for local RAG development.", + "Zilliz Cloud provides managed Milvus for production workloads.", + ], + metadatas=[ + {"source": "milvus-lite"}, + {"source": "zilliz-cloud"}, + ], + ids=["local-doc", "cloud-doc"], + ) + toolset = MilvusToolset(vector_store=vector_store) + tools = await toolset.get_tools_with_prefix() + assert [tool.name for tool in tools] == ["milvus_similarity_search"] + + result = await tools[0].run_async( + args={"query": "managed cloud production Milvus"}, + tool_context=None, + ) + + assert result["status"] == "SUCCESS" + assert any( + "Zilliz Cloud provides managed Milvus" in row["content"] + for row in result["rows"] + ) + finally: + await _drop_collection(vector_store) + await vector_store.close() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_MILVUS_LITE_E2E") != "1", + reason="Set RUN_MILVUS_LITE_E2E=1 to run Milvus Lite E2E.", +) +async def test_milvus_lite_toolset_e2e(tmp_path: Path): + await _run_toolset_e2e( + _lite_settings(tmp_path, dimension=len(_VOCAB)), + _keyword_embedding, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_MILVUS_LITE_E2E") != "1", + reason="Set RUN_MILVUS_LITE_E2E=1 to run Milvus Lite E2E.", +) +@pytest.mark.skipif( + not os.getenv("OPENAI_API_KEY"), + reason="Set OPENAI_API_KEY to run OpenAI embeddings E2E.", +) +@pytest.mark.skipif( + os.getenv("RUN_OPENAI_EMBEDDING_E2E") != "1", + reason="Set RUN_OPENAI_EMBEDDING_E2E=1 to run OpenAI embeddings E2E.", +) +async def test_milvus_lite_toolset_openai_embedding_e2e(tmp_path: Path): + await _run_toolset_e2e( + _lite_settings(tmp_path, dimension=_OPENAI_EMBEDDING_DIMENSION), + _openai_embedding, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_MILVUS_LITE_E2E") != "1", + reason="Set RUN_MILVUS_LITE_E2E=1 to run Milvus Lite E2E.", +) +@pytest.mark.skipif( + not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"), + reason="Set GEMINI_API_KEY or GOOGLE_API_KEY to run Google embeddings E2E.", +) +@pytest.mark.skipif( + os.getenv("RUN_GOOGLE_EMBEDDING_E2E") != "1", + reason="Set RUN_GOOGLE_EMBEDDING_E2E=1 to run Google embeddings E2E.", +) +async def test_milvus_lite_toolset_google_embedding_e2e(tmp_path: Path): + await _run_toolset_e2e( + _lite_settings(tmp_path, dimension=_GOOGLE_EMBEDDING_DIMENSION), + _google_embedding, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_ZILLIZ_CLOUD_E2E") != "1", + reason="Set RUN_ZILLIZ_CLOUD_E2E=1 to run Zilliz Cloud E2E.", +) +@pytest.mark.skipif( + not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), + reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", +) +async def test_zilliz_cloud_toolset_e2e(): + await _run_toolset_e2e( + _zilliz_settings(dimension=len(_VOCAB)), + _keyword_embedding, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_ZILLIZ_CLOUD_E2E") != "1", + reason="Set RUN_ZILLIZ_CLOUD_E2E=1 to run Zilliz Cloud E2E.", +) +@pytest.mark.skipif( + not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), + reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", +) +@pytest.mark.skipif( + not os.getenv("OPENAI_API_KEY"), + reason="Set OPENAI_API_KEY to run OpenAI embeddings E2E.", +) +@pytest.mark.skipif( + os.getenv("RUN_OPENAI_EMBEDDING_E2E") != "1", + reason="Set RUN_OPENAI_EMBEDDING_E2E=1 to run OpenAI embeddings E2E.", +) +async def test_zilliz_cloud_toolset_openai_embedding_e2e(): + await _run_toolset_e2e( + _zilliz_settings(dimension=_OPENAI_EMBEDDING_DIMENSION), + _openai_embedding, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + os.getenv("RUN_ZILLIZ_CLOUD_E2E") != "1", + reason="Set RUN_ZILLIZ_CLOUD_E2E=1 to run Zilliz Cloud E2E.", +) +@pytest.mark.skipif( + not os.getenv("ZILLIZ_URI") or not os.getenv("ZILLIZ_TOKEN"), + reason="Set ZILLIZ_URI and ZILLIZ_TOKEN to run Zilliz Cloud E2E.", +) +@pytest.mark.skipif( + not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"), + reason="Set GEMINI_API_KEY or GOOGLE_API_KEY to run Google embeddings E2E.", +) +@pytest.mark.skipif( + os.getenv("RUN_GOOGLE_EMBEDDING_E2E") != "1", + reason="Set RUN_GOOGLE_EMBEDDING_E2E=1 to run Google embeddings E2E.", +) +async def test_zilliz_cloud_toolset_google_embedding_e2e(): + await _run_toolset_e2e( + _zilliz_settings(dimension=_GOOGLE_EMBEDDING_DIMENSION), + _google_embedding, + ) diff --git a/tests/unittests/tools/milvus/test_milvus_toolset.py b/tests/unittests/tools/milvus/test_milvus_toolset.py new file mode 100644 index 00000000..b07955b7 --- /dev/null +++ b/tests/unittests/tools/milvus/test_milvus_toolset.py @@ -0,0 +1,303 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from google.adk_community.tools.milvus import milvus_toolset as milvus_module +from google.adk_community.tools.milvus import MilvusToolset +from google.adk_community.tools.milvus import MilvusToolSettings +from google.adk_community.tools.milvus import MilvusVectorStore +from google.adk_community.tools.milvus import MilvusVectorStoreSettings + + +class FakeDataType: + VARCHAR = "VARCHAR" + FLOAT_VECTOR = "FLOAT_VECTOR" + JSON = "JSON" + + +class FakeSchema: + + def __init__(self): + self.fields = [] + + def add_field(self, **kwargs): + self.fields.append(kwargs) + + +class FakeIndexParams: + + def __init__(self): + self.indexes = [] + + def add_index(self, **kwargs): + self.indexes.append(kwargs) + + +@pytest.fixture +def fake_milvus(monkeypatch): + client = MagicMock() + client.has_collection.return_value = False + client.create_schema.return_value = FakeSchema() + client.prepare_index_params.return_value = FakeIndexParams() + client.describe_collection.return_value = { + "auto_id": False, + "fields": [ + {"name": "id", "type": FakeDataType.VARCHAR, "is_primary": True}, + { + "name": "embedding", + "type": FakeDataType.FLOAT_VECTOR, + "params": {"dim": 3}, + }, + {"name": "content", "type": FakeDataType.VARCHAR}, + {"name": "source", "type": FakeDataType.VARCHAR}, + {"name": "metadata", "type": FakeDataType.JSON}, + ], + } + + class FakeMilvusClient: + + def __new__(cls, **kwargs): + client.init_kwargs = kwargs + return client + + monkeypatch.setattr( + milvus_module, + "_load_pymilvus", + lambda: (FakeMilvusClient, FakeDataType), + ) + return client + + +def embedding_function(texts): + return [[float(index + 1), 0.0, 0.0] for index, _ in enumerate(texts)] + + +async def async_embedding_function(texts): + return embedding_function(texts) + + +def create_settings(**kwargs): + return MilvusVectorStoreSettings( + uri="memory.db", + collection_name="rag_collection", + dimension=3, + **kwargs, + ) + + +def create_vector_store(fake_milvus, **kwargs): + _ = fake_milvus + return MilvusVectorStore( + embedding_function=embedding_function, + settings=create_settings(**kwargs), + ) + + +def test_vector_store_creates_collection(fake_milvus): + client = fake_milvus + vector_store = create_vector_store(fake_milvus, consistency_level="Strong") + + assert vector_store is not None + assert client.init_kwargs == {"uri": "memory.db"} + client.create_collection.assert_called_once() + create_kwargs = client.create_collection.call_args.kwargs + assert create_kwargs["collection_name"] == "rag_collection" + assert create_kwargs["consistency_level"] == "Strong" + schema = client.create_schema.return_value + assert {field["field_name"] for field in schema.fields} == { + "id", + "embedding", + "content", + "source", + "metadata", + } + assert client.prepare_index_params.return_value.indexes == [{ + "field_name": "embedding", + "index_type": "AUTOINDEX", + "metric_type": "COSINE", + }] + + +def test_existing_collection_dimension_mismatch_raises(fake_milvus): + client = fake_milvus + client.has_collection.return_value = True + client.describe_collection.return_value["fields"][1]["params"] = {"dim": 8} + + with pytest.raises(ValueError, match="vector dimension 8"): + create_vector_store(fake_milvus) + + +def test_invalid_field_name_raises(fake_milvus): + _ = fake_milvus + settings = create_settings(content_field="bad-name") + + with pytest.raises(ValueError, match="valid identifiers"): + MilvusVectorStore( + embedding_function=embedding_function, + settings=settings, + ) + + +def test_add_texts_upserts_records(fake_milvus): + client = fake_milvus + vector_store = create_vector_store(fake_milvus) + + result = vector_store.add_texts( + ["Milvus stores vectors.", "ADK tools retrieve context."], + metadatas=[ + {"source": "doc-1", "section": 1}, + {"source": "doc-2", "section": 2}, + ], + ids=["id-1", "id-2"], + ) + + assert result == {"status": "SUCCESS", "inserted_count": 2} + client.upsert.assert_called_once() + records = client.upsert.call_args.kwargs["data"] + assert records[0] == { + "id": "id-1", + "embedding": [1.0, 0.0, 0.0], + "content": "Milvus stores vectors.", + "source": "doc-1", + "metadata": {"source": "doc-1", "section": 1}, + } + assert records[1]["embedding"] == [2.0, 0.0, 0.0] + + +def test_add_texts_metadata_length_mismatch_raises(fake_milvus): + vector_store = create_vector_store(fake_milvus) + + with pytest.raises(ValueError, match="one item per content"): + vector_store.add_texts(["one", "two"], metadatas=[{}]) + + +def test_similarity_search_returns_rows(fake_milvus): + client = fake_milvus + vector_store = create_vector_store(fake_milvus, search_top_k=3) + client.search.return_value = [[{ + "id": "id-1", + "distance": 0.12, + "entity": { + "content": "Milvus stores vectors.", + "source": "doc-1", + "metadata": {"section": 1}, + }, + }]] + + result = vector_store.similarity_search("vector database") + + assert result == { + "status": "SUCCESS", + "rows": [{ + "id": "id-1", + "content": "Milvus stores vectors.", + "source": "doc-1", + "metadata": {"section": 1}, + "distance": 0.12, + }], + } + search_kwargs = client.search.call_args.kwargs + assert search_kwargs["collection_name"] == "rag_collection" + assert search_kwargs["data"] == [[1.0, 0.0, 0.0]] + assert search_kwargs["limit"] == 3 + assert "filter" not in search_kwargs + + +@pytest.mark.asyncio +async def test_toolset_returns_prefixed_similarity_search_tool(fake_milvus): + client = fake_milvus + client.search.return_value = [[{ + "id": "id-1", + "distance": 0.12, + "entity": { + "content": "Milvus stores vectors.", + "source": "doc-1", + "metadata": {}, + }, + }]] + toolset = MilvusToolset( + embedding_function=async_embedding_function, + milvus_tool_settings=MilvusToolSettings( + vector_store_settings=create_settings() + ), + ) + + tools = await toolset.get_tools_with_prefix() + + assert [tool.name for tool in tools] == ["milvus_similarity_search"] + result = await tools[0].run_async( + args={"query": "vector database"}, + tool_context=None, + ) + assert result["status"] == "SUCCESS" + assert result["rows"][0]["content"] == "Milvus stores vectors." + + +@pytest.mark.asyncio +async def test_toolset_reuses_vector_store_with_custom_description( + fake_milvus, +): + vector_store = create_vector_store(fake_milvus) + toolset = MilvusToolset( + vector_store=vector_store, + milvus_tool_settings=MilvusToolSettings( + similarity_search_description="Search indexed support content.", + ), + ) + + tools = await toolset.get_tools() + + assert [tool.name for tool in tools] == ["similarity_search"] + assert tools[0].description == "Search indexed support content." + + +@pytest.mark.asyncio +async def test_toolset_filter_can_include_similarity_search(fake_milvus): + toolset = MilvusToolset( + embedding_function=embedding_function, + milvus_tool_settings=MilvusToolSettings( + vector_store_settings=create_settings() + ), + tool_filter=["similarity_search"], + ) + + tools = await toolset.get_tools() + + assert [tool.name for tool in tools] == ["similarity_search"] + + +@pytest.mark.asyncio +async def test_toolset_filter_can_exclude_similarity_search(fake_milvus): + toolset = MilvusToolset( + embedding_function=embedding_function, + milvus_tool_settings=MilvusToolSettings( + vector_store_settings=create_settings() + ), + tool_filter=["other_tool"], + ) + + assert await toolset.get_tools() == [] + + +@pytest.mark.asyncio +async def test_close_closes_client(fake_milvus): + client = fake_milvus + vector_store = create_vector_store(fake_milvus) + + await vector_store.close() + + client.close.assert_called_once()