611 lines
18 KiB
Python
611 lines
18 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
import inspect
|
|
import os
|
|
import re
|
|
import time
|
|
from types import GenericAlias, MethodType, UnionType
|
|
from typing import (
|
|
Any,
|
|
Literal,
|
|
ParamSpec,
|
|
Protocol,
|
|
TypeVar,
|
|
get_args,
|
|
runtime_checkable,
|
|
)
|
|
|
|
import requests
|
|
from pydantic import ( # noqa: F401
|
|
AliasGenerator,
|
|
BaseModel as PydanticBaseModel,
|
|
ConfigDict,
|
|
PrivateAttr,
|
|
create_model,
|
|
computed_field,
|
|
)
|
|
import yaml
|
|
|
|
from dynamix_sdk.config import Config, ConfigWithAuth
|
|
from dynamix_sdk.exceptions import RequestException
|
|
from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path
|
|
|
|
|
|
NAME_MAPPING_FILE_PATH = os.path.join(
|
|
os.path.dirname(__file__),
|
|
'api',
|
|
'name_mapping.yml',
|
|
)
|
|
PATH_MAPPING_FILE_PATH = os.path.join(
|
|
os.path.dirname(__file__),
|
|
'api',
|
|
'path_mapping.yml',
|
|
)
|
|
|
|
|
|
def read_mapping_file(file_path: str):
|
|
with open(file_path) as file:
|
|
parsed_data: Any = yaml.safe_load(file)
|
|
|
|
if not isinstance(parsed_data, dict):
|
|
raise TypeError
|
|
result_dict: dict[Any, Any] = parsed_data
|
|
|
|
with open(file_path) as file:
|
|
name_mapping_file_lines_amount = sum(
|
|
1 for s in file if s.lstrip() and not s.startswith('#')
|
|
)
|
|
|
|
if len(result_dict) < name_mapping_file_lines_amount:
|
|
raise AssertionError(
|
|
f'File {file_path} contains more code lines than'
|
|
f' keys in the parsed data. Check the file for duplicate keys.'
|
|
)
|
|
|
|
for k, v in result_dict.items():
|
|
if not isinstance(k, str) or not isinstance(v, str):
|
|
raise TypeError
|
|
result_str_dict: dict[str, str] = result_dict
|
|
|
|
return result_str_dict
|
|
|
|
|
|
path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH)
|
|
|
|
|
|
name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH)
|
|
common_mappings_values = [
|
|
v for k, v in name_mapping_dict.items() if '__' not in k
|
|
]
|
|
if len(common_mappings_values) > len(set(common_mappings_values)):
|
|
raise AssertionError(
|
|
f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values'
|
|
f' only for individual mapping (attr_name__model_class_name),'
|
|
f' not common. Check common mappings for duplicate values.'
|
|
)
|
|
|
|
|
|
@runtime_checkable
|
|
class BaseAPIFunctionProtocol(Protocol):
|
|
pass
|
|
|
|
|
|
class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
|
pass
|
|
|
|
|
|
class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
|
pass
|
|
|
|
|
|
class BaseDeleteAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
|
pass
|
|
|
|
|
|
class BasePatchAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
|
pass
|
|
|
|
|
|
class BasePutAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
|
pass
|
|
|
|
|
|
base_proto_to_http_method: dict[type[BaseAPIFunctionProtocol], HTTPMethod] = {
|
|
BasePostAPIFunctionProtocol: HTTPMethod.POST,
|
|
BaseGetAPIFunctionProtocol: HTTPMethod.GET,
|
|
BaseDeleteAPIFunctionProtocol: HTTPMethod.DELETE,
|
|
BasePatchAPIFunctionProtocol: HTTPMethod.PATCH,
|
|
BasePutAPIFunctionProtocol: HTTPMethod.PUT,
|
|
}
|
|
|
|
|
|
class BaseModel(PydanticBaseModel, ABC):
|
|
@staticmethod
|
|
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
|
|
if timestamp > 0:
|
|
return datetime.fromtimestamp(timestamp)
|
|
|
|
|
|
class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC):
|
|
model_config = ConfigDict(
|
|
extra='forbid',
|
|
strict=True,
|
|
frozen=True,
|
|
)
|
|
|
|
|
|
class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
|
pass
|
|
|
|
|
|
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class APIResultContext:
|
|
api_params: BaseAPIParamsModel
|
|
http_response: requests.Response
|
|
|
|
|
|
class BaseAPIResult(ABC):
|
|
_api_params: BaseAPIParamsModel
|
|
_http_response: requests.Response
|
|
|
|
|
|
class BaseAPIResultModel(
|
|
BaseModelWithFrozenStrictExtraForbid,
|
|
BaseAPIResult,
|
|
ABC,
|
|
):
|
|
_api_params: BaseAPIParamsModel = PrivateAttr()
|
|
_http_response: requests.Response = PrivateAttr()
|
|
|
|
def model_post_init(self, __context: APIResultContext):
|
|
self._api_params = __context.api_params
|
|
self._http_response = __context.http_response
|
|
|
|
|
|
class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
|
pass
|
|
|
|
|
|
class BaseAPIResultBasicType(BaseAPIResult, ABC):
|
|
_frozen: bool = False
|
|
|
|
@abstractmethod
|
|
def __new__(cls, value: Any, context: APIResultContext):
|
|
return super().__new__(cls)
|
|
|
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
if self._frozen:
|
|
raise AttributeError(
|
|
'Instance is frozen.',
|
|
)
|
|
return super().__setattr__(name, value)
|
|
|
|
|
|
class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC):
|
|
def __new__(cls, value: str, context: APIResultContext):
|
|
instance = super().__new__(cls, value)
|
|
instance._api_params = context.api_params
|
|
instance._http_response = context.http_response
|
|
instance._frozen = True
|
|
return instance
|
|
|
|
|
|
class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC):
|
|
def __new__(cls, value: int, context: APIResultContext):
|
|
instance = super().__new__(cls, value)
|
|
instance._api_params = context.api_params
|
|
instance._http_response = context.http_response
|
|
instance._frozen = True
|
|
return instance
|
|
|
|
|
|
class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
|
|
value: bool
|
|
|
|
def __bool__(self):
|
|
return self.value
|
|
|
|
def __str__(self):
|
|
return f'{self.__class__.__qualname__}: {self.value}'
|
|
|
|
def __new__(cls, value: bool, context: APIResultContext):
|
|
if isinstance(value, bool):
|
|
_value = value
|
|
elif value is None: # BDX-8379
|
|
_value = True
|
|
else:
|
|
raise TypeError
|
|
instance = super().__new__(cls, _value, context)
|
|
instance.value = _value
|
|
instance._api_params = context.api_params
|
|
instance._http_response = context.http_response
|
|
instance._frozen = True
|
|
return instance
|
|
|
|
|
|
def get_alias(
|
|
field_name: str,
|
|
model_cls: type[BaseModel],
|
|
name_mapping_dict: dict[str, str],
|
|
) -> str:
|
|
if field_name in model_cls.__annotations__.keys():
|
|
individual_alias = name_mapping_dict.get(
|
|
f'{field_name}__{model_cls.__qualname__}',
|
|
)
|
|
if individual_alias:
|
|
return individual_alias
|
|
|
|
if field_name in name_mapping_dict:
|
|
return name_mapping_dict[field_name]
|
|
|
|
raise KeyError(
|
|
f'Mapping for attr {model_cls.__qualname__}.{field_name}'
|
|
f' not found in name mapping dictionary.'
|
|
)
|
|
|
|
if field_name in model_cls.model_computed_fields:
|
|
return field_name
|
|
|
|
for base_cls in model_cls.__bases__:
|
|
if not issubclass(base_cls, BaseModel):
|
|
continue
|
|
|
|
if (
|
|
field_name not in base_cls.model_fields.keys()
|
|
and field_name not in base_cls.model_computed_fields
|
|
):
|
|
continue
|
|
|
|
return get_alias(
|
|
field_name=field_name,
|
|
model_cls=base_cls,
|
|
name_mapping_dict=name_mapping_dict,
|
|
)
|
|
|
|
raise NameError(
|
|
f'Field {field_name} not found in model {model_cls.__qualname__}.'
|
|
)
|
|
|
|
|
|
def set_alias_gen(
|
|
model_cls: type[BaseModel],
|
|
alias_type: Literal['validation', 'serialization'],
|
|
name_mapping_dict: dict[str, str]
|
|
):
|
|
if model_cls.model_config.get('alias_generator'):
|
|
return
|
|
|
|
def get_model_classes_from_annotation(
|
|
annotation: Any,
|
|
classes: list[type[BaseModel]] | None = None,
|
|
):
|
|
_classes = classes or []
|
|
if annotation and annotation is not Any:
|
|
if isinstance(annotation, (UnionType, GenericAlias)):
|
|
for cls in get_args(annotation):
|
|
_classes += get_model_classes_from_annotation(cls)
|
|
elif issubclass(annotation, BaseModel):
|
|
_classes.append(annotation)
|
|
return _classes
|
|
|
|
for field in model_cls.model_fields.values():
|
|
for cls in get_model_classes_from_annotation(field.annotation):
|
|
set_alias_gen(
|
|
model_cls=cls,
|
|
alias_type=alias_type,
|
|
name_mapping_dict=name_mapping_dict,
|
|
)
|
|
|
|
def alias_gen(field_name: str):
|
|
return get_alias(
|
|
field_name=field_name,
|
|
model_cls=model_cls,
|
|
name_mapping_dict=name_mapping_dict,
|
|
)
|
|
|
|
kwargs = {
|
|
f'{alias_type}_alias': alias_gen
|
|
}
|
|
|
|
model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs)
|
|
model_cls.model_rebuild(force=True)
|
|
|
|
|
|
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
|
|
|
|
|
|
def gen_api_params_cls_name(api_path: str):
|
|
return gen_cls_name_from_url_path(
|
|
url_path=re.sub(r'[\[\]]', '', api_path),
|
|
postfix='ParamsModel',
|
|
)
|
|
|
|
|
|
def create_api_params_cls(
|
|
cls_name: str,
|
|
module_name: str,
|
|
protocol_method: Callable
|
|
):
|
|
field_definitions = {}
|
|
params = inspect.signature(protocol_method).parameters.values()
|
|
inspect_empty = inspect.Parameter.empty
|
|
for param in params:
|
|
if param.name == 'self':
|
|
continue
|
|
|
|
field_definitions[param.name] = (
|
|
param.annotation,
|
|
... if param.default is inspect_empty else param.default,
|
|
)
|
|
return create_model(
|
|
cls_name,
|
|
__base__=BaseAPIParamsModel,
|
|
__module__=module_name,
|
|
**field_definitions,
|
|
)
|
|
|
|
|
|
def get_func_api_path(*, func_name: str, path_mapping: dict) -> str:
|
|
api_path_parts = []
|
|
for func_part in func_name.split('.'):
|
|
api_path_parts.append(
|
|
path_mapping.get(func_part, func_part)
|
|
)
|
|
return f'/{"/".join(api_path_parts)}'
|
|
|
|
|
|
APIParamsP = ParamSpec('APIParamsP')
|
|
APIResultT = TypeVar(
|
|
'APIResultT',
|
|
bound=BaseAPIResultModel | BaseAPIResultBasicType,
|
|
)
|
|
|
|
|
|
class BaseAPI(ABC):
|
|
_config: Config
|
|
_parent_api_group_names: list[str]
|
|
_name_mapping_dict: dict[str, str] = name_mapping_dict
|
|
_path_mapping_dict: dict[str, str] = path_mapping_dict
|
|
_post_json: bool = True
|
|
|
|
def __init_subclass__(
|
|
cls,
|
|
path_mapping_dict: None | dict[str, str] = None,
|
|
name_mapping_dict: None | dict[str, str] = None,
|
|
post_json: None | bool = None,
|
|
) -> None:
|
|
if path_mapping_dict:
|
|
cls._path_mapping_dict = path_mapping_dict
|
|
if name_mapping_dict:
|
|
cls._name_mapping_dict = name_mapping_dict
|
|
if post_json is not None:
|
|
cls._post_json = post_json
|
|
|
|
def __init__(
|
|
self,
|
|
config: Config,
|
|
parent_api_group_names: None | list = None,
|
|
):
|
|
self._config = config
|
|
self._parent_api_group_names = parent_api_group_names or []
|
|
|
|
def __getattribute__(self, name: str) -> Any:
|
|
if name.startswith('_'):
|
|
return super().__getattribute__(name)
|
|
if name in self.__annotations__:
|
|
annotation = self.__annotations__[name]
|
|
if issubclass(annotation, BaseAPI):
|
|
api_cls = annotation
|
|
return api_cls(
|
|
config=self._config,
|
|
parent_api_group_names=(
|
|
self._parent_api_group_names + [name]
|
|
)
|
|
)
|
|
attr_value = super().__getattribute__(name)
|
|
|
|
if inspect.ismethod(attr_value):
|
|
return self._make_api_function(attr_value)
|
|
|
|
return attr_value
|
|
|
|
def _make_api_function(
|
|
self,
|
|
protocol_method: Callable[APIParamsP, APIResultT],
|
|
/,
|
|
) -> Callable[APIParamsP, APIResultT]:
|
|
return_type = inspect.signature(protocol_method).return_annotation
|
|
if not issubclass(return_type, get_args(APIResultT.__bound__)):
|
|
raise TypeError(
|
|
f'Return type annotation of {protocol_method}'
|
|
f' must be subclass of {get_args(APIResultT.__bound__)}.'
|
|
)
|
|
api_result_cls: type[APIResultT] = return_type
|
|
|
|
if not isinstance(protocol_method, MethodType):
|
|
raise TypeError
|
|
|
|
func_name_parts = (
|
|
self._parent_api_group_names + [protocol_method.__name__]
|
|
)
|
|
func_name = '.'.join(func_name_parts)
|
|
|
|
def api_function(
|
|
*args: APIParamsP.args,
|
|
**kwargs: APIParamsP.kwargs,
|
|
):
|
|
config = self._config
|
|
|
|
api_path = get_func_api_path(
|
|
func_name=func_name,
|
|
path_mapping=self._path_mapping_dict,
|
|
)
|
|
|
|
api_params_cls_name = gen_api_params_cls_name(
|
|
api_path=api_path,
|
|
)
|
|
|
|
api_params_cls = params_model_classes.get(
|
|
api_params_cls_name,
|
|
None,
|
|
)
|
|
if not api_params_cls:
|
|
api_params_cls = create_api_params_cls(
|
|
cls_name=api_params_cls_name,
|
|
module_name=api_result_cls.__module__,
|
|
protocol_method=protocol_method,
|
|
)
|
|
params_model_classes[api_params_cls_name] = api_params_cls
|
|
|
|
set_alias_gen(
|
|
model_cls=api_params_cls,
|
|
alias_type='serialization',
|
|
name_mapping_dict=self._name_mapping_dict,
|
|
)
|
|
|
|
api_params = api_params_cls(*args, **kwargs)
|
|
|
|
req_headers = None
|
|
if isinstance(config, ConfigWithAuth):
|
|
req_headers = {
|
|
'Authorization': f'bearer {config.jwt}',
|
|
}
|
|
|
|
model_data = api_params.model_dump(
|
|
by_alias=True,
|
|
exclude_none=True,
|
|
)
|
|
|
|
for api_mixin_cls in self.__class__.__bases__[1:]:
|
|
if not hasattr(api_mixin_cls, protocol_method.__name__):
|
|
continue
|
|
proto_cls_base = api_mixin_cls.__base__
|
|
if (
|
|
not proto_cls_base
|
|
or not issubclass(proto_cls_base, BaseAPIFunctionProtocol)
|
|
):
|
|
raise TypeError(
|
|
f'Class {api_mixin_cls.__qualname__} must have'
|
|
f' {BaseAPIFunctionProtocol.__qualname__}'
|
|
f' as parent class.'
|
|
)
|
|
http_method = base_proto_to_http_method[proto_cls_base]
|
|
break
|
|
else:
|
|
raise RuntimeError
|
|
|
|
req_params = None
|
|
req_json = None
|
|
req_data = None
|
|
match http_method:
|
|
case HTTPMethod.GET:
|
|
req_params = model_data
|
|
case (
|
|
HTTPMethod.POST
|
|
| HTTPMethod.DELETE
|
|
| HTTPMethod.PATCH
|
|
| HTTPMethod.PUT
|
|
):
|
|
if self._post_json:
|
|
req_json = model_data
|
|
else:
|
|
req_data = model_data
|
|
case _:
|
|
raise RuntimeError
|
|
|
|
http_request = requests.Request(
|
|
method=http_method,
|
|
url=config.get_api_url(api_path=api_path),
|
|
headers=req_headers,
|
|
params=req_params,
|
|
json=req_json,
|
|
data=req_data,
|
|
).prepare()
|
|
|
|
attempts = config.http503_attempts
|
|
status_code = None
|
|
try:
|
|
with requests.Session() as session:
|
|
while True:
|
|
http_response = session.send(
|
|
request=http_request,
|
|
verify=config.verify_ssl,
|
|
)
|
|
|
|
status_code = http_response.status_code
|
|
if status_code == 503:
|
|
if attempts < 1:
|
|
break
|
|
else:
|
|
attempts -= 1
|
|
time.sleep(config.http503_attempts_interval)
|
|
else:
|
|
break
|
|
|
|
http_response.raise_for_status()
|
|
except requests.exceptions.RequestException as e:
|
|
if config.wrap_request_exceptions:
|
|
raise RequestException(
|
|
orig_exception=e,
|
|
func_name=func_name,
|
|
func_kwargs=kwargs,
|
|
status_code=status_code,
|
|
)
|
|
else:
|
|
raise e
|
|
|
|
api_result_context = APIResultContext(
|
|
api_params=api_params,
|
|
http_response=http_response,
|
|
)
|
|
|
|
if issubclass(api_result_cls, BaseAPIResultModel):
|
|
set_alias_gen(
|
|
model_cls=api_result_cls,
|
|
alias_type='validation',
|
|
name_mapping_dict=self._name_mapping_dict,
|
|
)
|
|
|
|
result_extra_allow = (
|
|
api_result_cls.model_config.get('extra') == 'allow'
|
|
)
|
|
if config.result_extra_allow != result_extra_allow:
|
|
if config.result_extra_allow:
|
|
api_result_cls.model_config['extra'] = 'allow'
|
|
else:
|
|
api_result_cls.model_config['extra'] = (
|
|
BaseAPIResultModel.model_config.get('extra')
|
|
)
|
|
api_result_cls.model_rebuild(force=True)
|
|
|
|
api_result = api_result_cls.model_validate_json(
|
|
json_data=http_response.content,
|
|
context=api_result_context,
|
|
)
|
|
else:
|
|
try:
|
|
decoded_content = http_response.json()
|
|
except requests.JSONDecodeError:
|
|
decoded_content = http_response.text
|
|
|
|
api_result = api_result_cls(
|
|
value=decoded_content,
|
|
context=api_result_context,
|
|
)
|
|
|
|
return api_result
|
|
|
|
api_function.__name__ = func_name.replace('.', '__')
|
|
|
|
if self._config.f_decorators:
|
|
for decorator in reversed(self._config.f_decorators):
|
|
api_function = decorator(api_function)
|
|
|
|
return api_function
|