You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					491 lines
				
				15 KiB
			
		
		
			
		
	
	
					491 lines
				
				15 KiB
			| 
											7 months ago
										 | from abc import ABC, abstractmethod | ||
|  | from collections.abc import Callable | ||
|  | from dataclasses import dataclass | ||
|  | from datetime import datetime | ||
|  | import inspect | ||
|  | import os | ||
|  | import re | ||
|  | import time | ||
|  | from types import GenericAlias, UnionType | ||
|  | from typing import Any, Literal, ParamSpec, Protocol, TypeVar, get_args | ||
|  | 
 | ||
|  | import requests | ||
|  | from pydantic import ( | ||
|  |     AliasGenerator, | ||
|  |     BaseModel as PydanticBaseModel, | ||
|  |     ConfigDict, | ||
|  |     PrivateAttr, | ||
|  |     create_model, | ||
|  | ) | ||
|  | import yaml | ||
|  | 
 | ||
|  | from dynamix_sdk.config import Config, ConfigWithAuth | ||
|  | from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path | ||
|  | 
 | ||
|  | 
 | ||
|  | NAME_MAPPING_FILE_PATH = os.path.join( | ||
|  |     os.path.dirname(__file__), | ||
|  |     'api', | ||
|  |     'name_mapping.yml', | ||
|  | ) | ||
|  | PATH_MAPPING_FILE_PATH = os.path.join( | ||
|  |     os.path.dirname(__file__), | ||
|  |     'api', | ||
|  |     'path_mapping.yml', | ||
|  | ) | ||
|  | 
 | ||
|  | 
 | ||
|  | def read_mapping_file(file_path: str): | ||
|  |     with open(file_path) as file: | ||
|  |         parsed_data: Any = yaml.safe_load(file) | ||
|  | 
 | ||
|  |     if not isinstance(parsed_data, dict): | ||
|  |         raise TypeError | ||
|  |     result_dict: dict[Any, Any] = parsed_data | ||
|  | 
 | ||
|  |     with open(file_path) as file: | ||
|  |         name_mapping_file_lines_amount = sum( | ||
|  |             1 for s in file if s.lstrip() and not s.startswith('#') | ||
|  |         ) | ||
|  | 
 | ||
|  |     if len(result_dict) < name_mapping_file_lines_amount: | ||
|  |         raise AssertionError( | ||
|  |             f'File {file_path} contains more code lines than' | ||
|  |             f' keys in the parsed data. Check the file for duplicate keys.' | ||
|  |         ) | ||
|  | 
 | ||
|  |     for k, v in result_dict.items(): | ||
|  |         if not isinstance(k, str) or not isinstance(v, str): | ||
|  |             raise TypeError | ||
|  |     result_str_dict: dict[str, str] = result_dict | ||
|  | 
 | ||
|  |     return result_str_dict | ||
|  | 
 | ||
|  | 
 | ||
|  | path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH) | ||
|  | 
 | ||
|  | 
 | ||
|  | name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH) | ||
|  | common_mappings_values = [ | ||
|  |     v for k, v in name_mapping_dict.items() if '__' not in k | ||
|  | ] | ||
|  | if len(common_mappings_values) > len(set(common_mappings_values)): | ||
|  |     raise AssertionError( | ||
|  |         f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values' | ||
|  |         f' only for individual mapping (attr_name__model_class_name),' | ||
|  |         f' not common. Check common mappings for duplicate values.' | ||
|  |     ) | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIFunctionProtocol(Protocol): | ||
|  |     pass | ||
|  | 
 | ||
|  | 
 | ||
|  | class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol): | ||
|  |     pass | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol): | ||
|  |     pass | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseModel(PydanticBaseModel, ABC): | ||
|  |     @staticmethod | ||
|  |     def _get_datetime_from_timestamp(timestamp: float) -> None | datetime: | ||
|  |         if timestamp > 0: | ||
|  |             return datetime.fromtimestamp(timestamp) | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC): | ||
|  |     model_config = ConfigDict( | ||
|  |         extra='forbid', | ||
|  |         strict=True, | ||
|  |         frozen=True, | ||
|  |     ) | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC): | ||
|  |     pass | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC): | ||
|  |     def __init_subclass__( | ||
|  |         cls, | ||
|  |         *args: tuple[Any], | ||
|  |         **kwargs: dict[str, Any], | ||
|  |     ) -> None: | ||
|  |         super().__init_subclass__(*args, **kwargs) | ||
|  | 
 | ||
|  |         postfix = 'APIParamsNM' | ||
|  |         if not cls.__qualname__.endswith(postfix): | ||
|  |             raise ValueError( | ||
|  |                 f'Name of {cls} must end with {postfix}.', | ||
|  |             ) | ||
|  | 
 | ||
|  | 
 | ||
|  | @dataclass | ||
|  | class APIResultContext: | ||
|  |     api_params: BaseAPIParamsModel | ||
|  |     http_response: requests.Response | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIResult(ABC): | ||
|  |     _api_params: BaseAPIParamsModel | ||
|  |     _http_response: requests.Response | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIResultModel( | ||
|  |     BaseModelWithFrozenStrictExtraForbid, | ||
|  |     BaseAPIResult, | ||
|  |     ABC, | ||
|  | ): | ||
|  |     _api_params: BaseAPIParamsModel = PrivateAttr() | ||
|  |     _http_response: requests.Response = PrivateAttr() | ||
|  | 
 | ||
|  |     def model_post_init(self, __context: APIResultContext): | ||
|  |         self._api_params = __context.api_params | ||
|  |         self._http_response = __context.http_response | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC): | ||
|  |     pass | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIResultBasicType(BaseAPIResult, ABC): | ||
|  |     _frozen: bool = False | ||
|  | 
 | ||
|  |     @abstractmethod | ||
|  |     def __new__(cls, value: Any, context: APIResultContext): | ||
|  |         return super().__new__(cls) | ||
|  | 
 | ||
|  |     def __setattr__(self, name: str, value: Any) -> None: | ||
|  |         if self._frozen: | ||
|  |             raise AttributeError( | ||
|  |                 'Instance is frozen.', | ||
|  |             ) | ||
|  |         return super().__setattr__(name, value) | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC): | ||
|  |     def __new__(cls, value: str, context: APIResultContext): | ||
|  |         instance = super().__new__(cls, value) | ||
|  |         instance._api_params = context.api_params | ||
|  |         instance._http_response = context.http_response | ||
|  |         instance._frozen = True | ||
|  |         return instance | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC): | ||
|  |     def __new__(cls, value: int, context: APIResultContext): | ||
|  |         instance = super().__new__(cls, value) | ||
|  |         instance._api_params = context.api_params | ||
|  |         instance._http_response = context.http_response | ||
|  |         instance._frozen = True | ||
|  |         return instance | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPIResultBool(BaseAPIResultBasicType, ABC): | ||
|  |     value: bool | ||
|  | 
 | ||
|  |     def __bool__(self): | ||
|  |         return self.value | ||
|  | 
 | ||
|  |     def __str__(self): | ||
|  |         return f'{self.__class__.__qualname__}: {self.value}' | ||
|  | 
 | ||
|  |     def __new__(cls, value: bool, context: APIResultContext): | ||
|  |         if not isinstance(value, bool): | ||
|  |             raise TypeError | ||
|  |         instance = super().__new__(cls, value, context) | ||
|  |         instance.value = value | ||
|  |         instance._api_params = context.api_params | ||
|  |         instance._http_response = context.http_response | ||
|  |         instance._frozen = True | ||
|  |         return instance | ||
|  | 
 | ||
|  | 
 | ||
|  | def set_alias_gen( | ||
|  |     model_cls: type[BaseModel], | ||
|  |     alias_type: Literal['validation', 'serialization'], | ||
|  |     name_mapping_dict: dict[str, str] | ||
|  | ): | ||
|  |     if model_cls.model_config.get('alias_generator'): | ||
|  |         return | ||
|  | 
 | ||
|  |     def get_model_classes_from_annotation( | ||
|  |         annotation: type[Any] | None, | ||
|  |         classes: list[type[BaseModel]] | None = None, | ||
|  |     ): | ||
|  |         _classes = classes or [] | ||
|  |         if annotation: | ||
|  |             if isinstance(annotation, (UnionType, GenericAlias)): | ||
|  |                 for cls in get_args(annotation): | ||
|  |                     _classes += get_model_classes_from_annotation(cls) | ||
|  |             elif issubclass(annotation, BaseModel): | ||
|  |                 _classes.append(annotation) | ||
|  |         return _classes | ||
|  | 
 | ||
|  |     for field in model_cls.model_fields.values(): | ||
|  |         for cls in get_model_classes_from_annotation(field.annotation): | ||
|  |             set_alias_gen( | ||
|  |                 model_cls=cls, | ||
|  |                 alias_type=alias_type, | ||
|  |                 name_mapping_dict=name_mapping_dict, | ||
|  |             ) | ||
|  | 
 | ||
|  |     def alias_gen(field_name: str): | ||
|  |         individual_alias = name_mapping_dict.get( | ||
|  |             f'{field_name}__{model_cls.__qualname__}', | ||
|  |         ) | ||
|  |         if individual_alias: | ||
|  |             return individual_alias | ||
|  |         else: | ||
|  |             if field_name in name_mapping_dict: | ||
|  |                 return name_mapping_dict[field_name] | ||
|  |             else: | ||
|  |                 raise KeyError( | ||
|  |                     f'{field_name} not found in name mapping dictionary.' | ||
|  |                     f' Class model: {model_cls.__name__}.' | ||
|  |                 ) | ||
|  | 
 | ||
|  |     kwargs = { | ||
|  |         f'{alias_type}_alias': alias_gen | ||
|  |     } | ||
|  | 
 | ||
|  |     model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs) | ||
|  |     model_cls.model_rebuild(force=True) | ||
|  | 
 | ||
|  | 
 | ||
|  | params_model_classes: dict[str, type[BaseAPIParamsModel]] = {} | ||
|  | 
 | ||
|  | 
 | ||
|  | APIParamsP = ParamSpec('APIParamsP') | ||
|  | APIResultT = TypeVar( | ||
|  |     'APIResultT', | ||
|  |     bound=BaseAPIResultModel | BaseAPIResultBasicType, | ||
|  | ) | ||
|  | 
 | ||
|  | 
 | ||
|  | class BaseAPI(ABC): | ||
|  |     _config: Config | ||
|  |     _base_api_path: None | str | ||
|  |     _path_mapping_dict: dict[str, str] = path_mapping_dict | ||
|  |     _last_call_api_path: None | str = None | ||
|  |     _name_mapping_dict: dict[str, str] = name_mapping_dict | ||
|  |     _post_json: bool = True | ||
|  | 
 | ||
|  |     def __init_subclass__( | ||
|  |         cls, | ||
|  |         path_mapping_dict: None | dict[str, str] = None, | ||
|  |         name_mapping_dict: None | dict[str, str] = None, | ||
|  |         post_json: None | bool = None, | ||
|  |     ) -> None: | ||
|  |         if path_mapping_dict: | ||
|  |             cls._path_mapping_dict = path_mapping_dict | ||
|  |         if name_mapping_dict: | ||
|  |             cls._name_mapping_dict = name_mapping_dict | ||
|  |         if post_json is not None: | ||
|  |             cls._post_json = post_json | ||
|  | 
 | ||
|  |     def __init__( | ||
|  |         self, | ||
|  |         config: Config, | ||
|  |         base_api_path: None | str = None | ||
|  |     ): | ||
|  |         self._config = config | ||
|  |         self._base_api_path = base_api_path | ||
|  | 
 | ||
|  |     def __getattribute__(self, name: str) -> Any: | ||
|  |         if name.startswith('_'): | ||
|  |             return super().__getattribute__(name) | ||
|  |         else: | ||
|  |             path_part = self._path_mapping_dict.get(name, name) | ||
|  |             if name in self.__annotations__: | ||
|  |                 annotation = self.__annotations__[name] | ||
|  |                 if issubclass(annotation, BaseAPI): | ||
|  |                     api_cls = annotation | ||
|  |                     return api_cls( | ||
|  |                         config=self._config, | ||
|  |                         base_api_path=( | ||
|  |                             f'{self._base_api_path or ""}/{path_part}' | ||
|  |                         ) | ||
|  |                     ) | ||
|  |             else: | ||
|  |                 self._last_call_api_path = ( | ||
|  |                     f'{self._base_api_path or ""}/{path_part}' | ||
|  |                 ) | ||
|  | 
 | ||
|  |                 attr_value = super().__getattribute__(name) | ||
|  |                 if not inspect.ismethod(attr_value): | ||
|  |                     raise ValueError | ||
|  | 
 | ||
|  |                 return self._make_api_function(attr_value) | ||
|  | 
 | ||
|  |     def _make_api_function( | ||
|  |         self, | ||
|  |         protocol_method: Callable[APIParamsP, APIResultT], | ||
|  |         /, | ||
|  |     ) -> Callable[APIParamsP, APIResultT]: | ||
|  |         return_type = inspect.signature(protocol_method).return_annotation | ||
|  |         if not issubclass(return_type, get_args(APIResultT.__bound__)): | ||
|  |             raise TypeError( | ||
|  |                 f'Return type annotation of {protocol_method}' | ||
|  |                 f' must be subclass of {get_args(APIResultT.__bound__)}.' | ||
|  |             ) | ||
|  |         api_result_cls: type[APIResultT] = return_type | ||
|  | 
 | ||
|  |         def api_function( | ||
|  |             *args: APIParamsP.args, | ||
|  |             **kwargs: APIParamsP.kwargs, | ||
|  |         ): | ||
|  |             config = self._config | ||
|  | 
 | ||
|  |             if self._last_call_api_path: | ||
|  |                 api_path = self._last_call_api_path | ||
|  |             else: | ||
|  |                 raise ValueError | ||
|  | 
 | ||
|  |             api_params_cls_name = gen_cls_name_from_url_path( | ||
|  |                 url_path=re.sub(r'[\[\]]', '', api_path), | ||
|  |                 postfix=BaseAPIParamsModel.__qualname__.removeprefix( | ||
|  |                     'BaseAPI' | ||
|  |                 ), | ||
|  |             ) | ||
|  | 
 | ||
|  |             api_params_cls = params_model_classes.get( | ||
|  |                 api_params_cls_name, | ||
|  |                 None, | ||
|  |             ) | ||
|  |             if not api_params_cls: | ||
|  |                 field_definitions = {} | ||
|  |                 parameters = inspect.signature(protocol_method).parameters | ||
|  |                 inspect_empty = inspect.Parameter.empty | ||
|  |                 for p in parameters.values(): | ||
|  |                     field_definitions[p.name] = ( | ||
|  |                         p.annotation, | ||
|  |                         ... if p.default is inspect_empty else p.default, | ||
|  |                     ) | ||
|  | 
 | ||
|  |                 api_params_cls = create_model( | ||
|  |                     api_params_cls_name, | ||
|  |                     __base__=BaseAPIParamsModel, | ||
|  |                     __module__=api_result_cls.__module__, | ||
|  |                     **field_definitions, | ||
|  |                 ) | ||
|  |                 params_model_classes[api_params_cls_name] = api_params_cls | ||
|  | 
 | ||
|  |                 set_alias_gen( | ||
|  |                     model_cls=api_params_cls, | ||
|  |                     alias_type='serialization', | ||
|  |                     name_mapping_dict=self._name_mapping_dict, | ||
|  |                 ) | ||
|  | 
 | ||
|  |             api_params = api_params_cls(*args, **kwargs) | ||
|  | 
 | ||
|  |             req_headers = None | ||
|  |             if isinstance(config, ConfigWithAuth): | ||
|  |                 req_headers = { | ||
|  |                     'Authorization': f'bearer {config.jwt}', | ||
|  |                 } | ||
|  | 
 | ||
|  |             model_data = api_params.model_dump( | ||
|  |                 by_alias=True, | ||
|  |                 exclude_none=True, | ||
|  |             ) | ||
|  | 
 | ||
|  |             for api_mixin_cls in self.__class__.__bases__[1:]: | ||
|  |                 if hasattr(api_mixin_cls, protocol_method.__name__): | ||
|  |                     if issubclass(api_mixin_cls, BasePostAPIFunctionProtocol): | ||
|  |                         http_method = HTTPMethod.POST | ||
|  |                         break | ||
|  |                     elif issubclass(api_mixin_cls, BaseGetAPIFunctionProtocol): | ||
|  |                         http_method = HTTPMethod.GET | ||
|  |                         break | ||
|  |             else: | ||
|  |                 raise RuntimeError | ||
|  | 
 | ||
|  |             req_params = None | ||
|  |             req_json = None | ||
|  |             req_data = None | ||
|  |             match http_method: | ||
|  |                 case HTTPMethod.GET: | ||
|  |                     req_params = model_data | ||
|  |                 case HTTPMethod.POST: | ||
|  |                     if self._post_json: | ||
|  |                         req_json = model_data | ||
|  |                     else: | ||
|  |                         req_data = model_data | ||
|  |                 case _: | ||
|  |                     raise RuntimeError | ||
|  | 
 | ||
|  |             http_request = requests.Request( | ||
|  |                 method=http_method, | ||
|  |                 url=config.get_api_url(api_path=api_path), | ||
|  |                 headers=req_headers, | ||
|  |                 params=req_params, | ||
|  |                 json=req_json, | ||
|  |                 data=req_data, | ||
|  |             ).prepare() | ||
|  | 
 | ||
|  |             attempts = config.http503_attempts | ||
|  |             with requests.Session() as session: | ||
|  |                 while True: | ||
|  |                     http_response = session.send( | ||
|  |                         request=http_request, | ||
|  |                         verify=config.verify_ssl, | ||
|  |                     ) | ||
|  | 
 | ||
|  |                     if http_response.status_code == 503: | ||
|  |                         if attempts < 1: | ||
|  |                             break | ||
|  |                         else: | ||
|  |                             attempts -= 1 | ||
|  |                             time.sleep(config.http503_attempts_interval) | ||
|  |                     else: | ||
|  |                         break | ||
|  | 
 | ||
|  |             http_response.raise_for_status() | ||
|  | 
 | ||
|  |             api_result_context = APIResultContext( | ||
|  |                 api_params=api_params, | ||
|  |                 http_response=http_response, | ||
|  |             ) | ||
|  | 
 | ||
|  |             if issubclass(api_result_cls, BaseAPIResultModel): | ||
|  |                 set_alias_gen( | ||
|  |                     model_cls=api_result_cls, | ||
|  |                     alias_type='validation', | ||
|  |                     name_mapping_dict=self._name_mapping_dict, | ||
|  |                 ) | ||
|  | 
 | ||
|  |                 result_extra_allow = ( | ||
|  |                     api_result_cls.model_config.get('extra') == 'allow' | ||
|  |                 ) | ||
|  |                 if config.result_extra_allow != result_extra_allow: | ||
|  |                     if config.result_extra_allow: | ||
|  |                         api_result_cls.model_config['extra'] = 'allow' | ||
|  |                     else: | ||
|  |                         api_result_cls.model_config['extra'] = ( | ||
|  |                             BaseAPIResultModel.model_config.get('extra') | ||
|  |                         ) | ||
|  |                     api_result_cls.model_rebuild(force=True) | ||
|  | 
 | ||
|  |                 api_result = api_result_cls.model_validate_json( | ||
|  |                     json_data=http_response.content, | ||
|  |                     context=api_result_context, | ||
|  |                 ) | ||
|  |             else: | ||
|  |                 try: | ||
|  |                     decoded_content = http_response.json() | ||
|  |                 except requests.JSONDecodeError: | ||
|  |                     decoded_content = http_response.text | ||
|  | 
 | ||
|  |                 api_result = api_result_cls( | ||
|  |                     value=decoded_content, | ||
|  |                     context=api_result_context, | ||
|  |                 ) | ||
|  | 
 | ||
|  |             return api_result | ||
|  | 
 | ||
|  |         return api_function |