milvus_helpers.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import sys
  2. import time
  3. from typing import List
  4. from loguru import logger
  5. from pymilvus import (
  6. connections,
  7. Collection,
  8. CollectionSchema,
  9. DataType,
  10. FieldSchema,
  11. utility,
  12. )
  13. from app.core.config import settings
  14. class MilvusHelper:
  15. def __init__(self):
  16. try:
  17. self.collection = None
  18. connections.connect(
  19. host=settings.MILVUS_HOST,
  20. port=settings.MILVUS_PORT,
  21. )
  22. logger.debug(
  23. f"Successfully connect to Milvus with IP: "
  24. f"{settings.MILVUS_HOST} and PORT: {settings.MILVUS_PORT}"
  25. )
  26. except Exception as e:
  27. logger.error(f"Failed to connect Milvus: {e}")
  28. sys.exit(1)
  29. def set_collection(self, collection_name: str):
  30. try:
  31. if self.has_collection(collection_name):
  32. self.collection = Collection(name=collection_name)
  33. else:
  34. raise Exception(f"There has no collection named: {collection_name}")
  35. except Exception as e:
  36. logger.error(f"Failed to search data to Milvus: {e}")
  37. sys.exit(1)
  38. # Return if Milvus has the collection
  39. def has_collection(self, collection_name: str):
  40. try:
  41. status = utility.has_collection(collection_name)
  42. return status
  43. except Exception as e:
  44. logger.error(f"Failed to load data to Milvus: {e}")
  45. sys.exit(1)
  46. # Create milvus collection if not exists
  47. async def create_collection(self, collection_name: str):
  48. try:
  49. if not self.has_collection(collection_name):
  50. field1 = FieldSchema(
  51. name="pk",
  52. dtype=DataType.INT64,
  53. descrition="int64",
  54. is_primary=True,
  55. auto_id=True,
  56. )
  57. field2 = FieldSchema(
  58. name="embeddings",
  59. dtype=DataType.FLOAT_VECTOR,
  60. descrition="float vector",
  61. dim=settings.VECTOR_DIMENSION,
  62. is_primary=False,
  63. )
  64. schema = CollectionSchema(
  65. fields=[field1, field2],
  66. description="Meeting attendee recommendation",
  67. )
  68. self.collection = Collection(name=collection_name, schema=schema)
  69. logger.debug(f"Create Milvus collection: {self.collection}")
  70. return "OK"
  71. except Exception as e:
  72. logger.error(f"Failed to load data to Milvus: {e}")
  73. sys.exit(1)
  74. # Batch insert vectors to milvus collection
  75. async def insert(self, collection_name: str, vectors: List):
  76. try:
  77. await self.create_collection(collection_name)
  78. self.collection = Collection(name=collection_name)
  79. data = [vectors]
  80. mr = self.collection.insert(data)
  81. ids = mr.primary_keys
  82. self.collection.load()
  83. logger.debug(
  84. f"Insert vectors to Milvus in collection: "
  85. f"{collection_name} with {vectors} rows"
  86. )
  87. return ids
  88. except Exception as e:
  89. logger.error(f"Failed to load data to Milvus: {e}")
  90. sys.exit(1)
  91. # Create FLAT index on milvus collection
  92. async def create_index(self, collection_name: str):
  93. try:
  94. self.set_collection(collection_name)
  95. default_index = {
  96. "index_type": "FLAT",
  97. "metric_type": settings.METRIC_TYPE,
  98. "params": {},
  99. }
  100. status = self.collection.create_index(
  101. field_name="embeddings", index_params=default_index
  102. )
  103. if not status.code:
  104. logger.debug(
  105. f"Successfully create index in collection: "
  106. f"{collection_name} with param: {default_index}"
  107. )
  108. return status
  109. else:
  110. raise Exception(status.message)
  111. except Exception as e:
  112. logger.error(f"Failed to create index: {e}")
  113. sys.exit(1)
  114. # Delete Milvus collection
  115. async def delete_collection(self, collection_name: str):
  116. try:
  117. self.set_collection(collection_name)
  118. self.collection.drop()
  119. logger.debug("Successfully drop collection!")
  120. return "ok"
  121. except Exception as e:
  122. logger.error(f"Failed to drop collection: {e}")
  123. sys.exit(1)
  124. # Search vector in milvus collection
  125. async def search_vectors(self, collection_name: str, vectors: List, top_k: int):
  126. # status = utility.list_collections()
  127. try:
  128. self.set_collection(collection_name)
  129. search_params = {
  130. "metric_type": settings.METRIC_TYPE,
  131. }
  132. try:
  133. res = self.collection.search(
  134. vectors, anns_field="embeddings", param=search_params, limit=top_k
  135. )
  136. except BaseException:
  137. self.collection.load()
  138. res = self.collection.search(
  139. vectors, anns_field="embeddings", param=search_params, limit=top_k
  140. )
  141. pk_list = []
  142. for hits in res:
  143. for hit in hits.ids:
  144. pk_list.append(hit)
  145. return pk_list
  146. except Exception as e:
  147. logger.error(f"Failed to search vectors in Milvus: {e}")
  148. sys.exit(1)
  149. # Get the number of milvus collection
  150. async def count(self, collection_name: str):
  151. try:
  152. self.set_collection(collection_name)
  153. num = self.collection.num_entities
  154. return num
  155. except Exception as e:
  156. logger.error(f"Failed to count vectors in Milvus: {e}")
  157. sys.exit(1)
  158. # Query vector by primiary key
  159. async def query_vector_by_pk(self, collection_name: str, pk: int):
  160. try:
  161. self.set_collection(collection_name)
  162. expr = f"pk in [{pk}]"
  163. try:
  164. res = self.collection.query(expr=expr, output_fields=["embeddings"])
  165. except BaseException:
  166. self.collection.load()
  167. res = self.collection.query(expr=expr, output_fields=["embeddings"])
  168. vector = res[0]["embeddings"]
  169. return vector
  170. except Exception as e:
  171. logger.error(f"Faild to query vector in Milvus: {e}")
  172. sys.exit(1)
  173. my_milvus = {}
  174. async def get_milvus_cli() -> None:
  175. MILVUS_CLI = MilvusHelper()
  176. my_milvus.update({"cli": MILVUS_CLI})