milvus_helpers.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import sys
  2. from loguru import logger
  3. from pymilvus import (
  4. connections,
  5. Collection,
  6. CollectionSchema,
  7. DataType,
  8. FieldSchema,
  9. utility,
  10. )
  11. from app.core.config import settings
  12. class MilvusHelper:
  13. def __init__(self):
  14. try:
  15. self.collection = None
  16. connections.connect(
  17. host=settings.MILVUS_HOST,
  18. port=settings.MILVUS_PORT,
  19. )
  20. logger.debug(
  21. f"Successfully connect to Milvus with IP: "
  22. f"{settings.MILVUS_HOST} and PORT: {settings.MILVUS_PORT}"
  23. )
  24. except Exception as e:
  25. logger.error(f"Failed to connect Milvus: {e}")
  26. sys.exit(1)
  27. def set_collection(self, collection_name: str):
  28. try:
  29. if self.has_collection(collection_name):
  30. self.collection = Collection(name=collection_name)
  31. else:
  32. raise Exception(f"There has no collection named: {collection_name}")
  33. except Exception as e:
  34. logger.error(f"Failed to search data to Milvus: {e}")
  35. sys.exit(1)
  36. # Return if Milvus has the collection
  37. @staticmethod
  38. def has_collection(collection_name: str):
  39. try:
  40. status = utility.has_collection(collection_name)
  41. return status
  42. except Exception as e:
  43. logger.error(f"Failed to load data to Milvus: {e}")
  44. sys.exit(1)
  45. # Create milvus collection if not exists
  46. async def create_collection(self, collection_name: str):
  47. try:
  48. if not self.has_collection(collection_name):
  49. field1 = FieldSchema(
  50. name="pk",
  51. dtype=DataType.INT64,
  52. descrition="int64",
  53. is_primary=True,
  54. auto_id=True,
  55. )
  56. field2 = FieldSchema(
  57. name="embeddings",
  58. dtype=DataType.FLOAT_VECTOR,
  59. descrition="float vector",
  60. dim=settings.VECTOR_DIMENSION,
  61. is_primary=False,
  62. )
  63. schema = CollectionSchema(
  64. fields=[field1, field2],
  65. description="Meeting attendee recommendation",
  66. )
  67. self.collection = Collection(name=collection_name, schema=schema)
  68. logger.debug(f"Create Milvus collection: {self.collection}")
  69. return "OK"
  70. except Exception as e:
  71. logger.error(f"Failed to load data to Milvus: {e}")
  72. sys.exit(1)
  73. # Batch insert vectors to milvus collection
  74. async def insert(self, collection_name: str, vectors: list):
  75. try:
  76. await self.create_collection(collection_name)
  77. self.collection = Collection(name=collection_name)
  78. data = [vectors]
  79. mr = self.collection.insert(data)
  80. ids = mr.primary_keys
  81. self.collection.load()
  82. logger.debug(
  83. f"Insert vectors to Milvus in collection: "
  84. f"{collection_name} with {vectors} rows"
  85. )
  86. return ids
  87. except Exception as e:
  88. logger.error(f"Failed to load data to Milvus: {e}")
  89. sys.exit(1)
  90. # Create FLAT index on milvus collection
  91. async def create_index(self, collection_name: str):
  92. try:
  93. self.set_collection(collection_name)
  94. default_index = {
  95. "index_type": "FLAT",
  96. "metric_type": settings.METRIC_TYPE,
  97. "params": {},
  98. }
  99. status = self.collection.create_index(
  100. field_name="embeddings", index_params=default_index
  101. )
  102. if not status.code:
  103. logger.debug(
  104. f"Successfully create index in collection: "
  105. f"{collection_name} with param: {default_index}"
  106. )
  107. return status
  108. else:
  109. raise Exception(status.message)
  110. except Exception as e:
  111. logger.error(f"Failed to create index: {e}")
  112. sys.exit(1)
  113. # Delete Milvus collection
  114. async def delete_collection(self, collection_name: str):
  115. try:
  116. self.set_collection(collection_name)
  117. self.collection.drop()
  118. logger.debug("Successfully drop collection!")
  119. return "ok"
  120. except Exception as e:
  121. logger.error(f"Failed to drop collection: {e}")
  122. sys.exit(1)
  123. # Search vector in milvus collection
  124. async def search_vectors(self, collection_name: str, vectors: list, top_k: int):
  125. # status = utility.list_collections()
  126. try:
  127. self.set_collection(collection_name)
  128. search_params = {
  129. "metric_type": settings.METRIC_TYPE,
  130. }
  131. try:
  132. res = self.collection.search(
  133. vectors, anns_field="embeddings", param=search_params, limit=top_k
  134. )
  135. except BaseException:
  136. self.collection.load()
  137. res = self.collection.search(
  138. vectors, anns_field="embeddings", param=search_params, limit=top_k
  139. )
  140. pk_list = []
  141. for hits in res:
  142. for hit in hits.ids:
  143. pk_list.append(hit)
  144. return pk_list
  145. except Exception as e:
  146. logger.error(f"Failed to search vectors in Milvus: {e}")
  147. sys.exit(1)
  148. # Get the number of milvus collection
  149. async def count(self, collection_name: str):
  150. try:
  151. self.set_collection(collection_name)
  152. num = self.collection.num_entities
  153. return num
  154. except Exception as e:
  155. logger.error(f"Failed to count vectors in Milvus: {e}")
  156. sys.exit(1)
  157. # Query vector by primary key
  158. async def query_vector_by_pk(self, collection_name: str, pk: int):
  159. try:
  160. self.set_collection(collection_name)
  161. expr = f"pk in [{pk}]"
  162. try:
  163. res = self.collection.query(expr=expr, output_fields=["embeddings"])
  164. except BaseException:
  165. self.collection.load()
  166. res = self.collection.query(expr=expr, output_fields=["embeddings"])
  167. vector = res[0]["embeddings"]
  168. return vector
  169. except Exception as e:
  170. logger.error(f"Failed to query vector in Milvus: {e}")
  171. sys.exit(1)
  172. my_milvus = {}
  173. async def get_milvus_cli() -> None:
  174. MILVUS_CLI = MilvusHelper()
  175. my_milvus.update({"cli": MILVUS_CLI})