Files
dynamix-python-sdk/src/dynamix_sdk/base.py
2026-03-13 17:18:28 +03:00

611 lines
18 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
import inspect
import os
import re
import time
from types import GenericAlias, MethodType, UnionType
from typing import (
Any,
Literal,
ParamSpec,
Protocol,
TypeVar,
get_args,
runtime_checkable,
)
import requests
from pydantic import ( # noqa: F401
AliasGenerator,
BaseModel as PydanticBaseModel,
ConfigDict,
PrivateAttr,
create_model,
computed_field,
)
import yaml
from dynamix_sdk.config import Config, ConfigWithAuth
from dynamix_sdk.exceptions import RequestException
from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path
NAME_MAPPING_FILE_PATH = os.path.join(
os.path.dirname(__file__),
'api',
'name_mapping.yml',
)
PATH_MAPPING_FILE_PATH = os.path.join(
os.path.dirname(__file__),
'api',
'path_mapping.yml',
)
def read_mapping_file(file_path: str):
with open(file_path) as file:
parsed_data: Any = yaml.safe_load(file)
if not isinstance(parsed_data, dict):
raise TypeError
result_dict: dict[Any, Any] = parsed_data
with open(file_path) as file:
name_mapping_file_lines_amount = sum(
1 for s in file if s.lstrip() and not s.startswith('#')
)
if len(result_dict) < name_mapping_file_lines_amount:
raise AssertionError(
f'File {file_path} contains more code lines than'
f' keys in the parsed data. Check the file for duplicate keys.'
)
for k, v in result_dict.items():
if not isinstance(k, str) or not isinstance(v, str):
raise TypeError
result_str_dict: dict[str, str] = result_dict
return result_str_dict
path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH)
name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH)
common_mappings_values = [
v for k, v in name_mapping_dict.items() if '__' not in k
]
if len(common_mappings_values) > len(set(common_mappings_values)):
raise AssertionError(
f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values'
f' only for individual mapping (attr_name__model_class_name),'
f' not common. Check common mappings for duplicate values.'
)
@runtime_checkable
class BaseAPIFunctionProtocol(Protocol):
pass
class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BaseDeleteAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BasePatchAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BasePutAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
base_proto_to_http_method: dict[type[BaseAPIFunctionProtocol], HTTPMethod] = {
BasePostAPIFunctionProtocol: HTTPMethod.POST,
BaseGetAPIFunctionProtocol: HTTPMethod.GET,
BaseDeleteAPIFunctionProtocol: HTTPMethod.DELETE,
BasePatchAPIFunctionProtocol: HTTPMethod.PATCH,
BasePutAPIFunctionProtocol: HTTPMethod.PUT,
}
class BaseModel(PydanticBaseModel, ABC):
@staticmethod
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
if timestamp > 0:
return datetime.fromtimestamp(timestamp)
class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC):
model_config = ConfigDict(
extra='forbid',
strict=True,
frozen=True,
)
class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
pass
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
pass
@dataclass
class APIResultContext:
api_params: BaseAPIParamsModel
http_response: requests.Response
class BaseAPIResult(ABC):
_api_params: BaseAPIParamsModel
_http_response: requests.Response
class BaseAPIResultModel(
BaseModelWithFrozenStrictExtraForbid,
BaseAPIResult,
ABC,
):
_api_params: BaseAPIParamsModel = PrivateAttr()
_http_response: requests.Response = PrivateAttr()
def model_post_init(self, __context: APIResultContext):
self._api_params = __context.api_params
self._http_response = __context.http_response
class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
pass
class BaseAPIResultBasicType(BaseAPIResult, ABC):
_frozen: bool = False
@abstractmethod
def __new__(cls, value: Any, context: APIResultContext):
return super().__new__(cls)
def __setattr__(self, name: str, value: Any) -> None:
if self._frozen:
raise AttributeError(
'Instance is frozen.',
)
return super().__setattr__(name, value)
class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC):
def __new__(cls, value: str, context: APIResultContext):
instance = super().__new__(cls, value)
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC):
def __new__(cls, value: int, context: APIResultContext):
instance = super().__new__(cls, value)
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
value: bool
def __bool__(self):
return self.value
def __str__(self):
return f'{self.__class__.__qualname__}: {self.value}'
def __new__(cls, value: bool, context: APIResultContext):
if isinstance(value, bool):
_value = value
elif value is None: # BDX-8379
_value = True
else:
raise TypeError
instance = super().__new__(cls, _value, context)
instance.value = _value
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
def get_alias(
field_name: str,
model_cls: type[BaseModel],
name_mapping_dict: dict[str, str],
) -> str:
if field_name in model_cls.__annotations__.keys():
individual_alias = name_mapping_dict.get(
f'{field_name}__{model_cls.__qualname__}',
)
if individual_alias:
return individual_alias
if field_name in name_mapping_dict:
return name_mapping_dict[field_name]
raise KeyError(
f'Mapping for attr {model_cls.__qualname__}.{field_name}'
f' not found in name mapping dictionary.'
)
if field_name in model_cls.model_computed_fields:
return field_name
for base_cls in model_cls.__bases__:
if not issubclass(base_cls, BaseModel):
continue
if (
field_name not in base_cls.model_fields.keys()
and field_name not in base_cls.model_computed_fields
):
continue
return get_alias(
field_name=field_name,
model_cls=base_cls,
name_mapping_dict=name_mapping_dict,
)
raise NameError(
f'Field {field_name} not found in model {model_cls.__qualname__}.'
)
def set_alias_gen(
model_cls: type[BaseModel],
alias_type: Literal['validation', 'serialization'],
name_mapping_dict: dict[str, str]
):
if model_cls.model_config.get('alias_generator'):
return
def get_model_classes_from_annotation(
annotation: Any,
classes: list[type[BaseModel]] | None = None,
):
_classes = classes or []
if annotation and annotation is not Any:
if isinstance(annotation, (UnionType, GenericAlias)):
for cls in get_args(annotation):
_classes += get_model_classes_from_annotation(cls)
elif issubclass(annotation, BaseModel):
_classes.append(annotation)
return _classes
for field in model_cls.model_fields.values():
for cls in get_model_classes_from_annotation(field.annotation):
set_alias_gen(
model_cls=cls,
alias_type=alias_type,
name_mapping_dict=name_mapping_dict,
)
def alias_gen(field_name: str):
return get_alias(
field_name=field_name,
model_cls=model_cls,
name_mapping_dict=name_mapping_dict,
)
kwargs = {
f'{alias_type}_alias': alias_gen
}
model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs)
model_cls.model_rebuild(force=True)
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
def gen_api_params_cls_name(api_path: str):
return gen_cls_name_from_url_path(
url_path=re.sub(r'[\[\]]', '', api_path),
postfix='ParamsModel',
)
def create_api_params_cls(
cls_name: str,
module_name: str,
protocol_method: Callable
):
field_definitions = {}
params = inspect.signature(protocol_method).parameters.values()
inspect_empty = inspect.Parameter.empty
for param in params:
if param.name == 'self':
continue
field_definitions[param.name] = (
param.annotation,
... if param.default is inspect_empty else param.default,
)
return create_model(
cls_name,
__base__=BaseAPIParamsModel,
__module__=module_name,
**field_definitions,
)
def get_func_api_path(*, func_name: str, path_mapping: dict) -> str:
api_path_parts = []
for func_part in func_name.split('.'):
api_path_parts.append(
path_mapping.get(func_part, func_part)
)
return f'/{"/".join(api_path_parts)}'
APIParamsP = ParamSpec('APIParamsP')
APIResultT = TypeVar(
'APIResultT',
bound=BaseAPIResultModel | BaseAPIResultBasicType,
)
class BaseAPI(ABC):
_config: Config
_parent_api_group_names: list[str]
_name_mapping_dict: dict[str, str] = name_mapping_dict
_path_mapping_dict: dict[str, str] = path_mapping_dict
_post_json: bool = True
def __init_subclass__(
cls,
path_mapping_dict: None | dict[str, str] = None,
name_mapping_dict: None | dict[str, str] = None,
post_json: None | bool = None,
) -> None:
if path_mapping_dict:
cls._path_mapping_dict = path_mapping_dict
if name_mapping_dict:
cls._name_mapping_dict = name_mapping_dict
if post_json is not None:
cls._post_json = post_json
def __init__(
self,
config: Config,
parent_api_group_names: None | list = None,
):
self._config = config
self._parent_api_group_names = parent_api_group_names or []
def __getattribute__(self, name: str) -> Any:
if name.startswith('_'):
return super().__getattribute__(name)
if name in self.__annotations__:
annotation = self.__annotations__[name]
if issubclass(annotation, BaseAPI):
api_cls = annotation
return api_cls(
config=self._config,
parent_api_group_names=(
self._parent_api_group_names + [name]
)
)
attr_value = super().__getattribute__(name)
if inspect.ismethod(attr_value):
return self._make_api_function(attr_value)
return attr_value
def _make_api_function(
self,
protocol_method: Callable[APIParamsP, APIResultT],
/,
) -> Callable[APIParamsP, APIResultT]:
return_type = inspect.signature(protocol_method).return_annotation
if not issubclass(return_type, get_args(APIResultT.__bound__)):
raise TypeError(
f'Return type annotation of {protocol_method}'
f' must be subclass of {get_args(APIResultT.__bound__)}.'
)
api_result_cls: type[APIResultT] = return_type
if not isinstance(protocol_method, MethodType):
raise TypeError
func_name_parts = (
self._parent_api_group_names + [protocol_method.__name__]
)
func_name = '.'.join(func_name_parts)
def api_function(
*args: APIParamsP.args,
**kwargs: APIParamsP.kwargs,
):
config = self._config
api_path = get_func_api_path(
func_name=func_name,
path_mapping=self._path_mapping_dict,
)
api_params_cls_name = gen_api_params_cls_name(
api_path=api_path,
)
api_params_cls = params_model_classes.get(
api_params_cls_name,
None,
)
if not api_params_cls:
api_params_cls = create_api_params_cls(
cls_name=api_params_cls_name,
module_name=api_result_cls.__module__,
protocol_method=protocol_method,
)
params_model_classes[api_params_cls_name] = api_params_cls
set_alias_gen(
model_cls=api_params_cls,
alias_type='serialization',
name_mapping_dict=self._name_mapping_dict,
)
api_params = api_params_cls(*args, **kwargs)
req_headers = None
if isinstance(config, ConfigWithAuth):
req_headers = {
'Authorization': f'bearer {config.jwt}',
}
model_data = api_params.model_dump(
by_alias=True,
exclude_none=True,
)
for api_mixin_cls in self.__class__.__bases__[1:]:
if not hasattr(api_mixin_cls, protocol_method.__name__):
continue
proto_cls_base = api_mixin_cls.__base__
if (
not proto_cls_base
or not issubclass(proto_cls_base, BaseAPIFunctionProtocol)
):
raise TypeError(
f'Class {api_mixin_cls.__qualname__} must have'
f' {BaseAPIFunctionProtocol.__qualname__}'
f' as parent class.'
)
http_method = base_proto_to_http_method[proto_cls_base]
break
else:
raise RuntimeError
req_params = None
req_json = None
req_data = None
match http_method:
case HTTPMethod.GET:
req_params = model_data
case (
HTTPMethod.POST
| HTTPMethod.DELETE
| HTTPMethod.PATCH
| HTTPMethod.PUT
):
if self._post_json:
req_json = model_data
else:
req_data = model_data
case _:
raise RuntimeError
http_request = requests.Request(
method=http_method,
url=config.get_api_url(api_path=api_path),
headers=req_headers,
params=req_params,
json=req_json,
data=req_data,
).prepare()
attempts = config.http503_attempts
status_code = None
try:
with requests.Session() as session:
while True:
http_response = session.send(
request=http_request,
verify=config.verify_ssl,
)
status_code = http_response.status_code
if status_code == 503:
if attempts < 1:
break
else:
attempts -= 1
time.sleep(config.http503_attempts_interval)
else:
break
http_response.raise_for_status()
except requests.exceptions.RequestException as e:
if config.wrap_request_exceptions:
raise RequestException(
orig_exception=e,
func_name=func_name,
func_kwargs=kwargs,
status_code=status_code,
)
else:
raise e
api_result_context = APIResultContext(
api_params=api_params,
http_response=http_response,
)
if issubclass(api_result_cls, BaseAPIResultModel):
set_alias_gen(
model_cls=api_result_cls,
alias_type='validation',
name_mapping_dict=self._name_mapping_dict,
)
result_extra_allow = (
api_result_cls.model_config.get('extra') == 'allow'
)
if config.result_extra_allow != result_extra_allow:
if config.result_extra_allow:
api_result_cls.model_config['extra'] = 'allow'
else:
api_result_cls.model_config['extra'] = (
BaseAPIResultModel.model_config.get('extra')
)
api_result_cls.model_rebuild(force=True)
api_result = api_result_cls.model_validate_json(
json_data=http_response.content,
context=api_result_context,
)
else:
try:
decoded_content = http_response.json()
except requests.JSONDecodeError:
decoded_content = http_response.text
api_result = api_result_cls(
value=decoded_content,
context=api_result_context,
)
return api_result
api_function.__name__ = func_name.replace('.', '__')
if self._config.f_decorators:
for decorator in reversed(self._config.f_decorators):
api_function = decorator(api_function)
return api_function