Browse Source

add milvus and redis service

highing666 3 years ago
parent
commit
5c5cc9e7e1

+ 38 - 0
app/controllers/algorithms/meeting_attendee_recommendation.py

@@ -0,0 +1,38 @@
+import time
+from typing import List
+from loguru import logger
+
+from app.services.milvus import Milvus
+from app.services.my_redis import MyRedis
+
+
+async def build_recommendations(company_id: str, user_id: str) -> List:
+    myredis = MyRedis()
+    start = time.perf_counter()
+    r = myredis.get_client()
+    logger.debug(time.perf_counter() - start)
+    milvus = Milvus()
+
+    start = time.perf_counter()
+    user_pk = r.hget(f"meeting_user_id_to_embedding_id_map:{company_id}", user_id)
+    logger.debug(time.perf_counter() - start)
+    user_id_list = []
+    if user_pk:
+        start = time.perf_counter()
+        user_emb = await milvus.query_embedding_by_pk(
+            "meeting_attendee_rec", "pk", int(user_pk)
+        )
+        logger.debug(time.perf_counter() - start)
+        start = time.perf_counter()
+        pk_list = await milvus.search(
+            [user_emb], "meeting_attendee_rec", [company_id], "embeddings", 10
+        )
+        logger.debug(time.perf_counter() - start)
+        pk_list = [str(item) for item in pk_list]
+        start = time.perf_counter()
+        user_id_list = r.hmget(
+            f"embedding_id_to_meeting_user_id_map:{company_id}", pk_list
+        )
+        logger.debug(time.perf_counter() - start)
+
+    return user_id_list

+ 72 - 0
app/services/milvus.py

@@ -0,0 +1,72 @@
+from typing import List
+from loguru import logger
+from pymilvus_orm import connections, Collection, SearchResult
+
+from app.core.config import settings
+
+
+class Milvus:
+    """
+    Open-source vector database for unstructured data.
+    """
+
+    def __init__(self) -> None:
+        self._host = settings.MILVUS_HOST
+        self._port = settings.MILVUS_PORT
+
+    async def query_embedding_by_pk(
+        self,
+        collection_name: str,
+        primary_key_name: str,
+        pk: int,
+        output_field: str = "embeddings",
+    ) -> List:
+        connections.connect(host=self._host, port=self._port)
+        collection = Collection(name=collection_name)
+        collection.load()
+        expr = f"{primary_key_name} in [{pk}]"
+        res = collection.query(expr=expr, output_fields=[output_field])
+        try:
+            emb = res[0].get(output_field)
+        except [KeyError, IndexError] as e:
+            emb = []
+            logger.error(f"Can't find embedding by {pk}, reason: {e}")
+
+        collection.release()
+
+        return emb
+
+    async def search(
+        self,
+        vec_list: List,
+        collection_name: str,
+        partition_names: List[str],
+        field_name: str,
+        limit: int,
+    ) -> List:
+        connections.connect(host=self._host, port=self._port)
+        collection = Collection(name=collection_name)
+        collection.load()
+
+        SEARCH_PARAM = {
+            "metric_type": "L2",
+            "params": {"nprobe": 20},
+        }
+        res = collection.search(
+            vec_list,
+            field_name,
+            param=SEARCH_PARAM,
+            limit=limit,
+            expr=None,
+            partition_names=partition_names,
+            output_fields=None,
+        )
+        pk_list = []
+        if isinstance(res, SearchResult):
+            for hits in res:
+                for hit in hits.ids:
+                    pk_list.append(hit)
+
+        collection.release()
+
+        return pk_list

+ 16 - 0
app/services/my_redis.py

@@ -0,0 +1,16 @@
+import redis
+
+from app.core.config import settings
+
+
+class MyRedis:
+    def __init__(self) -> None:
+        self._host = settings.REDIS_HOST
+        self._port = settings.REDIS_PORT
+        self._db = settings.REDIS_DB
+        self._password = settings.REDIS_PASSWORD
+
+    def get_client(self) -> redis.Redis:
+        return redis.Redis(
+            host=self._host, port=self._port, db=self._db, password=self._password
+        )