# -*- coding: utf-8 -*-

from operator import attrgetter
from typing import Dict, List, Optional, Tuple

import numpy as np
from fastapi import HTTPException
from httpx import AsyncClient
from loguru import logger
from sqlalchemy.orm import Session

from app.controllers.equipment.controller import EquipmentController
from app.crud.space.weight import get_weights_by_vav
from app.models.domain.devices import (
    ACATVAInstructionsRequest,
    ACATVAInstructionsRequestV2,
)
from app.schemas.equipment import VAVBox, FCU
from app.schemas.instructions import ACATVAInstructions
from app.schemas.sapce_weight import SpaceWeight
from app.schemas.space import SpaceATVA
from app.services.platform import DataPlatformService
from app.services.transfer import Duoduo, SpaceInfoService, Season
from app.utils.date import get_time_str


class VAVController(EquipmentController):
    def __init__(self, equipment: VAVBox):
        super(VAVController, self).__init__()
        self.equipment = equipment

    async def get_strategy(self):
        strategy = "Plan A"
        for space in self.equipment.spaces:
            for eq in space.equipment:
                if isinstance(eq, FCU):
                    strategy = "Plan B"
                    break

        return strategy

    async def build_virtual_temperature(self) -> Tuple[float, float]:
        target_list, realtime_list = [], []
        buffer_list = []
        strategy = await self.get_strategy()
        for space in self.equipment.spaces:
            if not np.isnan(space.temperature_target):
                target_list.append(space.temperature_target)
                realtime_list.append(space.realtime_temperature)
                if strategy == "Plan B":
                    for eq in space.equipment:
                        if isinstance(eq, FCU):
                            buffer = (4 - eq.air_valve_speed) / 4
                            buffer_list.append(buffer)
                            break
        total_target = buffer_list + target_list
        total_realtime = buffer_list + realtime_list
        if total_target and total_realtime:
            target_result = np.array(total_target).sum() / len(target_list)
            realtime_result = np.array(total_realtime).sum() / len(realtime_list)
            self.equipment.setting_temperature = target_result
        else:
            target_result, realtime_result = np.NAN, np.NAN

        return target_result, realtime_result

    async def get_supply_air_flow_set(
        self, temperature_set: float, temperature_realtime: float
    ) -> float:
        if np.isnan(temperature_set) or np.isnan(temperature_realtime):
            supply_air_flow_set = 0.0
        else:
            temperature_supply = self.equipment.supply_air_temperature
            if np.isnan(temperature_supply):
                temperature_supply = 19.0
            try:
                ratio = abs(
                    1
                    + (temperature_realtime - temperature_set)
                    / (temperature_set - temperature_supply)
                )
            except ZeroDivisionError:
                ratio = 1
            supply_air_flow_set = self.equipment.supply_air_flow * ratio

        supply_air_flow_set = max(
            self.equipment.supply_air_flow_lower_limit, supply_air_flow_set
        )
        supply_air_flow_set = min(
            self.equipment.supply_air_flow_upper_limit, supply_air_flow_set
        )
        self.equipment.supply_air_flow_set = supply_air_flow_set
        self.equipment.virtual_target_temperature = temperature_set
        self.equipment.virtual_realtime_temperature = temperature_realtime

        return supply_air_flow_set

    async def run(self):
        temperature_set, temperature_realtime = await self.build_virtual_temperature()
        await self.get_supply_air_flow_set(temperature_set, temperature_realtime)
        self.equipment.running_status = True

    def get_results(self):
        return self.equipment


class VAVControllerV2(VAVController):
    def __init__(
        self,
        equipment: VAVBox,
        weights: Optional[List[SpaceWeight]] = None,
        season: Optional[Season] = None,
    ):
        super(VAVControllerV2, self).__init__(equipment)
        self.weights = weights
        self.season = season

    async def build_virtual_temperature(self) -> None:
        valid_spaces = []
        weights = []
        for sp in self.equipment.spaces:
            if sp.realtime_temperature > 0.0 and sp.temperature_target > 0.0:
                valid_spaces.append(sp)
                for weight in self.weights:
                    if weight.space_id == sp.id:
                        weights.append(weight)

        if valid_spaces:
            weights = sorted(weights, key=lambda x: x.temporary_weight_update_time)
            if weights[-1].temporary_weight_update_time > get_time_str(
                60 * 60 * 2, flag="ago"
            ):
                #  If someone has submitted a feedback in past two hours, meet the need.
                weight_dic = {weight.space_id: 0.0 for weight in weights}
                weight_dic.update({weights[-1].space_id: weights[-1].temporary_weight})
            else:
                weight_dic = {
                    weight.space_id: weight.default_weight for weight in weights
                }
                total_weight_value = 0.0
                for v in weight_dic.values():
                    total_weight_value += v
                if total_weight_value > 0:
                    weight_dic = {
                        k: v / total_weight_value for k, v in weight_dic.items()
                    }
                else:
                    weight_dic.update({list(weight_dic.keys())[0]: 1.0})

            try:
                virtual_target, virtual_realtime = 0.0, 0.0
                for sp in valid_spaces:
                    virtual_target += sp.temperature_target * weight_dic.get(sp.id)
                    virtual_realtime += sp.realtime_temperature * weight_dic.get(sp.id)
            except KeyError:
                logger.error(f"{self.equipment.id} has wrong vav-space relation")
                raise HTTPException(
                    status_code=404, detail="This VAV box has wrong eq-sp relation"
                )

            self.equipment.virtual_target_temperature = virtual_target
            self.equipment.virtual_realtime_temperature = virtual_realtime
        else:
            self.equipment.virtual_target_temperature = np.NAN
            self.equipment.virtual_realtime_temperature = np.NAN

    async def rectify(self) -> Tuple[float, float]:
        bad_spaces = list()
        for sp in self.equipment.spaces:
            if self.season == Season.heating:
                if sp.realtime_temperature > max(
                    26.0, sp.temperature_target
                ) or sp.realtime_temperature < min(20.0, sp.temperature_target):
                    if sp.temperature_target > 0.0:
                        bad_spaces.append(sp)
            elif self.season == Season.cooling:
                if sp.realtime_temperature > max(
                    27.0, sp.temperature_target
                ) or sp.realtime_temperature < min(22.0, sp.temperature_target):
                    if sp.temperature_target > 0.0:
                        bad_spaces.append(sp)

        if bad_spaces:
            virtual_diff = self.equipment.virtual_target_temperature - self.equipment.virtual_realtime_temperature
            if self.season == Season.cooling:
                bad_spaces = sorted(bad_spaces, key=attrgetter("diff"))
                worst = bad_spaces[0]
                if worst.diff * virtual_diff >= 0:
                    if abs(worst.diff) > abs(virtual_diff):
                        self.equipment.virtual_target_temperature = worst.temperature_target
                        self.equipment.virtual_realtime_temperature = worst.realtime_temperature
                else:
                    if worst.diff < 0:
                        self.equipment.virtual_target_temperature = min(22.0, worst.temperature_target) + 0.5
                    else:
                        self.equipment.virtual_target_temperature = max(26.0, worst.temperature_target) - 0.5
                    self.equipment.virtual_realtime_temperature = worst.realtime_temperature
            elif self.season == Season.heating:
                bad_spaces = sorted(bad_spaces, key=attrgetter("diff"), reverse=True)
                worst = bad_spaces[0]
                if worst.diff * virtual_diff >= 0:
                    if abs(worst.diff) > abs(virtual_diff):
                        self.equipment.virtual_target_temperature = worst.temperature_target
                        self.equipment.virtual_realtime_temperature = worst.realtime_temperature
                else:
                    if worst.diff > 0:
                        self.equipment.virtual_target_temperature = max(26.0, worst.temperature_target) - 0.5
                    else:
                        self.equipment.virtual_target_temperature = min(20.0, worst.temperature_target) + 0.5
                    self.equipment.virtual_realtime_temperature = worst.realtime_temperature

        return (
            self.equipment.virtual_target_temperature,
            self.equipment.virtual_realtime_temperature,
        )

    async def run(self) -> None:
        await self.build_virtual_temperature()
        temperature_set, temperature_realtime = await self.rectify()
        await self.get_supply_air_flow_set(temperature_set, temperature_realtime)
        self.equipment.running_status = True


class VAVControllerV3(VAVControllerV2):
    def __init__(self, vav_params: VAVBox, season: Season):
        super(VAVControllerV3, self).__init__(vav_params)
        self.season = season

    def get_valid_spaces(self) -> List:
        valid_spaces = list()
        for sp in self.equipment.spaces:
            if sp.realtime_temperature > 0.0 and sp.temperature_target > 0.0:
                valid_spaces.append(sp)

        return valid_spaces

    async def build_virtual_temperature(self) -> None:
        valid_spaces = self.get_valid_spaces()

        if not valid_spaces:
            virtual_realtime, virtual_target = np.NAN, np.NAN
        else:
            sorted_spaces = sorted(
                valid_spaces, key=lambda x: x.vav_temporary_update_time
            )
            if sorted_spaces[-1].vav_temporary_update_time > get_time_str(
                60 * 60 * 2, flag="ago"
            ):
                virtual_realtime = sorted_spaces[-1].realtime_temperature
                virtual_target = sorted_spaces[-1].temperature_target
            else:
                virtual_realtime, virtual_target = 0.0, 0.0
                total_weight = 0.0
                for sp in valid_spaces:
                    temp_weight = sp.vav_default_weight
                    virtual_realtime += sp.realtime_temperature * temp_weight
                    virtual_target += sp.temperature_target * temp_weight
                    total_weight += temp_weight

                if total_weight == 0:
                    for sp in valid_spaces:
                        virtual_realtime += sp.realtime_temperature
                        virtual_target += sp.temperature_target
                    virtual_realtime /= len(valid_spaces)
                    virtual_target /= len(valid_spaces)
                else:
                    virtual_realtime /= total_weight
                    virtual_target /= total_weight

        self.equipment.virtual_realtime_temperature = virtual_realtime
        self.equipment.virtual_target_temperature = virtual_target


class VAVControllerV4(VAVControllerV3):
    def __init__(self, vav_params: VAVBox, season: Season, return_air_temp: float):
        super().__init__(vav_params, season)
        self.return_air_temp = return_air_temp

    def get_next_temp_set(
        self, virtual_realtime_temp: float, virtual_target_temp: float
    ) -> float:
        if np.isnan(virtual_realtime_temp) or np.isnan(virtual_target_temp):
            next_temp_set = np.NAN
        else:
            next_temp_set = (
                virtual_target_temp + self.return_air_temp - virtual_realtime_temp
            )

        self.equipment.setting_temperature = next_temp_set

        return next_temp_set

    async def run(self) -> None:
        await self.build_virtual_temperature()
        temperature_set, temperature_realtime = await self.rectify()
        self.get_next_temp_set(temperature_realtime, temperature_set)
        self.equipment.running_status = True


async def fetch_vav_control_params(project_id: str, equipment_id: str) -> Dict:
    async with AsyncClient() as client:
        duo_duo = Duoduo(client, project_id)
        platform = DataPlatformService(client, project_id)

        season = await duo_duo.get_season()
        served_spaces = await duo_duo.get_space_by_equipment(equipment_id)
        space_objects = []
        realtime_supply_air_temperature_list = []
        for sp in served_spaces:
            sp_id = sp.get("id")
            transfer = SpaceInfoService(client, project_id, sp_id)
            current_target = await transfer.get_current_temperature_target()
            realtime_temperature = await platform.get_realtime_temperature(sp_id)

            related_equipment = await transfer.get_equipment()
            equipment_objects = []
            for eq in related_equipment:
                if eq.get("category") == "ACATFC":
                    speed = await platform.get_fan_speed(eq.get("id"))
                    temp_fcu_params = {"id": eq.get("id"), "air_valve_speed": speed}
                    fcu = FCU(**temp_fcu_params)
                    equipment_objects.append(fcu)
                elif eq.get("category") == "ACATAH":
                    realtime_supply_air_temperature_list.append(
                        await platform.get_realtime_supply_air_temperature(eq.get("id"))
                    )
            temp_space_params = {
                "id": sp_id,
                "realtime_temperature": realtime_temperature,
                "equipment": equipment_objects,
                "temperature_target": current_target,
                "diff": current_target - realtime_temperature,
            }
            space = SpaceATVA(**temp_space_params)
            space_objects.append(space)

        realtime_supply_air_temperature = np.array(
            realtime_supply_air_temperature_list
        ).mean()
        realtime_supply_air_flow = await platform.get_realtime_supply_air_flow(
            equipment_id
        )
        lower_limit, upper_limit = await platform.get_air_flow_limit(equipment_id)

        vav_params = {
            "id": equipment_id,
            "spaces": space_objects,
            "supply_air_temperature": realtime_supply_air_temperature,
            "supply_air_flow": realtime_supply_air_flow,
            "supply_air_flow_lower_limit": lower_limit,
            "supply_air_flow_upper_limit": upper_limit,
            "season": season,
        }

        return vav_params


@logger.catch()
async def get_vav_control_v1(project: str, equipment_id: str) -> VAVBox:
    vav_params = await fetch_vav_control_params(project, equipment_id)
    vav = VAVBox(**vav_params)

    vav_controller = VAVController(vav)
    await vav_controller.run()
    regulated_vav = vav_controller.get_results()

    return regulated_vav


@logger.catch()
async def get_vav_control_v2(db: Session, project_id: str, equipment_id: str) -> VAVBox:
    vav_params = await fetch_vav_control_params(project_id, equipment_id)
    vav = VAVBox(**vav_params)
    weights = get_weights_by_vav(db, equipment_id)

    vav_controller = VAVControllerV2(
        vav, [SpaceWeight.from_orm(weight) for weight in weights], vav_params["season"]
    )
    await vav_controller.run()
    regulated_vav = vav_controller.get_results()

    return regulated_vav


@logger.catch()
async def build_acatva_instructions(
    params: ACATVAInstructionsRequest,
) -> ACATVAInstructions:
    space_params = []
    for sp in params.spaces:
        temp_sp = SpaceATVA(**sp.dict())
        temp_sp.diff = temp_sp.temperature_target - temp_sp.realtime_temperature
        space_params.append(temp_sp)

    if params.supply_air_temperature == -1:
        if params.acatah_supply_air_temperature == -1:
            supply_air_temperature = np.NAN
        else:
            supply_air_temperature = params.acatah_supply_air_temperature
    else:
        supply_air_temperature = params.supply_air_temperature

    vav_params = VAVBox(
        spaces=space_params,
        supply_air_temperature=supply_air_temperature,
        supply_air_flow=params.supply_air_flow,
        supply_air_flow_lower_limit=params.supply_air_flow_lower_limit,
        supply_air_flow_upper_limit=params.supply_air_flow_upper_limit,
    )

    controller = VAVControllerV3(vav_params=vav_params, season=Season(params.season))
    await controller.run()
    regulated_vav = controller.get_results()

    instructions = ACATVAInstructions(
        supply_air_flow_set=regulated_vav.supply_air_flow_set,
        virtual_realtime_temperature=regulated_vav.virtual_realtime_temperature,
        virtual_temperature_target_set=regulated_vav.virtual_target_temperature,
    )

    return instructions


@logger.catch()
async def build_acatva_instructions_for_JM(
    params: ACATVAInstructionsRequestV2,
) -> dict:
    # Control logic for Jiaming.

    space_params = []
    for sp in params.spaces:
        temp_sp = SpaceATVA(**sp.dict())
        if temp_sp.temperature_target:
            temp_sp.diff = temp_sp.temperature_target - temp_sp.realtime_temperature
        else:
            temp_sp.diff = 0.0
        space_params.append(temp_sp)

    vav_params = VAVBox(spaces=space_params)
    return_air_temp = (
        np.NAN if params.return_air_temp == -1.0 else params.return_air_temp
    )

    controller = VAVControllerV4(
        vav_params=vav_params,
        season=Season(params.season),
        return_air_temp=return_air_temp,
    )
    await controller.run()
    regulated_vav = controller.get_results()

    next_temp_set = regulated_vav.setting_temperature
    if next_temp_set:
        if np.isnan(next_temp_set):
            next_temp_set = -1.0
    else:
        next_temp_set = -1.0
    
    instructions = {
        'temperature_target_set': next_temp_set,
        'virtual_target_temperature': regulated_vav.virtual_target_temperature,
        'virtual_realtime_temperature': regulated_vav.virtual_realtime_temperature
    }

    return instructions