milvus.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import time
  2. from loguru import logger
  3. from pymilvus import connections, Collection, SearchResult
  4. from app.core.config import settings
  5. class Milvus:
  6. """
  7. Open-source vector database for unstructured data.
  8. """
  9. def __init__(self) -> None:
  10. self._host = settings.MILVUS_HOST
  11. self._port = settings.MILVUS_PORT
  12. async def query_embedding_by_pk(
  13. self,
  14. collection_name: str,
  15. primary_key_name: str,
  16. pk: int,
  17. output_field: str = "embeddings",
  18. ) -> list:
  19. start = time.perf_counter()
  20. connections.connect("default", host=self._host, port=self._port)
  21. logger.debug(time.perf_counter() - start)
  22. collection = Collection(name=collection_name)
  23. start = time.perf_counter()
  24. collection.load()
  25. logger.debug(time.perf_counter() - start)
  26. expr = f"{primary_key_name} in [{pk}]"
  27. start = time.perf_counter()
  28. res = collection.query(expr=expr, output_fields=[output_field])
  29. logger.debug(time.perf_counter() - start)
  30. try:
  31. emb = res[0].get(output_field)
  32. except [KeyError, IndexError] as e:
  33. emb = []
  34. logger.error(f"Can't find embedding by {pk}, reason: {e}")
  35. start = time.perf_counter()
  36. collection.release()
  37. # connections.disconnect("default")
  38. logger.debug(time.perf_counter() - start)
  39. return emb
  40. async def search(
  41. self,
  42. vec_list: list,
  43. collection_name: str,
  44. field_name: str,
  45. limit: int,
  46. ) -> list:
  47. connections.connect("default", host=self._host, port=self._port)
  48. collection = Collection(name=collection_name)
  49. collection.load()
  50. SEARCH_PARAM = {
  51. "metric_type": "L2",
  52. }
  53. start = time.perf_counter()
  54. res = collection.search(
  55. vec_list,
  56. field_name,
  57. param=SEARCH_PARAM,
  58. limit=limit,
  59. )
  60. logger.debug(time.perf_counter() - start)
  61. pk_list = []
  62. if isinstance(res, SearchResult):
  63. for hits in res:
  64. for hit in hits.ids:
  65. pk_list.append(hit)
  66. collection.release()
  67. # connections.disconnect("default")
  68. return pk_list