import sys
import time
from typing import List

from loguru import logger
from pymilvus import (
    connections,
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    utility,
)

from app.core.config import settings


class MilvusHelper:
    def __init__(self):
        try:
            self.collection = None
            connections.connect(
                host=settings.MILVUS_HOST,
                port=settings.MILVUS_PORT,
            )
            logger.debug(
                f"Successfully connect to Milvus with IP: "
                f"{settings.MILVUS_HOST} and PORT: {settings.MILVUS_PORT}"
            )
        except Exception as e:
            logger.error(f"Failed to connect Milvus: {e}")
            sys.exit(1)

    def set_collection(self, collection_name: str):
        try:
            if self.has_collection(collection_name):
                self.collection = Collection(name=collection_name)
            else:
                raise Exception(f"There has no collection named: {collection_name}")
        except Exception as e:
            logger.error(f"Failed  to search data to Milvus: {e}")
            sys.exit(1)

    # Return if Milvus has the collection
    def has_collection(self, collection_name: str):
        try:
            status = utility.has_collection(collection_name)
            return status
        except Exception as e:
            logger.error(f"Failed to load data to Milvus: {e}")
            sys.exit(1)

    # Create milvus collection if not exists
    async def create_collection(self, collection_name: str):
        try:
            if not self.has_collection(collection_name):
                field1 = FieldSchema(
                    name="pk",
                    dtype=DataType.INT64,
                    descrition="int64",
                    is_primary=True,
                    auto_id=True,
                )
                field2 = FieldSchema(
                    name="embeddings",
                    dtype=DataType.FLOAT_VECTOR,
                    descrition="float vector",
                    dim=settings.VECTOR_DIMENSION,
                    is_primary=False,
                )
                schema = CollectionSchema(
                    fields=[field1, field2],
                    description="Meeting attendee recommendation",
                )
                self.collection = Collection(name=collection_name, schema=schema)
                logger.debug(f"Create Milvus collection: {self.collection}")
            return "OK"
        except Exception as e:
            logger.error(f"Failed to load data to Milvus: {e}")
            sys.exit(1)

    # Batch insert vectors to milvus collection
    async def insert(self, collection_name: str, vectors: List):
        try:
            await self.create_collection(collection_name)
            self.collection = Collection(name=collection_name)
            data = [vectors]
            mr = self.collection.insert(data)
            ids = mr.primary_keys
            self.collection.load()
            logger.debug(
                f"Insert vectors to Milvus in collection: "
                f"{collection_name} with {vectors} rows"
            )
            return ids
        except Exception as e:
            logger.error(f"Failed to load data to Milvus: {e}")
            sys.exit(1)

    # Create FLAT index on milvus collection
    async def create_index(self, collection_name: str):
        try:
            self.set_collection(collection_name)
            default_index = {
                "index_type": "FLAT",
                "metric_type": settings.METRIC_TYPE,
                "params": {},
            }
            status = self.collection.create_index(
                field_name="embeddings", index_params=default_index
            )
            if not status.code:
                logger.debug(
                    f"Successfully create index in collection: "
                    f"{collection_name} with param: {default_index}"
                )
                return status
            else:
                raise Exception(status.message)
        except Exception as e:
            logger.error(f"Failed to create index: {e}")
            sys.exit(1)

    # Delete Milvus collection
    async def delete_collection(self, collection_name: str):
        try:
            self.set_collection(collection_name)
            self.collection.drop()
            logger.debug("Successfully drop collection!")
            return "ok"
        except Exception as e:
            logger.error(f"Failed to drop collection: {e}")
            sys.exit(1)

    # Search vector in milvus collection
    async def search_vectors(self, collection_name: str, vectors: List, top_k: int):
        # status = utility.list_collections()
        try:
            self.set_collection(collection_name)
            search_params = {
                "metric_type": settings.METRIC_TYPE,
            }
            try:
                res = self.collection.search(
                    vectors, anns_field="embeddings", param=search_params, limit=top_k
                )
            except BaseException:
                self.collection.load()
                res = self.collection.search(
                    vectors, anns_field="embeddings", param=search_params, limit=top_k
                )
            pk_list = []
            for hits in res:
                for hit in hits.ids:
                    pk_list.append(hit)
            return pk_list
        except Exception as e:
            logger.error(f"Failed to search vectors in Milvus: {e}")
            sys.exit(1)

    # Get the number of milvus collection
    async def count(self, collection_name: str):
        try:
            self.set_collection(collection_name)
            num = self.collection.num_entities
            return num
        except Exception as e:
            logger.error(f"Failed to count vectors in Milvus: {e}")
            sys.exit(1)

    # Query vector by primiary key
    async def query_vector_by_pk(self, collection_name: str, pk: int):
        try:
            self.set_collection(collection_name)
            expr = f"pk in [{pk}]"
            try:
                res = self.collection.query(expr=expr, output_fields=["embeddings"])
            except BaseException:
                self.collection.load()
                res = self.collection.query(expr=expr, output_fields=["embeddings"])
            vector = res[0]["embeddings"]
            return vector
        except Exception as e:
            logger.error(f"Faild to query vector in Milvus: {e}")
            sys.exit(1)


my_milvus = {}


async def get_milvus_cli() -> None:
    MILVUS_CLI = MilvusHelper()
    my_milvus.update({"cli": MILVUS_CLI})