milvus.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import List
  2. from loguru import logger
  3. from pymilvus import connections, Collection, SearchResult
  4. from app.core.config import settings
  5. class Milvus:
  6. """
  7. Open-source vector database for unstructured data.
  8. """
  9. def __init__(self) -> None:
  10. self._host = settings.MILVUS_HOST
  11. self._port = settings.MILVUS_PORT
  12. async def query_embedding_by_pk(
  13. self,
  14. collection_name: str,
  15. primary_key_name: str,
  16. pk: int,
  17. output_field: str = "embeddings",
  18. ) -> List:
  19. connections.connect(host=self._host, port=self._port)
  20. collection = Collection(name=collection_name)
  21. collection.load()
  22. expr = f"{primary_key_name} in [{pk}]"
  23. res = collection.query(expr=expr, output_fields=[output_field])
  24. try:
  25. emb = res[0].get(output_field)
  26. except [KeyError, IndexError] as e:
  27. emb = []
  28. logger.error(f"Can't find embedding by {pk}, reason: {e}")
  29. collection.release()
  30. return emb
  31. async def search(
  32. self,
  33. vec_list: List,
  34. collection_name: str,
  35. partition_names: List[str],
  36. field_name: str,
  37. limit: int,
  38. ) -> List:
  39. connections.connect(host=self._host, port=self._port)
  40. collection = Collection(name=collection_name)
  41. collection.load()
  42. SEARCH_PARAM = {
  43. "metric_type": "L2",
  44. "params": {"nprobe": 20},
  45. }
  46. res = collection.search(
  47. vec_list,
  48. field_name,
  49. param=SEARCH_PARAM,
  50. limit=limit,
  51. expr=None,
  52. partition_names=partition_names,
  53. output_fields=None,
  54. )
  55. pk_list = []
  56. if isinstance(res, SearchResult):
  57. for hits in res:
  58. for hit in hits.ids:
  59. pk_list.append(hit)
  60. collection.release()
  61. return pk_list