You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

491 lines
15 KiB

1 month ago
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
import inspect
import os
import re
import time
from types import GenericAlias, UnionType
from typing import Any, Literal, ParamSpec, Protocol, TypeVar, get_args
import requests
from pydantic import (
AliasGenerator,
BaseModel as PydanticBaseModel,
ConfigDict,
PrivateAttr,
create_model,
)
import yaml
from dynamix_sdk.config import Config, ConfigWithAuth
from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path
NAME_MAPPING_FILE_PATH = os.path.join(
os.path.dirname(__file__),
'api',
'name_mapping.yml',
)
PATH_MAPPING_FILE_PATH = os.path.join(
os.path.dirname(__file__),
'api',
'path_mapping.yml',
)
def read_mapping_file(file_path: str):
with open(file_path) as file:
parsed_data: Any = yaml.safe_load(file)
if not isinstance(parsed_data, dict):
raise TypeError
result_dict: dict[Any, Any] = parsed_data
with open(file_path) as file:
name_mapping_file_lines_amount = sum(
1 for s in file if s.lstrip() and not s.startswith('#')
)
if len(result_dict) < name_mapping_file_lines_amount:
raise AssertionError(
f'File {file_path} contains more code lines than'
f' keys in the parsed data. Check the file for duplicate keys.'
)
for k, v in result_dict.items():
if not isinstance(k, str) or not isinstance(v, str):
raise TypeError
result_str_dict: dict[str, str] = result_dict
return result_str_dict
path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH)
name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH)
common_mappings_values = [
v for k, v in name_mapping_dict.items() if '__' not in k
]
if len(common_mappings_values) > len(set(common_mappings_values)):
raise AssertionError(
f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values'
f' only for individual mapping (attr_name__model_class_name),'
f' not common. Check common mappings for duplicate values.'
)
class BaseAPIFunctionProtocol(Protocol):
pass
class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
class BaseModel(PydanticBaseModel, ABC):
@staticmethod
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
if timestamp > 0:
return datetime.fromtimestamp(timestamp)
class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC):
model_config = ConfigDict(
extra='forbid',
strict=True,
frozen=True,
)
class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
pass
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
def __init_subclass__(
cls,
*args: tuple[Any],
**kwargs: dict[str, Any],
) -> None:
super().__init_subclass__(*args, **kwargs)
postfix = 'APIParamsNM'
if not cls.__qualname__.endswith(postfix):
raise ValueError(
f'Name of {cls} must end with {postfix}.',
)
@dataclass
class APIResultContext:
api_params: BaseAPIParamsModel
http_response: requests.Response
class BaseAPIResult(ABC):
_api_params: BaseAPIParamsModel
_http_response: requests.Response
class BaseAPIResultModel(
BaseModelWithFrozenStrictExtraForbid,
BaseAPIResult,
ABC,
):
_api_params: BaseAPIParamsModel = PrivateAttr()
_http_response: requests.Response = PrivateAttr()
def model_post_init(self, __context: APIResultContext):
self._api_params = __context.api_params
self._http_response = __context.http_response
class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
pass
class BaseAPIResultBasicType(BaseAPIResult, ABC):
_frozen: bool = False
@abstractmethod
def __new__(cls, value: Any, context: APIResultContext):
return super().__new__(cls)
def __setattr__(self, name: str, value: Any) -> None:
if self._frozen:
raise AttributeError(
'Instance is frozen.',
)
return super().__setattr__(name, value)
class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC):
def __new__(cls, value: str, context: APIResultContext):
instance = super().__new__(cls, value)
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC):
def __new__(cls, value: int, context: APIResultContext):
instance = super().__new__(cls, value)
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
value: bool
def __bool__(self):
return self.value
def __str__(self):
return f'{self.__class__.__qualname__}: {self.value}'
def __new__(cls, value: bool, context: APIResultContext):
if not isinstance(value, bool):
raise TypeError
instance = super().__new__(cls, value, context)
instance.value = value
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
def set_alias_gen(
model_cls: type[BaseModel],
alias_type: Literal['validation', 'serialization'],
name_mapping_dict: dict[str, str]
):
if model_cls.model_config.get('alias_generator'):
return
def get_model_classes_from_annotation(
annotation: type[Any] | None,
classes: list[type[BaseModel]] | None = None,
):
_classes = classes or []
if annotation:
if isinstance(annotation, (UnionType, GenericAlias)):
for cls in get_args(annotation):
_classes += get_model_classes_from_annotation(cls)
elif issubclass(annotation, BaseModel):
_classes.append(annotation)
return _classes
for field in model_cls.model_fields.values():
for cls in get_model_classes_from_annotation(field.annotation):
set_alias_gen(
model_cls=cls,
alias_type=alias_type,
name_mapping_dict=name_mapping_dict,
)
def alias_gen(field_name: str):
individual_alias = name_mapping_dict.get(
f'{field_name}__{model_cls.__qualname__}',
)
if individual_alias:
return individual_alias
else:
if field_name in name_mapping_dict:
return name_mapping_dict[field_name]
else:
raise KeyError(
f'{field_name} not found in name mapping dictionary.'
f' Class model: {model_cls.__name__}.'
)
kwargs = {
f'{alias_type}_alias': alias_gen
}
model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs)
model_cls.model_rebuild(force=True)
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
APIParamsP = ParamSpec('APIParamsP')
APIResultT = TypeVar(
'APIResultT',
bound=BaseAPIResultModel | BaseAPIResultBasicType,
)
class BaseAPI(ABC):
_config: Config
_base_api_path: None | str
_path_mapping_dict: dict[str, str] = path_mapping_dict
_last_call_api_path: None | str = None
_name_mapping_dict: dict[str, str] = name_mapping_dict
_post_json: bool = True
def __init_subclass__(
cls,
path_mapping_dict: None | dict[str, str] = None,
name_mapping_dict: None | dict[str, str] = None,
post_json: None | bool = None,
) -> None:
if path_mapping_dict:
cls._path_mapping_dict = path_mapping_dict
if name_mapping_dict:
cls._name_mapping_dict = name_mapping_dict
if post_json is not None:
cls._post_json = post_json
def __init__(
self,
config: Config,
base_api_path: None | str = None
):
self._config = config
self._base_api_path = base_api_path
def __getattribute__(self, name: str) -> Any:
if name.startswith('_'):
return super().__getattribute__(name)
else:
path_part = self._path_mapping_dict.get(name, name)
if name in self.__annotations__:
annotation = self.__annotations__[name]
if issubclass(annotation, BaseAPI):
api_cls = annotation
return api_cls(
config=self._config,
base_api_path=(
f'{self._base_api_path or ""}/{path_part}'
)
)
else:
self._last_call_api_path = (
f'{self._base_api_path or ""}/{path_part}'
)
attr_value = super().__getattribute__(name)
if not inspect.ismethod(attr_value):
raise ValueError
return self._make_api_function(attr_value)
def _make_api_function(
self,
protocol_method: Callable[APIParamsP, APIResultT],
/,
) -> Callable[APIParamsP, APIResultT]:
return_type = inspect.signature(protocol_method).return_annotation
if not issubclass(return_type, get_args(APIResultT.__bound__)):
raise TypeError(
f'Return type annotation of {protocol_method}'
f' must be subclass of {get_args(APIResultT.__bound__)}.'
)
api_result_cls: type[APIResultT] = return_type
def api_function(
*args: APIParamsP.args,
**kwargs: APIParamsP.kwargs,
):
config = self._config
if self._last_call_api_path:
api_path = self._last_call_api_path
else:
raise ValueError
api_params_cls_name = gen_cls_name_from_url_path(
url_path=re.sub(r'[\[\]]', '', api_path),
postfix=BaseAPIParamsModel.__qualname__.removeprefix(
'BaseAPI'
),
)
api_params_cls = params_model_classes.get(
api_params_cls_name,
None,
)
if not api_params_cls:
field_definitions = {}
parameters = inspect.signature(protocol_method).parameters
inspect_empty = inspect.Parameter.empty
for p in parameters.values():
field_definitions[p.name] = (
p.annotation,
... if p.default is inspect_empty else p.default,
)
api_params_cls = create_model(
api_params_cls_name,
__base__=BaseAPIParamsModel,
__module__=api_result_cls.__module__,
**field_definitions,
)
params_model_classes[api_params_cls_name] = api_params_cls
set_alias_gen(
model_cls=api_params_cls,
alias_type='serialization',
name_mapping_dict=self._name_mapping_dict,
)
api_params = api_params_cls(*args, **kwargs)
req_headers = None
if isinstance(config, ConfigWithAuth):
req_headers = {
'Authorization': f'bearer {config.jwt}',
}
model_data = api_params.model_dump(
by_alias=True,
exclude_none=True,
)
for api_mixin_cls in self.__class__.__bases__[1:]:
if hasattr(api_mixin_cls, protocol_method.__name__):
if issubclass(api_mixin_cls, BasePostAPIFunctionProtocol):
http_method = HTTPMethod.POST
break
elif issubclass(api_mixin_cls, BaseGetAPIFunctionProtocol):
http_method = HTTPMethod.GET
break
else:
raise RuntimeError
req_params = None
req_json = None
req_data = None
match http_method:
case HTTPMethod.GET:
req_params = model_data
case HTTPMethod.POST:
if self._post_json:
req_json = model_data
else:
req_data = model_data
case _:
raise RuntimeError
http_request = requests.Request(
method=http_method,
url=config.get_api_url(api_path=api_path),
headers=req_headers,
params=req_params,
json=req_json,
data=req_data,
).prepare()
attempts = config.http503_attempts
with requests.Session() as session:
while True:
http_response = session.send(
request=http_request,
verify=config.verify_ssl,
)
if http_response.status_code == 503:
if attempts < 1:
break
else:
attempts -= 1
time.sleep(config.http503_attempts_interval)
else:
break
http_response.raise_for_status()
api_result_context = APIResultContext(
api_params=api_params,
http_response=http_response,
)
if issubclass(api_result_cls, BaseAPIResultModel):
set_alias_gen(
model_cls=api_result_cls,
alias_type='validation',
name_mapping_dict=self._name_mapping_dict,
)
result_extra_allow = (
api_result_cls.model_config.get('extra') == 'allow'
)
if config.result_extra_allow != result_extra_allow:
if config.result_extra_allow:
api_result_cls.model_config['extra'] = 'allow'
else:
api_result_cls.model_config['extra'] = (
BaseAPIResultModel.model_config.get('extra')
)
api_result_cls.model_rebuild(force=True)
api_result = api_result_cls.model_validate_json(
json_data=http_response.content,
context=api_result_context,
)
else:
try:
decoded_content = http_response.json()
except requests.JSONDecodeError:
decoded_content = http_response.text
api_result = api_result_cls(
value=decoded_content,
context=api_result_context,
)
return api_result
return api_function