milvus.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import time
  2. from typing import List
  3. from loguru import logger
  4. from pymilvus import connections, Collection, SearchResult
  5. from app.core.config import settings
  6. class Milvus:
  7. """
  8. Open-source vector database for unstructured data.
  9. """
  10. def __init__(self) -> None:
  11. self._host = settings.MILVUS_HOST
  12. self._port = settings.MILVUS_PORT
  13. async def query_embedding_by_pk(
  14. self,
  15. collection_name: str,
  16. primary_key_name: str,
  17. pk: int,
  18. output_field: str = "embeddings",
  19. ) -> List:
  20. start = time.perf_counter()
  21. connections.connect("default", host=self._host, port=self._port)
  22. logger.debug(time.perf_counter() - start)
  23. collection = Collection(name=collection_name)
  24. start = time.perf_counter()
  25. collection.load()
  26. logger.debug(time.perf_counter() - start)
  27. expr = f"{primary_key_name} in [{pk}]"
  28. start = time.perf_counter()
  29. res = collection.query(expr=expr, output_fields=[output_field])
  30. logger.debug(time.perf_counter() - start)
  31. try:
  32. emb = res[0].get(output_field)
  33. except [KeyError, IndexError] as e:
  34. emb = []
  35. logger.error(f"Can't find embedding by {pk}, reason: {e}")
  36. start = time.perf_counter()
  37. collection.release()
  38. # connections.disconnect("default")
  39. logger.debug(time.perf_counter() - start)
  40. return emb
  41. async def search(
  42. self,
  43. vec_list: List,
  44. collection_name: str,
  45. field_name: str,
  46. limit: int,
  47. ) -> List:
  48. connections.connect("default", host=self._host, port=self._port)
  49. collection = Collection(name=collection_name)
  50. collection.load()
  51. SEARCH_PARAM = {
  52. "metric_type": "L2",
  53. }
  54. start = time.perf_counter()
  55. res = collection.search(
  56. vec_list,
  57. field_name,
  58. param=SEARCH_PARAM,
  59. limit=limit,
  60. expr=None,
  61. output_fields=None,
  62. )
  63. logger.debug(time.perf_counter() - start)
  64. pk_list = []
  65. if isinstance(res, SearchResult):
  66. for hits in res:
  67. for hit in hits.ids:
  68. pk_list.append(hit)
  69. collection.release()
  70. # connections.disconnect("default")
  71. return pk_list