Files
dynamix-python-sdk/src/dynamix_sdk/base.py

602 lines
18 KiB
Python
Raw Normal View History

2025-03-21 17:47:09 +03:00
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
import inspect
import os
import re
import time
2025-06-06 08:20:45 +03:00
from types import GenericAlias, MethodType, UnionType
from typing import (
Any,
Literal,
ParamSpec,
Protocol,
TypeVar,
get_args,
runtime_checkable,
)
2025-03-21 17:47:09 +03:00
import requests
from pydantic import (
AliasGenerator,
BaseModel as PydanticBaseModel,
ConfigDict,
PrivateAttr,
create_model,
)
import yaml
from dynamix_sdk.config import Config, ConfigWithAuth
2025-11-25 18:09:46 +03:00
from dynamix_sdk.exceptions import RequestException
2025-03-21 17:47:09 +03:00
from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path
NAME_MAPPING_FILE_PATH = os.path.join(
os.path.dirname(__file__),
'api',
'name_mapping.yml',
)
PATH_MAPPING_FILE_PATH = os.path.join(
os.path.dirname(__file__),
'api',
'path_mapping.yml',
)
def read_mapping_file(file_path: str):
with open(file_path) as file:
parsed_data: Any = yaml.safe_load(file)
if not isinstance(parsed_data, dict):
raise TypeError
result_dict: dict[Any, Any] = parsed_data
with open(file_path) as file:
name_mapping_file_lines_amount = sum(
1 for s in file if s.lstrip() and not s.startswith('#')
)
if len(result_dict) < name_mapping_file_lines_amount:
raise AssertionError(
f'File {file_path} contains more code lines than'
f' keys in the parsed data. Check the file for duplicate keys.'
)
for k, v in result_dict.items():
if not isinstance(k, str) or not isinstance(v, str):
raise TypeError
result_str_dict: dict[str, str] = result_dict
return result_str_dict
path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH)
name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH)
common_mappings_values = [
v for k, v in name_mapping_dict.items() if '__' not in k
]
if len(common_mappings_values) > len(set(common_mappings_values)):
raise AssertionError(
f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values'
f' only for individual mapping (attr_name__model_class_name),'
f' not common. Check common mappings for duplicate values.'
)
2025-06-06 08:20:45 +03:00
@runtime_checkable
2025-03-21 17:47:09 +03:00
class BaseAPIFunctionProtocol(Protocol):
pass
class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
2025-11-25 18:09:46 +03:00
class BaseDeleteAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BasePatchAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BasePutAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
2025-06-06 08:20:45 +03:00
base_proto_to_http_method: dict[type[BaseAPIFunctionProtocol], HTTPMethod] = {
BasePostAPIFunctionProtocol: HTTPMethod.POST,
BaseGetAPIFunctionProtocol: HTTPMethod.GET,
2025-11-25 18:09:46 +03:00
BaseDeleteAPIFunctionProtocol: HTTPMethod.DELETE,
BasePatchAPIFunctionProtocol: HTTPMethod.PATCH,
BasePutAPIFunctionProtocol: HTTPMethod.PUT,
2025-06-06 08:20:45 +03:00
}
2025-03-21 17:47:09 +03:00
class BaseModel(PydanticBaseModel, ABC):
@staticmethod
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
if timestamp > 0:
return datetime.fromtimestamp(timestamp)
class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC):
model_config = ConfigDict(
extra='forbid',
strict=True,
frozen=True,
)
class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
pass
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
2025-06-06 08:20:45 +03:00
pass
2025-03-21 17:47:09 +03:00
@dataclass
class APIResultContext:
api_params: BaseAPIParamsModel
http_response: requests.Response
class BaseAPIResult(ABC):
_api_params: BaseAPIParamsModel
_http_response: requests.Response
class BaseAPIResultModel(
BaseModelWithFrozenStrictExtraForbid,
BaseAPIResult,
ABC,
):
_api_params: BaseAPIParamsModel = PrivateAttr()
_http_response: requests.Response = PrivateAttr()
def model_post_init(self, __context: APIResultContext):
self._api_params = __context.api_params
self._http_response = __context.http_response
class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
pass
class BaseAPIResultBasicType(BaseAPIResult, ABC):
_frozen: bool = False
@abstractmethod
def __new__(cls, value: Any, context: APIResultContext):
return super().__new__(cls)
def __setattr__(self, name: str, value: Any) -> None:
if self._frozen:
raise AttributeError(
'Instance is frozen.',
)
return super().__setattr__(name, value)
class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC):
def __new__(cls, value: str, context: APIResultContext):
instance = super().__new__(cls, value)
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC):
def __new__(cls, value: int, context: APIResultContext):
instance = super().__new__(cls, value)
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
value: bool
def __bool__(self):
return self.value
def __str__(self):
return f'{self.__class__.__qualname__}: {self.value}'
def __new__(cls, value: bool, context: APIResultContext):
2025-06-06 08:20:45 +03:00
if isinstance(value, bool):
_value = value
elif value is None: # BDX-8379
_value = True
else:
2025-03-21 17:47:09 +03:00
raise TypeError
2025-06-06 08:20:45 +03:00
instance = super().__new__(cls, _value, context)
instance.value = _value
2025-03-21 17:47:09 +03:00
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
2025-06-06 08:20:45 +03:00
def get_alias(
field_name: str,
model_cls: type[BaseModel],
name_mapping_dict: dict[str, str],
) -> str:
if field_name in model_cls.__annotations__.keys():
individual_alias = name_mapping_dict.get(
f'{field_name}__{model_cls.__qualname__}',
)
if individual_alias:
return individual_alias
if field_name in name_mapping_dict:
return name_mapping_dict[field_name]
raise KeyError(
f'Mapping for attr {model_cls.__qualname__}.{field_name}'
f' not found in name mapping dictionary.'
)
for base_cls in model_cls.__bases__:
if not issubclass(base_cls, BaseModel):
continue
if field_name not in base_cls.model_fields.keys():
continue
return get_alias(
field_name=field_name,
model_cls=base_cls,
name_mapping_dict=name_mapping_dict,
)
raise NameError(
f'Field {field_name} not found in model {model_cls.__qualname__}.'
)
2025-03-21 17:47:09 +03:00
def set_alias_gen(
model_cls: type[BaseModel],
alias_type: Literal['validation', 'serialization'],
name_mapping_dict: dict[str, str]
):
if model_cls.model_config.get('alias_generator'):
return
def get_model_classes_from_annotation(
2025-06-06 08:20:45 +03:00
annotation: Any,
2025-03-21 17:47:09 +03:00
classes: list[type[BaseModel]] | None = None,
):
_classes = classes or []
2025-06-06 08:20:45 +03:00
if annotation and annotation is not Any:
2025-03-21 17:47:09 +03:00
if isinstance(annotation, (UnionType, GenericAlias)):
for cls in get_args(annotation):
_classes += get_model_classes_from_annotation(cls)
elif issubclass(annotation, BaseModel):
_classes.append(annotation)
return _classes
for field in model_cls.model_fields.values():
for cls in get_model_classes_from_annotation(field.annotation):
set_alias_gen(
model_cls=cls,
alias_type=alias_type,
name_mapping_dict=name_mapping_dict,
)
def alias_gen(field_name: str):
2025-06-06 08:20:45 +03:00
return get_alias(
field_name=field_name,
model_cls=model_cls,
name_mapping_dict=name_mapping_dict,
2025-03-21 17:47:09 +03:00
)
kwargs = {
f'{alias_type}_alias': alias_gen
}
model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs)
model_cls.model_rebuild(force=True)
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
2025-06-06 08:20:45 +03:00
def gen_api_params_cls_name(api_path: str):
return gen_cls_name_from_url_path(
url_path=re.sub(r'[\[\]]', '', api_path),
postfix='ParamsModel',
)
def create_api_params_cls(
cls_name: str,
module_name: str,
protocol_method: Callable
):
field_definitions = {}
params = inspect.signature(protocol_method).parameters.values()
inspect_empty = inspect.Parameter.empty
for param in params:
if param.name == 'self':
continue
field_definitions[param.name] = (
param.annotation,
... if param.default is inspect_empty else param.default,
)
return create_model(
cls_name,
__base__=BaseAPIParamsModel,
__module__=module_name,
**field_definitions,
)
2025-11-25 18:09:46 +03:00
def get_func_api_path(*, func_name: str, path_mapping: dict) -> str:
api_path_parts = []
for func_part in func_name.split('.'):
api_path_parts.append(
path_mapping.get(func_part, func_part)
)
return f'/{"/".join(api_path_parts)}'
2025-03-21 17:47:09 +03:00
APIParamsP = ParamSpec('APIParamsP')
APIResultT = TypeVar(
'APIResultT',
bound=BaseAPIResultModel | BaseAPIResultBasicType,
)
class BaseAPI(ABC):
_config: Config
2025-11-25 18:09:46 +03:00
_parent_api_group_names: list[str]
2025-03-21 17:47:09 +03:00
_name_mapping_dict: dict[str, str] = name_mapping_dict
2025-11-25 18:09:46 +03:00
_path_mapping_dict: dict[str, str] = path_mapping_dict
2025-03-21 17:47:09 +03:00
_post_json: bool = True
def __init_subclass__(
cls,
path_mapping_dict: None | dict[str, str] = None,
name_mapping_dict: None | dict[str, str] = None,
post_json: None | bool = None,
) -> None:
if path_mapping_dict:
cls._path_mapping_dict = path_mapping_dict
if name_mapping_dict:
cls._name_mapping_dict = name_mapping_dict
if post_json is not None:
cls._post_json = post_json
def __init__(
self,
config: Config,
2025-11-25 18:09:46 +03:00
parent_api_group_names: None | list = None,
2025-03-21 17:47:09 +03:00
):
self._config = config
2025-11-25 18:09:46 +03:00
self._parent_api_group_names = parent_api_group_names or []
2025-03-21 17:47:09 +03:00
def __getattribute__(self, name: str) -> Any:
if name.startswith('_'):
return super().__getattribute__(name)
else:
if name in self.__annotations__:
annotation = self.__annotations__[name]
if issubclass(annotation, BaseAPI):
api_cls = annotation
return api_cls(
config=self._config,
2025-11-25 18:09:46 +03:00
parent_api_group_names=(
self._parent_api_group_names + [name]
2025-03-21 17:47:09 +03:00
)
)
else:
attr_value = super().__getattribute__(name)
if not inspect.ismethod(attr_value):
raise ValueError
return self._make_api_function(attr_value)
def _make_api_function(
self,
protocol_method: Callable[APIParamsP, APIResultT],
/,
) -> Callable[APIParamsP, APIResultT]:
return_type = inspect.signature(protocol_method).return_annotation
if not issubclass(return_type, get_args(APIResultT.__bound__)):
raise TypeError(
f'Return type annotation of {protocol_method}'
f' must be subclass of {get_args(APIResultT.__bound__)}.'
)
api_result_cls: type[APIResultT] = return_type
2025-06-06 08:20:45 +03:00
if not isinstance(protocol_method, MethodType):
raise TypeError
2025-11-25 18:09:46 +03:00
func_name_parts = (
self._parent_api_group_names + [protocol_method.__name__]
)
func_name = '.'.join(func_name_parts)
2025-03-21 17:47:09 +03:00
def api_function(
*args: APIParamsP.args,
**kwargs: APIParamsP.kwargs,
):
config = self._config
2025-11-25 18:09:46 +03:00
api_path = get_func_api_path(
func_name=func_name,
path_mapping=self._path_mapping_dict,
)
2025-03-21 17:47:09 +03:00
2025-06-06 08:20:45 +03:00
api_params_cls_name = gen_api_params_cls_name(
api_path=api_path,
2025-03-21 17:47:09 +03:00
)
api_params_cls = params_model_classes.get(
api_params_cls_name,
None,
)
if not api_params_cls:
2025-06-06 08:20:45 +03:00
api_params_cls = create_api_params_cls(
cls_name=api_params_cls_name,
module_name=api_result_cls.__module__,
protocol_method=protocol_method,
2025-03-21 17:47:09 +03:00
)
params_model_classes[api_params_cls_name] = api_params_cls
set_alias_gen(
model_cls=api_params_cls,
alias_type='serialization',
name_mapping_dict=self._name_mapping_dict,
)
api_params = api_params_cls(*args, **kwargs)
req_headers = None
if isinstance(config, ConfigWithAuth):
req_headers = {
'Authorization': f'bearer {config.jwt}',
}
model_data = api_params.model_dump(
by_alias=True,
exclude_none=True,
)
for api_mixin_cls in self.__class__.__bases__[1:]:
2025-06-06 08:20:45 +03:00
if not hasattr(api_mixin_cls, protocol_method.__name__):
continue
proto_cls_base = api_mixin_cls.__base__
if (
not proto_cls_base
or not issubclass(proto_cls_base, BaseAPIFunctionProtocol)
):
raise TypeError(
f'Class {api_mixin_cls.__qualname__} must have'
f' {BaseAPIFunctionProtocol.__qualname__}'
f' as parent class.'
)
http_method = base_proto_to_http_method[proto_cls_base]
break
2025-03-21 17:47:09 +03:00
else:
raise RuntimeError
req_params = None
req_json = None
req_data = None
match http_method:
case HTTPMethod.GET:
req_params = model_data
2025-11-25 18:09:46 +03:00
case (
HTTPMethod.POST
| HTTPMethod.DELETE
| HTTPMethod.PATCH
| HTTPMethod.PUT
):
2025-03-21 17:47:09 +03:00
if self._post_json:
req_json = model_data
else:
req_data = model_data
case _:
raise RuntimeError
http_request = requests.Request(
method=http_method,
url=config.get_api_url(api_path=api_path),
headers=req_headers,
params=req_params,
json=req_json,
data=req_data,
).prepare()
attempts = config.http503_attempts
2025-11-25 18:09:46 +03:00
try:
with requests.Session() as session:
while True:
http_response = session.send(
request=http_request,
verify=config.verify_ssl,
)
2025-03-21 17:47:09 +03:00
2025-11-25 18:09:46 +03:00
if http_response.status_code == 503:
if attempts < 1:
break
else:
attempts -= 1
time.sleep(config.http503_attempts_interval)
2025-03-21 17:47:09 +03:00
else:
2025-11-25 18:09:46 +03:00
break
2025-03-21 17:47:09 +03:00
2025-11-25 18:09:46 +03:00
http_response.raise_for_status()
except requests.exceptions.RequestException as e:
if config.wrap_request_exceptions:
raise RequestException(
orig_exception=e,
func_name=func_name,
func_kwargs=kwargs,
)
else:
raise e
2025-03-21 17:47:09 +03:00
api_result_context = APIResultContext(
api_params=api_params,
http_response=http_response,
)
if issubclass(api_result_cls, BaseAPIResultModel):
set_alias_gen(
model_cls=api_result_cls,
alias_type='validation',
name_mapping_dict=self._name_mapping_dict,
)
result_extra_allow = (
api_result_cls.model_config.get('extra') == 'allow'
)
if config.result_extra_allow != result_extra_allow:
if config.result_extra_allow:
api_result_cls.model_config['extra'] = 'allow'
else:
api_result_cls.model_config['extra'] = (
BaseAPIResultModel.model_config.get('extra')
)
api_result_cls.model_rebuild(force=True)
api_result = api_result_cls.model_validate_json(
json_data=http_response.content,
context=api_result_context,
)
else:
try:
decoded_content = http_response.json()
except requests.JSONDecodeError:
decoded_content = http_response.text
api_result = api_result_cls(
value=decoded_content,
context=api_result_context,
)
return api_result
2025-11-25 18:09:46 +03:00
api_function.__name__ = func_name.replace('.', '__')
if self._config.f_decorators:
for decorator in reversed(self._config.f_decorators):
api_function = decorator(api_function)
2025-03-21 17:47:09 +03:00
return api_function