from typing import List from loguru import logger from pymilvus import connections, Collection, SearchResult from app.core.config import settings class Milvus: """ Open-source vector database for unstructured data. """ def __init__(self) -> None: self._host = settings.MILVUS_HOST self._port = settings.MILVUS_PORT async def query_embedding_by_pk( self, collection_name: str, primary_key_name: str, pk: int, output_field: str = "embeddings", ) -> List: connections.connect(host=self._host, port=self._port) collection = Collection(name=collection_name) collection.load() expr = f"{primary_key_name} in [{pk}]" res = collection.query(expr=expr, output_fields=[output_field]) try: emb = res[0].get(output_field) except [KeyError, IndexError] as e: emb = [] logger.error(f"Can't find embedding by {pk}, reason: {e}") collection.release() return emb async def search( self, vec_list: List, collection_name: str, partition_names: List[str], field_name: str, limit: int, ) -> List: connections.connect(host=self._host, port=self._port) collection = Collection(name=collection_name) collection.load() SEARCH_PARAM = { "metric_type": "L2", "params": {"nprobe": 20}, } res = collection.search( vec_list, field_name, param=SEARCH_PARAM, limit=limit, expr=None, partition_names=partition_names, output_fields=None, ) pk_list = [] if isinstance(res, SearchResult): for hits in res: for hit in hits.ids: pk_list.append(hit) collection.release() return pk_list