|
@@ -0,0 +1,193 @@
|
|
|
|
+import sys
|
|
|
|
+import time
|
|
|
|
+from typing import List
|
|
|
|
+
|
|
|
|
+from loguru import logger
|
|
|
|
+from pymilvus import (
|
|
|
|
+ connections,
|
|
|
|
+ Collection,
|
|
|
|
+ CollectionSchema,
|
|
|
|
+ DataType,
|
|
|
|
+ FieldSchema,
|
|
|
|
+ utility,
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+from app.core.config import settings
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class MilvusHelper:
|
|
|
|
+ def __init__(self):
|
|
|
|
+ try:
|
|
|
|
+ self.collection = None
|
|
|
|
+ connections.connect(
|
|
|
|
+ host=settings.MILVUS_HOST,
|
|
|
|
+ port=settings.MILVUS_PORT,
|
|
|
|
+ )
|
|
|
|
+ logger.debug(
|
|
|
|
+ f"Successfully connect to Milvus with IP: "
|
|
|
|
+ f"{settings.MILVUS_HOST} and PORT: {settings.MILVUS_PORT}"
|
|
|
|
+ )
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to connect Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ def set_collection(self, collection_name: str):
|
|
|
|
+ try:
|
|
|
|
+ if self.has_collection(collection_name):
|
|
|
|
+ self.collection = Collection(name=collection_name)
|
|
|
|
+ else:
|
|
|
|
+ raise Exception(f"There has no collection named: {collection_name}")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to search data to Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Return if Milvus has the collection
|
|
|
|
+ def has_collection(self, collection_name: str):
|
|
|
|
+ try:
|
|
|
|
+ status = utility.has_collection(collection_name)
|
|
|
|
+ return status
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to load data to Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Create milvus collection if not exists
|
|
|
|
+ async def create_collection(self, collection_name: str):
|
|
|
|
+ try:
|
|
|
|
+ if not self.has_collection(collection_name):
|
|
|
|
+ field1 = FieldSchema(
|
|
|
|
+ name="pk",
|
|
|
|
+ dtype=DataType.INT64,
|
|
|
|
+ descrition="int64",
|
|
|
|
+ is_primary=True,
|
|
|
|
+ auto_id=True,
|
|
|
|
+ )
|
|
|
|
+ field2 = FieldSchema(
|
|
|
|
+ name="embeddings",
|
|
|
|
+ dtype=DataType.FLOAT_VECTOR,
|
|
|
|
+ descrition="float vector",
|
|
|
|
+ dim=settings.VECTOR_DIMENSION,
|
|
|
|
+ is_primary=False,
|
|
|
|
+ )
|
|
|
|
+ schema = CollectionSchema(
|
|
|
|
+ fields=[field1, field2],
|
|
|
|
+ description="Meeting attendee recommendation",
|
|
|
|
+ )
|
|
|
|
+ self.collection = Collection(name=collection_name, schema=schema)
|
|
|
|
+ logger.debug(f"Create Milvus collection: {self.collection}")
|
|
|
|
+ return "OK"
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to load data to Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Batch insert vectors to milvus collection
|
|
|
|
+ async def insert(self, collection_name: str, vectors: List):
|
|
|
|
+ try:
|
|
|
|
+ await self.create_collection(collection_name)
|
|
|
|
+ self.collection = Collection(name=collection_name)
|
|
|
|
+ data = [vectors]
|
|
|
|
+ mr = self.collection.insert(data)
|
|
|
|
+ ids = mr.primary_keys
|
|
|
|
+ self.collection.load()
|
|
|
|
+ logger.debug(
|
|
|
|
+ f"Insert vectors to Milvus in collection: "
|
|
|
|
+ f"{collection_name} with {vectors} rows"
|
|
|
|
+ )
|
|
|
|
+ return ids
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to load data to Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Create FLAT index on milvus collection
|
|
|
|
+ async def create_index(self, collection_name: str):
|
|
|
|
+ try:
|
|
|
|
+ self.set_collection(collection_name)
|
|
|
|
+ default_index = {
|
|
|
|
+ "index_type": "FLAT",
|
|
|
|
+ "metric_type": settings.METRIC_TYPE,
|
|
|
|
+ "params": {},
|
|
|
|
+ }
|
|
|
|
+ status = self.collection.create_index(
|
|
|
|
+ field_name="embeddings", index_params=default_index
|
|
|
|
+ )
|
|
|
|
+ if not status.code:
|
|
|
|
+ logger.debug(
|
|
|
|
+ f"Successfully create index in collection: "
|
|
|
|
+ f"{collection_name} with param: {default_index}"
|
|
|
|
+ )
|
|
|
|
+ return status
|
|
|
|
+ else:
|
|
|
|
+ raise Exception(status.message)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to create index: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Delete Milvus collection
|
|
|
|
+ async def delete_collection(self, collection_name: str):
|
|
|
|
+ try:
|
|
|
|
+ self.set_collection(collection_name)
|
|
|
|
+ self.collection.drop()
|
|
|
|
+ logger.debug("Successfully drop collection!")
|
|
|
|
+ return "ok"
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to drop collection: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Search vector in milvus collection
|
|
|
|
+ async def search_vectors(self, collection_name: str, vectors: List, top_k: int):
|
|
|
|
+ # status = utility.list_collections()
|
|
|
|
+ try:
|
|
|
|
+ self.set_collection(collection_name)
|
|
|
|
+ search_params = {
|
|
|
|
+ "metric_type": settings.METRIC_TYPE,
|
|
|
|
+ }
|
|
|
|
+ try:
|
|
|
|
+ res = self.collection.search(
|
|
|
|
+ vectors, anns_field="embeddings", param=search_params, limit=top_k
|
|
|
|
+ )
|
|
|
|
+ except BaseException:
|
|
|
|
+ self.collection.load()
|
|
|
|
+ res = self.collection.search(
|
|
|
|
+ vectors, anns_field="embeddings", param=search_params, limit=top_k
|
|
|
|
+ )
|
|
|
|
+ pk_list = []
|
|
|
|
+ for hits in res:
|
|
|
|
+ for hit in hits.ids:
|
|
|
|
+ pk_list.append(hit)
|
|
|
|
+ return pk_list
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to search vectors in Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Get the number of milvus collection
|
|
|
|
+ async def count(self, collection_name: str):
|
|
|
|
+ try:
|
|
|
|
+ self.set_collection(collection_name)
|
|
|
|
+ num = self.collection.num_entities
|
|
|
|
+ return num
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Failed to count vectors in Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+ # Query vector by primiary key
|
|
|
|
+ async def query_vector_by_pk(self, collection_name: str, pk: int):
|
|
|
|
+ try:
|
|
|
|
+ self.set_collection(collection_name)
|
|
|
|
+ expr = f"pk in [{pk}]"
|
|
|
|
+ try:
|
|
|
|
+ res = self.collection.query(expr=expr, output_fields=["embeddings"])
|
|
|
|
+ except BaseException:
|
|
|
|
+ self.collection.load()
|
|
|
|
+ res = self.collection.query(expr=expr, output_fields=["embeddings"])
|
|
|
|
+ vector = res[0]["embeddings"]
|
|
|
|
+ return vector
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Faild to query vector in Milvus: {e}")
|
|
|
|
+ sys.exit(1)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+my_milvus = {}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def get_milvus_cli() -> None:
|
|
|
|
+ MILVUS_CLI = MilvusHelper()
|
|
|
|
+ my_milvus.update({"cli": MILVUS_CLI})
|