# -*- coding: utf-8 -*-

from operator import attrgetter
from typing import Dict, List, Tuple

import numpy as np
from fastapi import HTTPException
from httpx import AsyncClient
from loguru import logger
from sqlalchemy.orm import Session

from app.controllers.equipment.controller import EquipmentController
from app.crud.space.weight import get_weights_by_vav
from app.schemas.equipment import VAVBox, FCU
from app.schemas.sapce_weight import SpaceWeight
from app.schemas.space import Space
from app.services.platform import DataPlatformService
from app.services.transfer import Duoduo, SpaceInfoService, Season
from app.utils.date import get_time_str


class VAVController(EquipmentController):

    def __init__(self, equipment: VAVBox):
        super(VAVController, self).__init__()
        self.equipment = equipment

    async def get_strategy(self):
        strategy = 'Plan A'
        for space in self.equipment.spaces:
            for eq in space.equipment:
                if isinstance(eq, FCU):
                    strategy = 'Plan B'
                    break

        return strategy

    async def build_virtual_temperature(self) -> Tuple[float, float]:
        target_list, realtime_list = [], []
        buffer_list = []
        strategy = await self.get_strategy()
        for space in self.equipment.spaces:
            if not np.isnan(space.temperature_target):
                target_list.append(space.temperature_target)
                realtime_list.append(space.realtime_temperature)
                if strategy == 'Plan B':
                    for eq in space.equipment:
                        if isinstance(eq, FCU):
                            buffer = (4 - eq.air_valve_speed) / 4
                            buffer_list.append(buffer)
                            break
        total_target = buffer_list + target_list
        total_realtime = buffer_list + realtime_list
        if total_target and total_realtime:
            target_result = np.array(total_target).sum() / len(target_list)
            realtime_result = np.array(total_realtime).sum() / len(realtime_list)
            self.equipment.setting_temperature = target_result
        else:
            target_result, realtime_result = np.NAN, np.NAN

        return target_result, realtime_result

    async def get_supply_air_flow_set(self, temperature_set: float, temperature_realtime: float) -> float:
        if np.isnan(temperature_set) or np.isnan(temperature_realtime):
            supply_air_flow_set = 0.0
        else:
            temperature_supply = self.equipment.supply_air_temperature
            if np.isnan(temperature_supply):
                temperature_supply = 19.0
            ratio = abs(1 + (temperature_realtime - temperature_set) / (temperature_set - temperature_supply))
            supply_air_flow_set = self.equipment.supply_air_flow * ratio

        supply_air_flow_set = max(self.equipment.supply_air_flow_lower_limit, supply_air_flow_set)
        supply_air_flow_set = min(self.equipment.supply_air_flow_upper_limit, supply_air_flow_set)
        self.equipment.supply_air_flow_set = supply_air_flow_set
        self.equipment.virtual_target_temperature = temperature_set
        self.equipment.virtual_realtime_temperature = temperature_realtime

        return supply_air_flow_set

    async def run(self):
        temperature_set, temperature_realtime = await self.build_virtual_temperature()
        await self.get_supply_air_flow_set(temperature_set, temperature_realtime)
        self.equipment.running_status = True

    def get_results(self):
        return self.equipment


class VAVControllerV2(VAVController):

    def __init__(self, equipment: VAVBox, weights: List[SpaceWeight], season: Season):
        super(VAVControllerV2, self).__init__(equipment)
        self.weights = weights
        self.season = season

    async def build_virtual_temperature(self) -> None:
        valid_spaces = []
        weights = []
        for sp in self.equipment.spaces:
            if sp.realtime_temperature > 0.0 and sp.temperature_target > 0.0:
                valid_spaces.append(sp)
                for weight in self.weights:
                    if weight.space_id == sp.id:
                        weights.append(weight)

        if valid_spaces:
            weights = sorted(weights, key=lambda x: x.temporary_weight_update_time)
            if weights[-1].temporary_weight_update_time > get_time_str(60 * 60 * 2, flag='ago'):
                #  If someone has submitted a feedback in past two hours, meet the need.
                weight_dic = {weight.space_id: 0.0 for weight in weights}
                weight_dic.update({weights[-1].space_id: weights[-1].temporary_weight})
            else:
                weight_dic = {weight.space_id: weight.default_weight for weight in weights}
                total_weight_value = 0.0
                for v in weight_dic.values():
                    total_weight_value += v
                if total_weight_value > 0:
                    weight_dic = {k: v / total_weight_value for k, v in weight_dic.items()}
                else:
                    weight_dic.update({list(weight_dic.keys())[0]: 1.0})

            try:
                virtual_target, virtual_realtime = 0.0, 0.0
                for sp in valid_spaces:
                    virtual_target += sp.temperature_target * weight_dic.get(sp.id)
                    virtual_realtime += sp.realtime_temperature * weight_dic.get(sp.id)
            except KeyError:
                logger.error(f'{self.equipment.id} has wrong vav-space relation')
                raise HTTPException(status_code=404, detail='This VAV box has wrong eq-sp relation')

            self.equipment.virtual_target_temperature = virtual_target
            self.equipment.virtual_realtime_temperature = virtual_realtime
        else:
            self.equipment.virtual_target_temperature = np.NAN
            self.equipment.virtual_realtime_temperature = np.NAN

    async def rectify(self) -> Tuple[float, float]:
        bad_spaces = list()
        for sp in self.equipment.spaces:
            if (sp.realtime_temperature > max(27.0, sp.temperature_target) or
                    sp.realtime_temperature < min(21.0, sp.temperature_target)):
                bad_spaces.append(sp)

        if bad_spaces:
            virtual_diff = self.equipment.virtual_target_temperature - self.equipment.virtual_realtime_temperature
            if self.season == Season.cooling:
                bad_spaces = sorted(bad_spaces, key=attrgetter('diff'))
                worst = bad_spaces[0]
                if worst.diff <= 0:
                    if worst.diff < virtual_diff:
                        self.equipment.virtual_target_temperature = worst.temperature_target
                        self.equipment.virtual_realtime_temperature = worst.realtime_temperature
                else:
                    self.equipment.virtual_target_temperature = min(21.0, worst.temperature_target) + 0.5
                    self.equipment.virtual_realtime_temperature = worst.realtime_temperature
            elif self.season == Season.heating:
                bad_spaces = sorted(bad_spaces, key=attrgetter('diff'), reverse=True)
                worst = bad_spaces[0]
                if worst.diff >= 0:
                    if worst.diff > virtual_diff:
                        self.equipment.virtual_target_temperature = worst.temperature_target
                        self.equipment.virtual_realtime_temperature = worst.realtime_temperature
                else:
                    self.equipment.virtual_target_temperature = max(27.0, worst.temperature_target) - 0.5
                    self.equipment.virtual_realtime_temperature = worst.realtime_temperature

        return self.equipment.virtual_target_temperature, self.equipment.virtual_realtime_temperature

    async def run(self) -> None:
        await self.build_virtual_temperature()
        temperature_set, temperature_realtime = await self.rectify()
        await self.get_supply_air_flow_set(temperature_set, temperature_realtime)
        self.equipment.running_status = True


async def fetch_vav_control_params(project_id: str, equipment_id: str) -> Dict:
    async with AsyncClient() as client:
        duo_duo = Duoduo(client, project_id)
        platform = DataPlatformService(client, project_id)

        season = await duo_duo.get_season()
        served_spaces = await duo_duo.get_space_by_equipment(equipment_id)
        space_objects = []
        realtime_supply_air_temperature_list = []
        for sp in served_spaces:
            sp_id = sp.get('id')
            transfer = SpaceInfoService(client, project_id, sp_id)
            current_target = await transfer.get_current_temperature_target()
            realtime_temperature = await platform.get_realtime_temperature(sp_id)

            related_equipment = await transfer.get_equipment()
            equipment_objects = []
            for eq in related_equipment:
                if eq.get('category') == 'ACATFC':
                    speed = await platform.get_fan_speed(eq.get('id'))
                    temp_fcu_params = {'id': eq.get('id'), 'air_valve_speed': speed}
                    fcu = FCU(**temp_fcu_params)
                    equipment_objects.append(fcu)
                elif eq.get('category') == 'ACATAH':
                    realtime_supply_air_temperature_list.append(
                        await platform.get_realtime_supply_air_temperature(eq.get('id'))
                    )
            temp_space_params = {
                'id': sp_id,
                'realtime_temperature': realtime_temperature,
                'equipment': equipment_objects,
                'temperature_target': current_target,
                'diff': current_target - realtime_temperature
            }
            space = Space(**temp_space_params)
            space_objects.append(space)

        realtime_supply_air_temperature = np.array(realtime_supply_air_temperature_list).mean()
        realtime_supply_air_flow = await platform.get_realtime_supply_air_flow(equipment_id)
        lower_limit, upper_limit = await platform.get_air_flow_limit(equipment_id)

        vav_params = {
            'id': equipment_id,
            'spaces': space_objects,
            'supply_air_temperature': realtime_supply_air_temperature,
            'supply_air_flow': realtime_supply_air_flow,
            'supply_air_flow_lower_limit': lower_limit,
            'supply_air_flow_upper_limit': upper_limit,
            'season': season
        }

        return vav_params


@logger.catch()
async def get_vav_control_v1(project: str, equipment_id: str) -> VAVBox:
    vav_params = await fetch_vav_control_params(project, equipment_id)
    vav = VAVBox(**vav_params)

    vav_controller = VAVController(vav)
    await vav_controller.run()
    regulated_vav = vav_controller.get_results()

    return regulated_vav


@logger.catch()
async def get_vav_control_v2(db: Session, project_id: str, equipment_id: str) -> VAVBox:
    vav_params = await fetch_vav_control_params(project_id, equipment_id)
    vav = VAVBox(**vav_params)
    weights = get_weights_by_vav(db, equipment_id)

    vav_controller = VAVControllerV2(vav, [SpaceWeight.from_orm(weight) for weight in weights], vav_params['season'])
    await vav_controller.run()
    regulated_vav = vav_controller.get_results()

    return regulated_vav