early_start.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import Tuple
  2. from httpx import AsyncClient
  3. from joblib import load
  4. from loguru import logger
  5. from sqlalchemy.orm import Session
  6. from app.core.config import settings
  7. from app.crud.model_path.early_start import model_path_early_start_dtr
  8. from app.models.domain.devices import ACATFCEarlyStartPredictionRequest
  9. from app.services.platform import DataPlatformService
  10. from app.services.transfer import SpaceInfoService
  11. from app.services.weather import WeatherService
  12. class EarlyStartTimeDTRBuilder:
  13. """
  14. Build early start time by decision tree regression.
  15. """
  16. def __init__(self, model_path: str):
  17. self.model_path = f"{settings.ML_MODELS_DIR}{model_path}"
  18. async def get_prediction(self, indoor_temp: float, outdoor_temp: float) -> float:
  19. try:
  20. model = load(self.model_path)
  21. except (FileNotFoundError, IsADirectoryError) as e:
  22. logger.debug(e)
  23. return 0
  24. try:
  25. pre = model.predict([[indoor_temp, outdoor_temp]])
  26. pre_time = pre[0]
  27. except (ValueError, IndexError) as e:
  28. logger.debug(e)
  29. pre_time = 0
  30. return pre_time
  31. async def fetch_params(
  32. project_id: str, space_id: str, db: Session
  33. ) -> Tuple[float, float, str]:
  34. async with AsyncClient() as client:
  35. platform = DataPlatformService(client, project_id)
  36. space_service = SpaceInfoService(client, project_id, space_id)
  37. weather_service = WeatherService(client)
  38. indoor_temp = await platform.get_realtime_temperature(space_id)
  39. weather_info = await weather_service.get_realtime_weather(project_id)
  40. outdoor_temp = weather_info.get("temperature")
  41. device_list = await space_service.get_equipment()
  42. device_id = ""
  43. for device in device_list:
  44. if device.get("category") == "ACATFC":
  45. device_id = device.get("id")
  46. break
  47. if device_id:
  48. model_path = model_path_early_start_dtr.get_path_by_device(db, device_id)
  49. model_path = model_path.model_path
  50. else:
  51. model_path = ""
  52. return indoor_temp, outdoor_temp, model_path
  53. @logger.catch()
  54. async def get_recommended_early_start_time(
  55. db: Session, project_id: str, space_id: str
  56. ) -> float:
  57. indoor_temp, outdoor_temp, model_path = await fetch_params(project_id, space_id, db)
  58. builder = EarlyStartTimeDTRBuilder(model_path)
  59. hour = await builder.get_prediction(indoor_temp, outdoor_temp)
  60. logger.debug(
  61. f"{space_id}: indoor-{indoor_temp}, outdoor-{outdoor_temp}, prediction-{hour * 60}"
  62. )
  63. return hour * 60
  64. @logger.catch()
  65. async def build_acatfc_early_start_prediction(
  66. params: ACATFCEarlyStartPredictionRequest, db: Session
  67. ) -> float:
  68. model_path = model_path_early_start_dtr.get_path_by_device(db, params.device_id)
  69. builder = EarlyStartTimeDTRBuilder(model_path.model_path)
  70. hour = await builder.get_prediction(
  71. params.space_realtime_temperature, params.outdoor_realtime_temperature
  72. )
  73. return hour * 60