瀏覽代碼

replace my_milvus by milvus helpers

highing666 3 年之前
父節點
當前提交
9f83eca1ca

+ 2 - 1
app/api/routers/algorithms.py

@@ -36,7 +36,8 @@ async def get_recommended(meeting_info: AttendeesRecommendationRequest):
         res = await build_recommendations(
             meeting_info.companyId, meeting_info.initiatorId
         )
-    except [KeyError, IndexError, ValueError]:
+    except Exception as e:
+        logger.error(e)
         return {"code": 500, "message": "Recommend failure."}
     else:
         return {"code": 200, "message": "sucess", "data": {"userIdList": res}}

+ 39 - 12
app/controllers/algorithms/meeting_attendee_recommendation.py

@@ -2,37 +2,64 @@ import time
 from typing import List
 from loguru import logger
 
-from app.services.milvus import Milvus
 from app.services.my_redis import MyRedis
+from app.services.milvus import Milvus
+from app.services.milvus_helpers import my_milvus
 
 
+@logger.catch()
 async def build_recommendations(company_id: str, user_id: str) -> List:
+    MILVUS_CLI = my_milvus.get("cli")
+    myredis = MyRedis()
+    r = myredis.get_client()
+
+    user_pk = r.hget(f"meeting_user_id_to_embedding_id_map:{company_id}", user_id)
+    user_id_list = []
+    if user_pk:
+        user_emb = await MILVUS_CLI.query_vector_by_pk(
+            f"meeting_attendee_rec_{company_id}",
+            int(user_pk)
+        )
+        pk_list = await MILVUS_CLI.search_vectors(
+            f"meeting_attendee_rec_{company_id}",
+            [user_emb],
+            10
+        )
+        pk_list = [str(item) for item in pk_list]
+        user_id_list = r.hmget(
+            f"embedding_id_to_meeting_user_id_map:{company_id}", pk_list
+        )
+
+    return user_id_list
+
+
+@logger.catch()
+async def build_recommendations_v2(company_id: str, user_id: str) -> List:
+    old_milvus = Milvus()
     myredis = MyRedis()
-    start = time.perf_counter()
     r = myredis.get_client()
-    logger.debug(time.perf_counter() - start)
-    milvus = Milvus()
 
-    start = time.perf_counter()
     user_pk = r.hget(f"meeting_user_id_to_embedding_id_map:{company_id}", user_id)
-    logger.debug(time.perf_counter() - start)
     user_id_list = []
     if user_pk:
         start = time.perf_counter()
-        user_emb = await milvus.query_embedding_by_pk(
-            "meeting_attendee_rec", "pk", int(user_pk)
+        user_emb = await old_milvus.query_embedding_by_pk(
+            f"meeting_attendee_rec_{company_id}",
+            "pk",
+            int(user_pk)
         )
         logger.debug(time.perf_counter() - start)
         start = time.perf_counter()
-        pk_list = await milvus.search(
-            [user_emb], "meeting_attendee_rec", [company_id], "embeddings", 10
+        pk_list = await old_milvus.search(
+            [user_emb],
+            f"meeting_attendee_rec_{company_id}",
+            "embeddings",
+            10
         )
         logger.debug(time.perf_counter() - start)
         pk_list = [str(item) for item in pk_list]
-        start = time.perf_counter()
         user_id_list = r.hmget(
             f"embedding_id_to_meeting_user_id_map:{company_id}", pk_list
         )
-        logger.debug(time.perf_counter() - start)
 
     return user_id_list

+ 2 - 1
app/core/events.py

@@ -5,11 +5,12 @@ from typing import Callable, Optional
 from fastapi import FastAPI
 
 # from app.controllers.equipment.events import regulate_ahu_freq
+from app.services.milvus_helpers import get_milvus_cli
 
 
 def create_start_app_handler(app: Optional[FastAPI] = None) -> Callable:
     async def start_app() -> None:
         # await regulate_ahu_freq()
-        pass
+        await get_milvus_cli()
 
     return start_app

+ 15 - 5
app/services/milvus.py

@@ -1,3 +1,4 @@
+import time
 from typing import List
 from loguru import logger
 from pymilvus import connections, Collection, SearchResult
@@ -21,18 +22,27 @@ class Milvus:
         pk: int,
         output_field: str = "embeddings",
     ) -> List:
-        connections.connect(host=self._host, port=self._port)
+        start = time.perf_counter()
+        connections.connect("default", host=self._host, port=self._port)
+        logger.debug(time.perf_counter() - start)
         collection = Collection(name=collection_name)
+        start = time.perf_counter()
         collection.load()
+        logger.debug(time.perf_counter() - start)
         expr = f"{primary_key_name} in [{pk}]"
+        start = time.perf_counter()
         res = collection.query(expr=expr, output_fields=[output_field])
+        logger.debug(time.perf_counter() - start)
         try:
             emb = res[0].get(output_field)
         except [KeyError, IndexError] as e:
             emb = []
             logger.error(f"Can't find embedding by {pk}, reason: {e}")
 
+        start = time.perf_counter()
         collection.release()
+        # connections.disconnect("default")
+        logger.debug(time.perf_counter() - start)
 
         return emb
 
@@ -40,27 +50,26 @@ class Milvus:
         self,
         vec_list: List,
         collection_name: str,
-        partition_names: List[str],
         field_name: str,
         limit: int,
     ) -> List:
-        connections.connect(host=self._host, port=self._port)
+        connections.connect("default", host=self._host, port=self._port)
         collection = Collection(name=collection_name)
         collection.load()
 
         SEARCH_PARAM = {
             "metric_type": "L2",
-            "params": {"nprobe": 20},
         }
+        start = time.perf_counter()
         res = collection.search(
             vec_list,
             field_name,
             param=SEARCH_PARAM,
             limit=limit,
             expr=None,
-            partition_names=partition_names,
             output_fields=None,
         )
+        logger.debug(time.perf_counter() - start)
         pk_list = []
         if isinstance(res, SearchResult):
             for hits in res:
@@ -68,5 +77,6 @@ class Milvus:
                     pk_list.append(hit)
 
         collection.release()
+        # connections.disconnect("default")
 
         return pk_list