from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass from datetime import datetime import inspect import os import re import time from types import GenericAlias, UnionType from typing import Any, Literal, ParamSpec, Protocol, TypeVar, get_args import requests from pydantic import ( AliasGenerator, BaseModel as PydanticBaseModel, ConfigDict, PrivateAttr, create_model, ) import yaml from dynamix_sdk.config import Config, ConfigWithAuth from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path NAME_MAPPING_FILE_PATH = os.path.join( os.path.dirname(__file__), 'api', 'name_mapping.yml', ) PATH_MAPPING_FILE_PATH = os.path.join( os.path.dirname(__file__), 'api', 'path_mapping.yml', ) def read_mapping_file(file_path: str): with open(file_path) as file: parsed_data: Any = yaml.safe_load(file) if not isinstance(parsed_data, dict): raise TypeError result_dict: dict[Any, Any] = parsed_data with open(file_path) as file: name_mapping_file_lines_amount = sum( 1 for s in file if s.lstrip() and not s.startswith('#') ) if len(result_dict) < name_mapping_file_lines_amount: raise AssertionError( f'File {file_path} contains more code lines than' f' keys in the parsed data. Check the file for duplicate keys.' ) for k, v in result_dict.items(): if not isinstance(k, str) or not isinstance(v, str): raise TypeError result_str_dict: dict[str, str] = result_dict return result_str_dict path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH) name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH) common_mappings_values = [ v for k, v in name_mapping_dict.items() if '__' not in k ] if len(common_mappings_values) > len(set(common_mappings_values)): raise AssertionError( f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values' f' only for individual mapping (attr_name__model_class_name),' f' not common. Check common mappings for duplicate values.' ) class BaseAPIFunctionProtocol(Protocol): pass class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol): pass class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol): pass class BaseModel(PydanticBaseModel, ABC): @staticmethod def _get_datetime_from_timestamp(timestamp: float) -> None | datetime: if timestamp > 0: return datetime.fromtimestamp(timestamp) class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC): model_config = ConfigDict( extra='forbid', strict=True, frozen=True, ) class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC): pass class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC): def __init_subclass__( cls, *args: tuple[Any], **kwargs: dict[str, Any], ) -> None: super().__init_subclass__(*args, **kwargs) postfix = 'APIParamsNM' if not cls.__qualname__.endswith(postfix): raise ValueError( f'Name of {cls} must end with {postfix}.', ) @dataclass class APIResultContext: api_params: BaseAPIParamsModel http_response: requests.Response class BaseAPIResult(ABC): _api_params: BaseAPIParamsModel _http_response: requests.Response class BaseAPIResultModel( BaseModelWithFrozenStrictExtraForbid, BaseAPIResult, ABC, ): _api_params: BaseAPIParamsModel = PrivateAttr() _http_response: requests.Response = PrivateAttr() def model_post_init(self, __context: APIResultContext): self._api_params = __context.api_params self._http_response = __context.http_response class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC): pass class BaseAPIResultBasicType(BaseAPIResult, ABC): _frozen: bool = False @abstractmethod def __new__(cls, value: Any, context: APIResultContext): return super().__new__(cls) def __setattr__(self, name: str, value: Any) -> None: if self._frozen: raise AttributeError( 'Instance is frozen.', ) return super().__setattr__(name, value) class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC): def __new__(cls, value: str, context: APIResultContext): instance = super().__new__(cls, value) instance._api_params = context.api_params instance._http_response = context.http_response instance._frozen = True return instance class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC): def __new__(cls, value: int, context: APIResultContext): instance = super().__new__(cls, value) instance._api_params = context.api_params instance._http_response = context.http_response instance._frozen = True return instance class BaseAPIResultBool(BaseAPIResultBasicType, ABC): value: bool def __bool__(self): return self.value def __str__(self): return f'{self.__class__.__qualname__}: {self.value}' def __new__(cls, value: bool, context: APIResultContext): if not isinstance(value, bool): raise TypeError instance = super().__new__(cls, value, context) instance.value = value instance._api_params = context.api_params instance._http_response = context.http_response instance._frozen = True return instance def set_alias_gen( model_cls: type[BaseModel], alias_type: Literal['validation', 'serialization'], name_mapping_dict: dict[str, str] ): if model_cls.model_config.get('alias_generator'): return def get_model_classes_from_annotation( annotation: type[Any] | None, classes: list[type[BaseModel]] | None = None, ): _classes = classes or [] if annotation: if isinstance(annotation, (UnionType, GenericAlias)): for cls in get_args(annotation): _classes += get_model_classes_from_annotation(cls) elif issubclass(annotation, BaseModel): _classes.append(annotation) return _classes for field in model_cls.model_fields.values(): for cls in get_model_classes_from_annotation(field.annotation): set_alias_gen( model_cls=cls, alias_type=alias_type, name_mapping_dict=name_mapping_dict, ) def alias_gen(field_name: str): individual_alias = name_mapping_dict.get( f'{field_name}__{model_cls.__qualname__}', ) if individual_alias: return individual_alias else: if field_name in name_mapping_dict: return name_mapping_dict[field_name] else: raise KeyError( f'{field_name} not found in name mapping dictionary.' f' Class model: {model_cls.__name__}.' ) kwargs = { f'{alias_type}_alias': alias_gen } model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs) model_cls.model_rebuild(force=True) params_model_classes: dict[str, type[BaseAPIParamsModel]] = {} APIParamsP = ParamSpec('APIParamsP') APIResultT = TypeVar( 'APIResultT', bound=BaseAPIResultModel | BaseAPIResultBasicType, ) class BaseAPI(ABC): _config: Config _base_api_path: None | str _path_mapping_dict: dict[str, str] = path_mapping_dict _last_call_api_path: None | str = None _name_mapping_dict: dict[str, str] = name_mapping_dict _post_json: bool = True def __init_subclass__( cls, path_mapping_dict: None | dict[str, str] = None, name_mapping_dict: None | dict[str, str] = None, post_json: None | bool = None, ) -> None: if path_mapping_dict: cls._path_mapping_dict = path_mapping_dict if name_mapping_dict: cls._name_mapping_dict = name_mapping_dict if post_json is not None: cls._post_json = post_json def __init__( self, config: Config, base_api_path: None | str = None ): self._config = config self._base_api_path = base_api_path def __getattribute__(self, name: str) -> Any: if name.startswith('_'): return super().__getattribute__(name) else: path_part = self._path_mapping_dict.get(name, name) if name in self.__annotations__: annotation = self.__annotations__[name] if issubclass(annotation, BaseAPI): api_cls = annotation return api_cls( config=self._config, base_api_path=( f'{self._base_api_path or ""}/{path_part}' ) ) else: self._last_call_api_path = ( f'{self._base_api_path or ""}/{path_part}' ) attr_value = super().__getattribute__(name) if not inspect.ismethod(attr_value): raise ValueError return self._make_api_function(attr_value) def _make_api_function( self, protocol_method: Callable[APIParamsP, APIResultT], /, ) -> Callable[APIParamsP, APIResultT]: return_type = inspect.signature(protocol_method).return_annotation if not issubclass(return_type, get_args(APIResultT.__bound__)): raise TypeError( f'Return type annotation of {protocol_method}' f' must be subclass of {get_args(APIResultT.__bound__)}.' ) api_result_cls: type[APIResultT] = return_type def api_function( *args: APIParamsP.args, **kwargs: APIParamsP.kwargs, ): config = self._config if self._last_call_api_path: api_path = self._last_call_api_path else: raise ValueError api_params_cls_name = gen_cls_name_from_url_path( url_path=re.sub(r'[\[\]]', '', api_path), postfix=BaseAPIParamsModel.__qualname__.removeprefix( 'BaseAPI' ), ) api_params_cls = params_model_classes.get( api_params_cls_name, None, ) if not api_params_cls: field_definitions = {} parameters = inspect.signature(protocol_method).parameters inspect_empty = inspect.Parameter.empty for p in parameters.values(): field_definitions[p.name] = ( p.annotation, ... if p.default is inspect_empty else p.default, ) api_params_cls = create_model( api_params_cls_name, __base__=BaseAPIParamsModel, __module__=api_result_cls.__module__, **field_definitions, ) params_model_classes[api_params_cls_name] = api_params_cls set_alias_gen( model_cls=api_params_cls, alias_type='serialization', name_mapping_dict=self._name_mapping_dict, ) api_params = api_params_cls(*args, **kwargs) req_headers = None if isinstance(config, ConfigWithAuth): req_headers = { 'Authorization': f'bearer {config.jwt}', } model_data = api_params.model_dump( by_alias=True, exclude_none=True, ) for api_mixin_cls in self.__class__.__bases__[1:]: if hasattr(api_mixin_cls, protocol_method.__name__): if issubclass(api_mixin_cls, BasePostAPIFunctionProtocol): http_method = HTTPMethod.POST break elif issubclass(api_mixin_cls, BaseGetAPIFunctionProtocol): http_method = HTTPMethod.GET break else: raise RuntimeError req_params = None req_json = None req_data = None match http_method: case HTTPMethod.GET: req_params = model_data case HTTPMethod.POST: if self._post_json: req_json = model_data else: req_data = model_data case _: raise RuntimeError http_request = requests.Request( method=http_method, url=config.get_api_url(api_path=api_path), headers=req_headers, params=req_params, json=req_json, data=req_data, ).prepare() attempts = config.http503_attempts with requests.Session() as session: while True: http_response = session.send( request=http_request, verify=config.verify_ssl, ) if http_response.status_code == 503: if attempts < 1: break else: attempts -= 1 time.sleep(config.http503_attempts_interval) else: break http_response.raise_for_status() api_result_context = APIResultContext( api_params=api_params, http_response=http_response, ) if issubclass(api_result_cls, BaseAPIResultModel): set_alias_gen( model_cls=api_result_cls, alias_type='validation', name_mapping_dict=self._name_mapping_dict, ) result_extra_allow = ( api_result_cls.model_config.get('extra') == 'allow' ) if config.result_extra_allow != result_extra_allow: if config.result_extra_allow: api_result_cls.model_config['extra'] = 'allow' else: api_result_cls.model_config['extra'] = ( BaseAPIResultModel.model_config.get('extra') ) api_result_cls.model_rebuild(force=True) api_result = api_result_cls.model_validate_json( json_data=http_response.content, context=api_result_context, ) else: try: decoded_content = http_response.json() except requests.JSONDecodeError: decoded_content = http_response.text api_result = api_result_cls( value=decoded_content, context=api_result_context, ) return api_result return api_function