|
@@ -0,0 +1,72 @@
|
|
|
+from typing import List
|
|
|
+from loguru import logger
|
|
|
+from pymilvus_orm import connections, Collection, SearchResult
|
|
|
+
|
|
|
+from app.core.config import settings
|
|
|
+
|
|
|
+
|
|
|
+class Milvus:
|
|
|
+ """
|
|
|
+ Open-source vector database for unstructured data.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self) -> None:
|
|
|
+ self._host = settings.MILVUS_HOST
|
|
|
+ self._port = settings.MILVUS_PORT
|
|
|
+
|
|
|
+ async def query_embedding_by_pk(
|
|
|
+ self,
|
|
|
+ collection_name: str,
|
|
|
+ primary_key_name: str,
|
|
|
+ pk: int,
|
|
|
+ output_field: str = "embeddings",
|
|
|
+ ) -> List:
|
|
|
+ connections.connect(host=self._host, port=self._port)
|
|
|
+ collection = Collection(name=collection_name)
|
|
|
+ collection.load()
|
|
|
+ expr = f"{primary_key_name} in [{pk}]"
|
|
|
+ res = collection.query(expr=expr, output_fields=[output_field])
|
|
|
+ try:
|
|
|
+ emb = res[0].get(output_field)
|
|
|
+ except [KeyError, IndexError] as e:
|
|
|
+ emb = []
|
|
|
+ logger.error(f"Can't find embedding by {pk}, reason: {e}")
|
|
|
+
|
|
|
+ collection.release()
|
|
|
+
|
|
|
+ return emb
|
|
|
+
|
|
|
+ async def search(
|
|
|
+ self,
|
|
|
+ vec_list: List,
|
|
|
+ collection_name: str,
|
|
|
+ partition_names: List[str],
|
|
|
+ field_name: str,
|
|
|
+ limit: int,
|
|
|
+ ) -> List:
|
|
|
+ connections.connect(host=self._host, port=self._port)
|
|
|
+ collection = Collection(name=collection_name)
|
|
|
+ collection.load()
|
|
|
+
|
|
|
+ SEARCH_PARAM = {
|
|
|
+ "metric_type": "L2",
|
|
|
+ "params": {"nprobe": 20},
|
|
|
+ }
|
|
|
+ res = collection.search(
|
|
|
+ vec_list,
|
|
|
+ field_name,
|
|
|
+ param=SEARCH_PARAM,
|
|
|
+ limit=limit,
|
|
|
+ expr=None,
|
|
|
+ partition_names=partition_names,
|
|
|
+ output_fields=None,
|
|
|
+ )
|
|
|
+ pk_list = []
|
|
|
+ if isinstance(res, SearchResult):
|
|
|
+ for hits in res:
|
|
|
+ for hit in hits.ids:
|
|
|
+ pk_list.append(hit)
|
|
|
+
|
|
|
+ collection.release()
|
|
|
+
|
|
|
+ return pk_list
|