You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
491 lines
15 KiB
491 lines
15 KiB
1 month ago
|
from abc import ABC, abstractmethod
|
||
|
from collections.abc import Callable
|
||
|
from dataclasses import dataclass
|
||
|
from datetime import datetime
|
||
|
import inspect
|
||
|
import os
|
||
|
import re
|
||
|
import time
|
||
|
from types import GenericAlias, UnionType
|
||
|
from typing import Any, Literal, ParamSpec, Protocol, TypeVar, get_args
|
||
|
|
||
|
import requests
|
||
|
from pydantic import (
|
||
|
AliasGenerator,
|
||
|
BaseModel as PydanticBaseModel,
|
||
|
ConfigDict,
|
||
|
PrivateAttr,
|
||
|
create_model,
|
||
|
)
|
||
|
import yaml
|
||
|
|
||
|
from dynamix_sdk.config import Config, ConfigWithAuth
|
||
|
from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path
|
||
|
|
||
|
|
||
|
NAME_MAPPING_FILE_PATH = os.path.join(
|
||
|
os.path.dirname(__file__),
|
||
|
'api',
|
||
|
'name_mapping.yml',
|
||
|
)
|
||
|
PATH_MAPPING_FILE_PATH = os.path.join(
|
||
|
os.path.dirname(__file__),
|
||
|
'api',
|
||
|
'path_mapping.yml',
|
||
|
)
|
||
|
|
||
|
|
||
|
def read_mapping_file(file_path: str):
|
||
|
with open(file_path) as file:
|
||
|
parsed_data: Any = yaml.safe_load(file)
|
||
|
|
||
|
if not isinstance(parsed_data, dict):
|
||
|
raise TypeError
|
||
|
result_dict: dict[Any, Any] = parsed_data
|
||
|
|
||
|
with open(file_path) as file:
|
||
|
name_mapping_file_lines_amount = sum(
|
||
|
1 for s in file if s.lstrip() and not s.startswith('#')
|
||
|
)
|
||
|
|
||
|
if len(result_dict) < name_mapping_file_lines_amount:
|
||
|
raise AssertionError(
|
||
|
f'File {file_path} contains more code lines than'
|
||
|
f' keys in the parsed data. Check the file for duplicate keys.'
|
||
|
)
|
||
|
|
||
|
for k, v in result_dict.items():
|
||
|
if not isinstance(k, str) or not isinstance(v, str):
|
||
|
raise TypeError
|
||
|
result_str_dict: dict[str, str] = result_dict
|
||
|
|
||
|
return result_str_dict
|
||
|
|
||
|
|
||
|
path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH)
|
||
|
|
||
|
|
||
|
name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH)
|
||
|
common_mappings_values = [
|
||
|
v for k, v in name_mapping_dict.items() if '__' not in k
|
||
|
]
|
||
|
if len(common_mappings_values) > len(set(common_mappings_values)):
|
||
|
raise AssertionError(
|
||
|
f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values'
|
||
|
f' only for individual mapping (attr_name__model_class_name),'
|
||
|
f' not common. Check common mappings for duplicate values.'
|
||
|
)
|
||
|
|
||
|
|
||
|
class BaseAPIFunctionProtocol(Protocol):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BaseModel(PydanticBaseModel, ABC):
|
||
|
@staticmethod
|
||
|
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
|
||
|
if timestamp > 0:
|
||
|
return datetime.fromtimestamp(timestamp)
|
||
|
|
||
|
|
||
|
class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC):
|
||
|
model_config = ConfigDict(
|
||
|
extra='forbid',
|
||
|
strict=True,
|
||
|
frozen=True,
|
||
|
)
|
||
|
|
||
|
|
||
|
class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||
|
def __init_subclass__(
|
||
|
cls,
|
||
|
*args: tuple[Any],
|
||
|
**kwargs: dict[str, Any],
|
||
|
) -> None:
|
||
|
super().__init_subclass__(*args, **kwargs)
|
||
|
|
||
|
postfix = 'APIParamsNM'
|
||
|
if not cls.__qualname__.endswith(postfix):
|
||
|
raise ValueError(
|
||
|
f'Name of {cls} must end with {postfix}.',
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class APIResultContext:
|
||
|
api_params: BaseAPIParamsModel
|
||
|
http_response: requests.Response
|
||
|
|
||
|
|
||
|
class BaseAPIResult(ABC):
|
||
|
_api_params: BaseAPIParamsModel
|
||
|
_http_response: requests.Response
|
||
|
|
||
|
|
||
|
class BaseAPIResultModel(
|
||
|
BaseModelWithFrozenStrictExtraForbid,
|
||
|
BaseAPIResult,
|
||
|
ABC,
|
||
|
):
|
||
|
_api_params: BaseAPIParamsModel = PrivateAttr()
|
||
|
_http_response: requests.Response = PrivateAttr()
|
||
|
|
||
|
def model_post_init(self, __context: APIResultContext):
|
||
|
self._api_params = __context.api_params
|
||
|
self._http_response = __context.http_response
|
||
|
|
||
|
|
||
|
class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BaseAPIResultBasicType(BaseAPIResult, ABC):
|
||
|
_frozen: bool = False
|
||
|
|
||
|
@abstractmethod
|
||
|
def __new__(cls, value: Any, context: APIResultContext):
|
||
|
return super().__new__(cls)
|
||
|
|
||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||
|
if self._frozen:
|
||
|
raise AttributeError(
|
||
|
'Instance is frozen.',
|
||
|
)
|
||
|
return super().__setattr__(name, value)
|
||
|
|
||
|
|
||
|
class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC):
|
||
|
def __new__(cls, value: str, context: APIResultContext):
|
||
|
instance = super().__new__(cls, value)
|
||
|
instance._api_params = context.api_params
|
||
|
instance._http_response = context.http_response
|
||
|
instance._frozen = True
|
||
|
return instance
|
||
|
|
||
|
|
||
|
class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC):
|
||
|
def __new__(cls, value: int, context: APIResultContext):
|
||
|
instance = super().__new__(cls, value)
|
||
|
instance._api_params = context.api_params
|
||
|
instance._http_response = context.http_response
|
||
|
instance._frozen = True
|
||
|
return instance
|
||
|
|
||
|
|
||
|
class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
|
||
|
value: bool
|
||
|
|
||
|
def __bool__(self):
|
||
|
return self.value
|
||
|
|
||
|
def __str__(self):
|
||
|
return f'{self.__class__.__qualname__}: {self.value}'
|
||
|
|
||
|
def __new__(cls, value: bool, context: APIResultContext):
|
||
|
if not isinstance(value, bool):
|
||
|
raise TypeError
|
||
|
instance = super().__new__(cls, value, context)
|
||
|
instance.value = value
|
||
|
instance._api_params = context.api_params
|
||
|
instance._http_response = context.http_response
|
||
|
instance._frozen = True
|
||
|
return instance
|
||
|
|
||
|
|
||
|
def set_alias_gen(
|
||
|
model_cls: type[BaseModel],
|
||
|
alias_type: Literal['validation', 'serialization'],
|
||
|
name_mapping_dict: dict[str, str]
|
||
|
):
|
||
|
if model_cls.model_config.get('alias_generator'):
|
||
|
return
|
||
|
|
||
|
def get_model_classes_from_annotation(
|
||
|
annotation: type[Any] | None,
|
||
|
classes: list[type[BaseModel]] | None = None,
|
||
|
):
|
||
|
_classes = classes or []
|
||
|
if annotation:
|
||
|
if isinstance(annotation, (UnionType, GenericAlias)):
|
||
|
for cls in get_args(annotation):
|
||
|
_classes += get_model_classes_from_annotation(cls)
|
||
|
elif issubclass(annotation, BaseModel):
|
||
|
_classes.append(annotation)
|
||
|
return _classes
|
||
|
|
||
|
for field in model_cls.model_fields.values():
|
||
|
for cls in get_model_classes_from_annotation(field.annotation):
|
||
|
set_alias_gen(
|
||
|
model_cls=cls,
|
||
|
alias_type=alias_type,
|
||
|
name_mapping_dict=name_mapping_dict,
|
||
|
)
|
||
|
|
||
|
def alias_gen(field_name: str):
|
||
|
individual_alias = name_mapping_dict.get(
|
||
|
f'{field_name}__{model_cls.__qualname__}',
|
||
|
)
|
||
|
if individual_alias:
|
||
|
return individual_alias
|
||
|
else:
|
||
|
if field_name in name_mapping_dict:
|
||
|
return name_mapping_dict[field_name]
|
||
|
else:
|
||
|
raise KeyError(
|
||
|
f'{field_name} not found in name mapping dictionary.'
|
||
|
f' Class model: {model_cls.__name__}.'
|
||
|
)
|
||
|
|
||
|
kwargs = {
|
||
|
f'{alias_type}_alias': alias_gen
|
||
|
}
|
||
|
|
||
|
model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs)
|
||
|
model_cls.model_rebuild(force=True)
|
||
|
|
||
|
|
||
|
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
|
||
|
|
||
|
|
||
|
APIParamsP = ParamSpec('APIParamsP')
|
||
|
APIResultT = TypeVar(
|
||
|
'APIResultT',
|
||
|
bound=BaseAPIResultModel | BaseAPIResultBasicType,
|
||
|
)
|
||
|
|
||
|
|
||
|
class BaseAPI(ABC):
|
||
|
_config: Config
|
||
|
_base_api_path: None | str
|
||
|
_path_mapping_dict: dict[str, str] = path_mapping_dict
|
||
|
_last_call_api_path: None | str = None
|
||
|
_name_mapping_dict: dict[str, str] = name_mapping_dict
|
||
|
_post_json: bool = True
|
||
|
|
||
|
def __init_subclass__(
|
||
|
cls,
|
||
|
path_mapping_dict: None | dict[str, str] = None,
|
||
|
name_mapping_dict: None | dict[str, str] = None,
|
||
|
post_json: None | bool = None,
|
||
|
) -> None:
|
||
|
if path_mapping_dict:
|
||
|
cls._path_mapping_dict = path_mapping_dict
|
||
|
if name_mapping_dict:
|
||
|
cls._name_mapping_dict = name_mapping_dict
|
||
|
if post_json is not None:
|
||
|
cls._post_json = post_json
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
config: Config,
|
||
|
base_api_path: None | str = None
|
||
|
):
|
||
|
self._config = config
|
||
|
self._base_api_path = base_api_path
|
||
|
|
||
|
def __getattribute__(self, name: str) -> Any:
|
||
|
if name.startswith('_'):
|
||
|
return super().__getattribute__(name)
|
||
|
else:
|
||
|
path_part = self._path_mapping_dict.get(name, name)
|
||
|
if name in self.__annotations__:
|
||
|
annotation = self.__annotations__[name]
|
||
|
if issubclass(annotation, BaseAPI):
|
||
|
api_cls = annotation
|
||
|
return api_cls(
|
||
|
config=self._config,
|
||
|
base_api_path=(
|
||
|
f'{self._base_api_path or ""}/{path_part}'
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
self._last_call_api_path = (
|
||
|
f'{self._base_api_path or ""}/{path_part}'
|
||
|
)
|
||
|
|
||
|
attr_value = super().__getattribute__(name)
|
||
|
if not inspect.ismethod(attr_value):
|
||
|
raise ValueError
|
||
|
|
||
|
return self._make_api_function(attr_value)
|
||
|
|
||
|
def _make_api_function(
|
||
|
self,
|
||
|
protocol_method: Callable[APIParamsP, APIResultT],
|
||
|
/,
|
||
|
) -> Callable[APIParamsP, APIResultT]:
|
||
|
return_type = inspect.signature(protocol_method).return_annotation
|
||
|
if not issubclass(return_type, get_args(APIResultT.__bound__)):
|
||
|
raise TypeError(
|
||
|
f'Return type annotation of {protocol_method}'
|
||
|
f' must be subclass of {get_args(APIResultT.__bound__)}.'
|
||
|
)
|
||
|
api_result_cls: type[APIResultT] = return_type
|
||
|
|
||
|
def api_function(
|
||
|
*args: APIParamsP.args,
|
||
|
**kwargs: APIParamsP.kwargs,
|
||
|
):
|
||
|
config = self._config
|
||
|
|
||
|
if self._last_call_api_path:
|
||
|
api_path = self._last_call_api_path
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
api_params_cls_name = gen_cls_name_from_url_path(
|
||
|
url_path=re.sub(r'[\[\]]', '', api_path),
|
||
|
postfix=BaseAPIParamsModel.__qualname__.removeprefix(
|
||
|
'BaseAPI'
|
||
|
),
|
||
|
)
|
||
|
|
||
|
api_params_cls = params_model_classes.get(
|
||
|
api_params_cls_name,
|
||
|
None,
|
||
|
)
|
||
|
if not api_params_cls:
|
||
|
field_definitions = {}
|
||
|
parameters = inspect.signature(protocol_method).parameters
|
||
|
inspect_empty = inspect.Parameter.empty
|
||
|
for p in parameters.values():
|
||
|
field_definitions[p.name] = (
|
||
|
p.annotation,
|
||
|
... if p.default is inspect_empty else p.default,
|
||
|
)
|
||
|
|
||
|
api_params_cls = create_model(
|
||
|
api_params_cls_name,
|
||
|
__base__=BaseAPIParamsModel,
|
||
|
__module__=api_result_cls.__module__,
|
||
|
**field_definitions,
|
||
|
)
|
||
|
params_model_classes[api_params_cls_name] = api_params_cls
|
||
|
|
||
|
set_alias_gen(
|
||
|
model_cls=api_params_cls,
|
||
|
alias_type='serialization',
|
||
|
name_mapping_dict=self._name_mapping_dict,
|
||
|
)
|
||
|
|
||
|
api_params = api_params_cls(*args, **kwargs)
|
||
|
|
||
|
req_headers = None
|
||
|
if isinstance(config, ConfigWithAuth):
|
||
|
req_headers = {
|
||
|
'Authorization': f'bearer {config.jwt}',
|
||
|
}
|
||
|
|
||
|
model_data = api_params.model_dump(
|
||
|
by_alias=True,
|
||
|
exclude_none=True,
|
||
|
)
|
||
|
|
||
|
for api_mixin_cls in self.__class__.__bases__[1:]:
|
||
|
if hasattr(api_mixin_cls, protocol_method.__name__):
|
||
|
if issubclass(api_mixin_cls, BasePostAPIFunctionProtocol):
|
||
|
http_method = HTTPMethod.POST
|
||
|
break
|
||
|
elif issubclass(api_mixin_cls, BaseGetAPIFunctionProtocol):
|
||
|
http_method = HTTPMethod.GET
|
||
|
break
|
||
|
else:
|
||
|
raise RuntimeError
|
||
|
|
||
|
req_params = None
|
||
|
req_json = None
|
||
|
req_data = None
|
||
|
match http_method:
|
||
|
case HTTPMethod.GET:
|
||
|
req_params = model_data
|
||
|
case HTTPMethod.POST:
|
||
|
if self._post_json:
|
||
|
req_json = model_data
|
||
|
else:
|
||
|
req_data = model_data
|
||
|
case _:
|
||
|
raise RuntimeError
|
||
|
|
||
|
http_request = requests.Request(
|
||
|
method=http_method,
|
||
|
url=config.get_api_url(api_path=api_path),
|
||
|
headers=req_headers,
|
||
|
params=req_params,
|
||
|
json=req_json,
|
||
|
data=req_data,
|
||
|
).prepare()
|
||
|
|
||
|
attempts = config.http503_attempts
|
||
|
with requests.Session() as session:
|
||
|
while True:
|
||
|
http_response = session.send(
|
||
|
request=http_request,
|
||
|
verify=config.verify_ssl,
|
||
|
)
|
||
|
|
||
|
if http_response.status_code == 503:
|
||
|
if attempts < 1:
|
||
|
break
|
||
|
else:
|
||
|
attempts -= 1
|
||
|
time.sleep(config.http503_attempts_interval)
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
http_response.raise_for_status()
|
||
|
|
||
|
api_result_context = APIResultContext(
|
||
|
api_params=api_params,
|
||
|
http_response=http_response,
|
||
|
)
|
||
|
|
||
|
if issubclass(api_result_cls, BaseAPIResultModel):
|
||
|
set_alias_gen(
|
||
|
model_cls=api_result_cls,
|
||
|
alias_type='validation',
|
||
|
name_mapping_dict=self._name_mapping_dict,
|
||
|
)
|
||
|
|
||
|
result_extra_allow = (
|
||
|
api_result_cls.model_config.get('extra') == 'allow'
|
||
|
)
|
||
|
if config.result_extra_allow != result_extra_allow:
|
||
|
if config.result_extra_allow:
|
||
|
api_result_cls.model_config['extra'] = 'allow'
|
||
|
else:
|
||
|
api_result_cls.model_config['extra'] = (
|
||
|
BaseAPIResultModel.model_config.get('extra')
|
||
|
)
|
||
|
api_result_cls.model_rebuild(force=True)
|
||
|
|
||
|
api_result = api_result_cls.model_validate_json(
|
||
|
json_data=http_response.content,
|
||
|
context=api_result_context,
|
||
|
)
|
||
|
else:
|
||
|
try:
|
||
|
decoded_content = http_response.json()
|
||
|
except requests.JSONDecodeError:
|
||
|
decoded_content = http_response.text
|
||
|
|
||
|
api_result = api_result_cls(
|
||
|
value=decoded_content,
|
||
|
context=api_result_context,
|
||
|
)
|
||
|
|
||
|
return api_result
|
||
|
|
||
|
return api_function
|