12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import time
- from typing import List
- from loguru import logger
- from pymilvus import connections, Collection, SearchResult
- from app.core.config import settings
- class Milvus:
- """
- Open-source vector database for unstructured data.
- """
- def __init__(self) -> None:
- self._host = settings.MILVUS_HOST
- self._port = settings.MILVUS_PORT
- async def query_embedding_by_pk(
- self,
- collection_name: str,
- primary_key_name: str,
- pk: int,
- output_field: str = "embeddings",
- ) -> List:
- start = time.perf_counter()
- connections.connect("default", host=self._host, port=self._port)
- logger.debug(time.perf_counter() - start)
- collection = Collection(name=collection_name)
- start = time.perf_counter()
- collection.load()
- logger.debug(time.perf_counter() - start)
- expr = f"{primary_key_name} in [{pk}]"
- start = time.perf_counter()
- res = collection.query(expr=expr, output_fields=[output_field])
- logger.debug(time.perf_counter() - start)
- try:
- emb = res[0].get(output_field)
- except [KeyError, IndexError] as e:
- emb = []
- logger.error(f"Can't find embedding by {pk}, reason: {e}")
- start = time.perf_counter()
- collection.release()
- # connections.disconnect("default")
- logger.debug(time.perf_counter() - start)
- return emb
- async def search(
- self,
- vec_list: List,
- collection_name: str,
- field_name: str,
- limit: int,
- ) -> List:
- connections.connect("default", host=self._host, port=self._port)
- collection = Collection(name=collection_name)
- collection.load()
- SEARCH_PARAM = {
- "metric_type": "L2",
- }
- start = time.perf_counter()
- res = collection.search(
- vec_list,
- field_name,
- param=SEARCH_PARAM,
- limit=limit,
- expr=None,
- output_fields=None,
- )
- logger.debug(time.perf_counter() - start)
- pk_list = []
- if isinstance(res, SearchResult):
- for hits in res:
- for hit in hits.ids:
- pk_list.append(hit)
- collection.release()
- # connections.disconnect("default")
- return pk_list
|