Sfoglia il codice sorgente

add new milvus helpers

highing666 3 anni fa
parent
commit
9b0aed628d
1 ha cambiato i file con 193 aggiunte e 0 eliminazioni
  1. 193 0
      app/services/milvus_helpers.py

+ 193 - 0
app/services/milvus_helpers.py

@@ -0,0 +1,193 @@
+import sys
+import time
+from typing import List
+
+from loguru import logger
+from pymilvus import (
+    connections,
+    Collection,
+    CollectionSchema,
+    DataType,
+    FieldSchema,
+    utility,
+)
+
+from app.core.config import settings
+
+
+class MilvusHelper:
+    def __init__(self):
+        try:
+            self.collection = None
+            connections.connect(
+                host=settings.MILVUS_HOST,
+                port=settings.MILVUS_PORT,
+            )
+            logger.debug(
+                f"Successfully connect to Milvus with IP: "
+                f"{settings.MILVUS_HOST} and PORT: {settings.MILVUS_PORT}"
+            )
+        except Exception as e:
+            logger.error(f"Failed to connect Milvus: {e}")
+            sys.exit(1)
+
+    def set_collection(self, collection_name: str):
+        try:
+            if self.has_collection(collection_name):
+                self.collection = Collection(name=collection_name)
+            else:
+                raise Exception(f"There has no collection named: {collection_name}")
+        except Exception as e:
+            logger.error(f"Failed  to search data to Milvus: {e}")
+            sys.exit(1)
+
+    # Return if Milvus has the collection
+    def has_collection(self, collection_name: str):
+        try:
+            status = utility.has_collection(collection_name)
+            return status
+        except Exception as e:
+            logger.error(f"Failed to load data to Milvus: {e}")
+            sys.exit(1)
+
+    # Create milvus collection if not exists
+    async def create_collection(self, collection_name: str):
+        try:
+            if not self.has_collection(collection_name):
+                field1 = FieldSchema(
+                    name="pk",
+                    dtype=DataType.INT64,
+                    descrition="int64",
+                    is_primary=True,
+                    auto_id=True,
+                )
+                field2 = FieldSchema(
+                    name="embeddings",
+                    dtype=DataType.FLOAT_VECTOR,
+                    descrition="float vector",
+                    dim=settings.VECTOR_DIMENSION,
+                    is_primary=False,
+                )
+                schema = CollectionSchema(
+                    fields=[field1, field2],
+                    description="Meeting attendee recommendation",
+                )
+                self.collection = Collection(name=collection_name, schema=schema)
+                logger.debug(f"Create Milvus collection: {self.collection}")
+            return "OK"
+        except Exception as e:
+            logger.error(f"Failed to load data to Milvus: {e}")
+            sys.exit(1)
+
+    # Batch insert vectors to milvus collection
+    async def insert(self, collection_name: str, vectors: List):
+        try:
+            await self.create_collection(collection_name)
+            self.collection = Collection(name=collection_name)
+            data = [vectors]
+            mr = self.collection.insert(data)
+            ids = mr.primary_keys
+            self.collection.load()
+            logger.debug(
+                f"Insert vectors to Milvus in collection: "
+                f"{collection_name} with {vectors} rows"
+            )
+            return ids
+        except Exception as e:
+            logger.error(f"Failed to load data to Milvus: {e}")
+            sys.exit(1)
+
+    # Create FLAT index on milvus collection
+    async def create_index(self, collection_name: str):
+        try:
+            self.set_collection(collection_name)
+            default_index = {
+                "index_type": "FLAT",
+                "metric_type": settings.METRIC_TYPE,
+                "params": {},
+            }
+            status = self.collection.create_index(
+                field_name="embeddings", index_params=default_index
+            )
+            if not status.code:
+                logger.debug(
+                    f"Successfully create index in collection: "
+                    f"{collection_name} with param: {default_index}"
+                )
+                return status
+            else:
+                raise Exception(status.message)
+        except Exception as e:
+            logger.error(f"Failed to create index: {e}")
+            sys.exit(1)
+
+    # Delete Milvus collection
+    async def delete_collection(self, collection_name: str):
+        try:
+            self.set_collection(collection_name)
+            self.collection.drop()
+            logger.debug("Successfully drop collection!")
+            return "ok"
+        except Exception as e:
+            logger.error(f"Failed to drop collection: {e}")
+            sys.exit(1)
+
+    # Search vector in milvus collection
+    async def search_vectors(self, collection_name: str, vectors: List, top_k: int):
+        # status = utility.list_collections()
+        try:
+            self.set_collection(collection_name)
+            search_params = {
+                "metric_type": settings.METRIC_TYPE,
+            }
+            try:
+                res = self.collection.search(
+                    vectors, anns_field="embeddings", param=search_params, limit=top_k
+                )
+            except BaseException:
+                self.collection.load()
+                res = self.collection.search(
+                    vectors, anns_field="embeddings", param=search_params, limit=top_k
+                )
+            pk_list = []
+            for hits in res:
+                for hit in hits.ids:
+                    pk_list.append(hit)
+            return pk_list
+        except Exception as e:
+            logger.error(f"Failed to search vectors in Milvus: {e}")
+            sys.exit(1)
+
+    # Get the number of milvus collection
+    async def count(self, collection_name: str):
+        try:
+            self.set_collection(collection_name)
+            num = self.collection.num_entities
+            return num
+        except Exception as e:
+            logger.error(f"Failed to count vectors in Milvus: {e}")
+            sys.exit(1)
+
+    # Query vector by primiary key
+    async def query_vector_by_pk(self, collection_name: str, pk: int):
+        try:
+            self.set_collection(collection_name)
+            expr = f"pk in [{pk}]"
+            try:
+                res = self.collection.query(expr=expr, output_fields=["embeddings"])
+            except BaseException:
+                self.collection.load()
+                res = self.collection.query(expr=expr, output_fields=["embeddings"])
+            vector = res[0]["embeddings"]
+            return vector
+        except Exception as e:
+            logger.error(f"Faild to query vector in Milvus: {e}")
+            sys.exit(1)
+
+
+my_milvus = {}
+
+
+async def get_milvus_cli() -> None:
+    MILVUS_CLI = MilvusHelper()
+    my_milvus.update({"cli": MILVUS_CLI})