from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass from datetime import datetime import inspect import os import re import time from types import GenericAlias, MethodType, UnionType from typing import ( Any, Literal, ParamSpec, Protocol, TypeVar, get_args, runtime_checkable, ) import requests from pydantic import ( AliasGenerator, BaseModel as PydanticBaseModel, ConfigDict, PrivateAttr, create_model, ) import yaml from dynamix_sdk.config import Config, ConfigWithAuth from dynamix_sdk.exceptions import RequestException from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path NAME_MAPPING_FILE_PATH = os.path.join( os.path.dirname(__file__), 'api', 'name_mapping.yml', ) PATH_MAPPING_FILE_PATH = os.path.join( os.path.dirname(__file__), 'api', 'path_mapping.yml', ) def read_mapping_file(file_path: str): with open(file_path) as file: parsed_data: Any = yaml.safe_load(file) if not isinstance(parsed_data, dict): raise TypeError result_dict: dict[Any, Any] = parsed_data with open(file_path) as file: name_mapping_file_lines_amount = sum( 1 for s in file if s.lstrip() and not s.startswith('#') ) if len(result_dict) < name_mapping_file_lines_amount: raise AssertionError( f'File {file_path} contains more code lines than' f' keys in the parsed data. Check the file for duplicate keys.' ) for k, v in result_dict.items(): if not isinstance(k, str) or not isinstance(v, str): raise TypeError result_str_dict: dict[str, str] = result_dict return result_str_dict path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH) name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH) common_mappings_values = [ v for k, v in name_mapping_dict.items() if '__' not in k ] if len(common_mappings_values) > len(set(common_mappings_values)): raise AssertionError( f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values' f' only for individual mapping (attr_name__model_class_name),' f' not common. Check common mappings for duplicate values.' ) @runtime_checkable class BaseAPIFunctionProtocol(Protocol): pass class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol): pass class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol): pass class BaseDeleteAPIFunctionProtocol(BaseAPIFunctionProtocol): pass class BasePatchAPIFunctionProtocol(BaseAPIFunctionProtocol): pass class BasePutAPIFunctionProtocol(BaseAPIFunctionProtocol): pass base_proto_to_http_method: dict[type[BaseAPIFunctionProtocol], HTTPMethod] = { BasePostAPIFunctionProtocol: HTTPMethod.POST, BaseGetAPIFunctionProtocol: HTTPMethod.GET, BaseDeleteAPIFunctionProtocol: HTTPMethod.DELETE, BasePatchAPIFunctionProtocol: HTTPMethod.PATCH, BasePutAPIFunctionProtocol: HTTPMethod.PUT, } class BaseModel(PydanticBaseModel, ABC): @staticmethod def _get_datetime_from_timestamp(timestamp: float) -> None | datetime: if timestamp > 0: return datetime.fromtimestamp(timestamp) class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC): model_config = ConfigDict( extra='forbid', strict=True, frozen=True, ) class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC): pass class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC): pass @dataclass class APIResultContext: api_params: BaseAPIParamsModel http_response: requests.Response class BaseAPIResult(ABC): _api_params: BaseAPIParamsModel _http_response: requests.Response class BaseAPIResultModel( BaseModelWithFrozenStrictExtraForbid, BaseAPIResult, ABC, ): _api_params: BaseAPIParamsModel = PrivateAttr() _http_response: requests.Response = PrivateAttr() def model_post_init(self, __context: APIResultContext): self._api_params = __context.api_params self._http_response = __context.http_response class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC): pass class BaseAPIResultBasicType(BaseAPIResult, ABC): _frozen: bool = False @abstractmethod def __new__(cls, value: Any, context: APIResultContext): return super().__new__(cls) def __setattr__(self, name: str, value: Any) -> None: if self._frozen: raise AttributeError( 'Instance is frozen.', ) return super().__setattr__(name, value) class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC): def __new__(cls, value: str, context: APIResultContext): instance = super().__new__(cls, value) instance._api_params = context.api_params instance._http_response = context.http_response instance._frozen = True return instance class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC): def __new__(cls, value: int, context: APIResultContext): instance = super().__new__(cls, value) instance._api_params = context.api_params instance._http_response = context.http_response instance._frozen = True return instance class BaseAPIResultBool(BaseAPIResultBasicType, ABC): value: bool def __bool__(self): return self.value def __str__(self): return f'{self.__class__.__qualname__}: {self.value}' def __new__(cls, value: bool, context: APIResultContext): if isinstance(value, bool): _value = value elif value is None: # BDX-8379 _value = True else: raise TypeError instance = super().__new__(cls, _value, context) instance.value = _value instance._api_params = context.api_params instance._http_response = context.http_response instance._frozen = True return instance def get_alias( field_name: str, model_cls: type[BaseModel], name_mapping_dict: dict[str, str], ) -> str: if field_name in model_cls.__annotations__.keys(): individual_alias = name_mapping_dict.get( f'{field_name}__{model_cls.__qualname__}', ) if individual_alias: return individual_alias if field_name in name_mapping_dict: return name_mapping_dict[field_name] raise KeyError( f'Mapping for attr {model_cls.__qualname__}.{field_name}' f' not found in name mapping dictionary.' ) for base_cls in model_cls.__bases__: if not issubclass(base_cls, BaseModel): continue if field_name not in base_cls.model_fields.keys(): continue return get_alias( field_name=field_name, model_cls=base_cls, name_mapping_dict=name_mapping_dict, ) raise NameError( f'Field {field_name} not found in model {model_cls.__qualname__}.' ) def set_alias_gen( model_cls: type[BaseModel], alias_type: Literal['validation', 'serialization'], name_mapping_dict: dict[str, str] ): if model_cls.model_config.get('alias_generator'): return def get_model_classes_from_annotation( annotation: Any, classes: list[type[BaseModel]] | None = None, ): _classes = classes or [] if annotation and annotation is not Any: if isinstance(annotation, (UnionType, GenericAlias)): for cls in get_args(annotation): _classes += get_model_classes_from_annotation(cls) elif issubclass(annotation, BaseModel): _classes.append(annotation) return _classes for field in model_cls.model_fields.values(): for cls in get_model_classes_from_annotation(field.annotation): set_alias_gen( model_cls=cls, alias_type=alias_type, name_mapping_dict=name_mapping_dict, ) def alias_gen(field_name: str): return get_alias( field_name=field_name, model_cls=model_cls, name_mapping_dict=name_mapping_dict, ) kwargs = { f'{alias_type}_alias': alias_gen } model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs) model_cls.model_rebuild(force=True) params_model_classes: dict[str, type[BaseAPIParamsModel]] = {} def gen_api_params_cls_name(api_path: str): return gen_cls_name_from_url_path( url_path=re.sub(r'[\[\]]', '', api_path), postfix='ParamsModel', ) def create_api_params_cls( cls_name: str, module_name: str, protocol_method: Callable ): field_definitions = {} params = inspect.signature(protocol_method).parameters.values() inspect_empty = inspect.Parameter.empty for param in params: if param.name == 'self': continue field_definitions[param.name] = ( param.annotation, ... if param.default is inspect_empty else param.default, ) return create_model( cls_name, __base__=BaseAPIParamsModel, __module__=module_name, **field_definitions, ) def get_func_api_path(*, func_name: str, path_mapping: dict) -> str: api_path_parts = [] for func_part in func_name.split('.'): api_path_parts.append( path_mapping.get(func_part, func_part) ) return f'/{"/".join(api_path_parts)}' APIParamsP = ParamSpec('APIParamsP') APIResultT = TypeVar( 'APIResultT', bound=BaseAPIResultModel | BaseAPIResultBasicType, ) class BaseAPI(ABC): _config: Config _parent_api_group_names: list[str] _name_mapping_dict: dict[str, str] = name_mapping_dict _path_mapping_dict: dict[str, str] = path_mapping_dict _post_json: bool = True def __init_subclass__( cls, path_mapping_dict: None | dict[str, str] = None, name_mapping_dict: None | dict[str, str] = None, post_json: None | bool = None, ) -> None: if path_mapping_dict: cls._path_mapping_dict = path_mapping_dict if name_mapping_dict: cls._name_mapping_dict = name_mapping_dict if post_json is not None: cls._post_json = post_json def __init__( self, config: Config, parent_api_group_names: None | list = None, ): self._config = config self._parent_api_group_names = parent_api_group_names or [] def __getattribute__(self, name: str) -> Any: if name.startswith('_'): return super().__getattribute__(name) else: if name in self.__annotations__: annotation = self.__annotations__[name] if issubclass(annotation, BaseAPI): api_cls = annotation return api_cls( config=self._config, parent_api_group_names=( self._parent_api_group_names + [name] ) ) else: attr_value = super().__getattribute__(name) if not inspect.ismethod(attr_value): raise ValueError return self._make_api_function(attr_value) def _make_api_function( self, protocol_method: Callable[APIParamsP, APIResultT], /, ) -> Callable[APIParamsP, APIResultT]: return_type = inspect.signature(protocol_method).return_annotation if not issubclass(return_type, get_args(APIResultT.__bound__)): raise TypeError( f'Return type annotation of {protocol_method}' f' must be subclass of {get_args(APIResultT.__bound__)}.' ) api_result_cls: type[APIResultT] = return_type if not isinstance(protocol_method, MethodType): raise TypeError func_name_parts = ( self._parent_api_group_names + [protocol_method.__name__] ) func_name = '.'.join(func_name_parts) def api_function( *args: APIParamsP.args, **kwargs: APIParamsP.kwargs, ): config = self._config api_path = get_func_api_path( func_name=func_name, path_mapping=self._path_mapping_dict, ) api_params_cls_name = gen_api_params_cls_name( api_path=api_path, ) api_params_cls = params_model_classes.get( api_params_cls_name, None, ) if not api_params_cls: api_params_cls = create_api_params_cls( cls_name=api_params_cls_name, module_name=api_result_cls.__module__, protocol_method=protocol_method, ) params_model_classes[api_params_cls_name] = api_params_cls set_alias_gen( model_cls=api_params_cls, alias_type='serialization', name_mapping_dict=self._name_mapping_dict, ) api_params = api_params_cls(*args, **kwargs) req_headers = None if isinstance(config, ConfigWithAuth): req_headers = { 'Authorization': f'bearer {config.jwt}', } model_data = api_params.model_dump( by_alias=True, exclude_none=True, ) for api_mixin_cls in self.__class__.__bases__[1:]: if not hasattr(api_mixin_cls, protocol_method.__name__): continue proto_cls_base = api_mixin_cls.__base__ if ( not proto_cls_base or not issubclass(proto_cls_base, BaseAPIFunctionProtocol) ): raise TypeError( f'Class {api_mixin_cls.__qualname__} must have' f' {BaseAPIFunctionProtocol.__qualname__}' f' as parent class.' ) http_method = base_proto_to_http_method[proto_cls_base] break else: raise RuntimeError req_params = None req_json = None req_data = None match http_method: case HTTPMethod.GET: req_params = model_data case ( HTTPMethod.POST | HTTPMethod.DELETE | HTTPMethod.PATCH | HTTPMethod.PUT ): if self._post_json: req_json = model_data else: req_data = model_data case _: raise RuntimeError http_request = requests.Request( method=http_method, url=config.get_api_url(api_path=api_path), headers=req_headers, params=req_params, json=req_json, data=req_data, ).prepare() attempts = config.http503_attempts try: with requests.Session() as session: while True: http_response = session.send( request=http_request, verify=config.verify_ssl, ) if http_response.status_code == 503: if attempts < 1: break else: attempts -= 1 time.sleep(config.http503_attempts_interval) else: break http_response.raise_for_status() except requests.exceptions.RequestException as e: if config.wrap_request_exceptions: raise RequestException( orig_exception=e, func_name=func_name, func_kwargs=kwargs, ) else: raise e api_result_context = APIResultContext( api_params=api_params, http_response=http_response, ) if issubclass(api_result_cls, BaseAPIResultModel): set_alias_gen( model_cls=api_result_cls, alias_type='validation', name_mapping_dict=self._name_mapping_dict, ) result_extra_allow = ( api_result_cls.model_config.get('extra') == 'allow' ) if config.result_extra_allow != result_extra_allow: if config.result_extra_allow: api_result_cls.model_config['extra'] = 'allow' else: api_result_cls.model_config['extra'] = ( BaseAPIResultModel.model_config.get('extra') ) api_result_cls.model_rebuild(force=True) api_result = api_result_cls.model_validate_json( json_data=http_response.content, context=api_result_context, ) else: try: decoded_content = http_response.json() except requests.JSONDecodeError: decoded_content = http_response.text api_result = api_result_cls( value=decoded_content, context=api_result_context, ) return api_result api_function.__name__ = func_name.replace('.', '__') if self._config.f_decorators: for decorator in reversed(self._config.f_decorators): api_function = decorator(api_function) return api_function