# -*- coding: utf-8 -*-

from typing import Dict, Optional

import numpy as np
from httpx import AsyncClient
from loguru import logger

from app.controllers.events import q_learning_models
from app.services.platform import DataPlatformService
from app.services.transfer import Duoduo, SpaceInfoService
from app.services.transfer import Season


class QLearningCommandBuilder:
    """
    Build FCU command by Q learning net.
    """

    def __init__(self, season: Season):
        self.season = season
        if season == Season.cooling:
            self.model = q_learning_models.get('summer')
        elif season == Season.heating:
            self.model = q_learning_models.get('winter')
        else:
            self.model = None

    def get_type(self, layer: int) -> str:
        return self.model[0, layer][0, 0][0][0]

    def get_weight(self, layer: int, idx: int) -> np.ndarray:
        return self.model[0, layer][0, 0][1][0, idx]

    @staticmethod
    def linear(input_v: np.ndarray, weight: np.ndarray, bias: Optional[np.ndarray] = None) -> np.ndarray:
        y = np.dot(weight, input_v)
        if bias.size > 0:
            y += bias

        return y

    @staticmethod
    def relu(x: np.ndarray) -> np.ndarray:
        return np.maximum(x, 0)

    def predict_speed(self, input_v: np.ndarray) -> int:
        result = [input_v]
        for layer in range(self.model.shape[1]):
            if self.get_type(layer) == 'mlp' or self.get_type(layer) == 'linear':
                y = self.linear(result[layer], self.get_weight(layer, 0), self.get_weight(layer, 1))
                result.append(y)
            elif self.get_type(layer) == 'relu':
                result.append(self.relu(result[layer]))

        speed = np.argmax(result[-1])

        return int(speed)

    async def get_command(self, current_temperature: float, pre_temperature: float, actual_target: float) -> Dict:
        input_value = np.array([
            [(current_temperature - actual_target) / 5],
            [(current_temperature - pre_temperature) / 5]
        ])
        speed = self.predict_speed(input_value)
        if np.isnan(current_temperature) or np.isnan(pre_temperature):
            speed = 2
        if np.isnan(actual_target):
            speed = 0

        if speed == 0:
            on_off = 0
            water_on_off = 0
        else:
            on_off = 1
            water_on_off = 1

        if self.season == Season.cooling:
            season = 1
        elif self.season == Season.heating:
            season = 2
        else:
            season = 0

        command = {
            'onOff': on_off,
            'mode': season,
            'speed': int(speed),
            'temperature': actual_target if not np.isnan(actual_target) else None,
            'water': water_on_off
        }
        return command


@logger.catch()
async def get_fcu_q_learning_control_result(project_id: str, equipment_id: str) -> Dict:
    async with AsyncClient() as client:
        duo_duo = Duoduo(client, project_id)
        platform = DataPlatformService(client, project_id)

        spaces = await duo_duo.get_space_by_equipment(equipment_id)
        if not spaces:
            logger.error(f'FCU {equipment_id} does not have space')
            return {}
        else:
            if len(spaces) > 1:
                logger.error(f'FCU {equipment_id} control more than one spaces!')
            transfer = SpaceInfoService(client, project_id, spaces[0].get('id'))
            season = await duo_duo.get_season()
            current_target = await transfer.get_current_temperature_target()
            realtime_temperature = await platform.get_realtime_temperature(spaces[0].get('id'))
            past_temperature = await platform.get_past_temperature(spaces[0].get('id'), 15 * 60)

    logger.debug(
        f'{spaces[0]["id"]} - {equipment_id} - '
        f'realtime Tdb: {realtime_temperature} - '
        f'pre Tdb: {past_temperature} - '
        f'target: {current_target}'
    )
    if season == Season.transition:
        command = {}
    else:
        builder = QLearningCommandBuilder(season)
        command = await builder.get_command(realtime_temperature, past_temperature, current_target)

    return command