1.0.0
This commit is contained in:
490
src/dynamix_sdk/base.py
Normal file
490
src/dynamix_sdk/base.py
Normal file
@@ -0,0 +1,490 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from types import GenericAlias, UnionType
|
||||
from typing import Any, Literal, ParamSpec, Protocol, TypeVar, get_args
|
||||
|
||||
import requests
|
||||
from pydantic import (
|
||||
AliasGenerator,
|
||||
BaseModel as PydanticBaseModel,
|
||||
ConfigDict,
|
||||
PrivateAttr,
|
||||
create_model,
|
||||
)
|
||||
import yaml
|
||||
|
||||
from dynamix_sdk.config import Config, ConfigWithAuth
|
||||
from dynamix_sdk.utils import HTTPMethod, gen_cls_name_from_url_path
|
||||
|
||||
|
||||
NAME_MAPPING_FILE_PATH = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'api',
|
||||
'name_mapping.yml',
|
||||
)
|
||||
PATH_MAPPING_FILE_PATH = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'api',
|
||||
'path_mapping.yml',
|
||||
)
|
||||
|
||||
|
||||
def read_mapping_file(file_path: str):
|
||||
with open(file_path) as file:
|
||||
parsed_data: Any = yaml.safe_load(file)
|
||||
|
||||
if not isinstance(parsed_data, dict):
|
||||
raise TypeError
|
||||
result_dict: dict[Any, Any] = parsed_data
|
||||
|
||||
with open(file_path) as file:
|
||||
name_mapping_file_lines_amount = sum(
|
||||
1 for s in file if s.lstrip() and not s.startswith('#')
|
||||
)
|
||||
|
||||
if len(result_dict) < name_mapping_file_lines_amount:
|
||||
raise AssertionError(
|
||||
f'File {file_path} contains more code lines than'
|
||||
f' keys in the parsed data. Check the file for duplicate keys.'
|
||||
)
|
||||
|
||||
for k, v in result_dict.items():
|
||||
if not isinstance(k, str) or not isinstance(v, str):
|
||||
raise TypeError
|
||||
result_str_dict: dict[str, str] = result_dict
|
||||
|
||||
return result_str_dict
|
||||
|
||||
|
||||
path_mapping_dict = read_mapping_file(PATH_MAPPING_FILE_PATH)
|
||||
|
||||
|
||||
name_mapping_dict = read_mapping_file(NAME_MAPPING_FILE_PATH)
|
||||
common_mappings_values = [
|
||||
v for k, v in name_mapping_dict.items() if '__' not in k
|
||||
]
|
||||
if len(common_mappings_values) > len(set(common_mappings_values)):
|
||||
raise AssertionError(
|
||||
f'File {NAME_MAPPING_FILE_PATH} can contain duplicate values'
|
||||
f' only for individual mapping (attr_name__model_class_name),'
|
||||
f' not common. Check common mappings for duplicate values.'
|
||||
)
|
||||
|
||||
|
||||
class BaseAPIFunctionProtocol(Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class BasePostAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
||||
pass
|
||||
|
||||
|
||||
class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
||||
pass
|
||||
|
||||
|
||||
class BaseModel(PydanticBaseModel, ABC):
|
||||
@staticmethod
|
||||
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
|
||||
if timestamp > 0:
|
||||
return datetime.fromtimestamp(timestamp)
|
||||
|
||||
|
||||
class BaseModelWithFrozenStrictExtraForbid(BaseModel, ABC):
|
||||
model_config = ConfigDict(
|
||||
extra='forbid',
|
||||
strict=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
|
||||
class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
*args: tuple[Any],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init_subclass__(*args, **kwargs)
|
||||
|
||||
postfix = 'APIParamsNM'
|
||||
if not cls.__qualname__.endswith(postfix):
|
||||
raise ValueError(
|
||||
f'Name of {cls} must end with {postfix}.',
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResultContext:
|
||||
api_params: BaseAPIParamsModel
|
||||
http_response: requests.Response
|
||||
|
||||
|
||||
class BaseAPIResult(ABC):
|
||||
_api_params: BaseAPIParamsModel
|
||||
_http_response: requests.Response
|
||||
|
||||
|
||||
class BaseAPIResultModel(
|
||||
BaseModelWithFrozenStrictExtraForbid,
|
||||
BaseAPIResult,
|
||||
ABC,
|
||||
):
|
||||
_api_params: BaseAPIParamsModel = PrivateAttr()
|
||||
_http_response: requests.Response = PrivateAttr()
|
||||
|
||||
def model_post_init(self, __context: APIResultContext):
|
||||
self._api_params = __context.api_params
|
||||
self._http_response = __context.http_response
|
||||
|
||||
|
||||
class BaseAPIResultNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class BaseAPIResultBasicType(BaseAPIResult, ABC):
|
||||
_frozen: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def __new__(cls, value: Any, context: APIResultContext):
|
||||
return super().__new__(cls)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if self._frozen:
|
||||
raise AttributeError(
|
||||
'Instance is frozen.',
|
||||
)
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
|
||||
class BaseAPIResultStr(str, BaseAPIResultBasicType, ABC):
|
||||
def __new__(cls, value: str, context: APIResultContext):
|
||||
instance = super().__new__(cls, value)
|
||||
instance._api_params = context.api_params
|
||||
instance._http_response = context.http_response
|
||||
instance._frozen = True
|
||||
return instance
|
||||
|
||||
|
||||
class BaseAPIResultInt(int, BaseAPIResultBasicType, ABC):
|
||||
def __new__(cls, value: int, context: APIResultContext):
|
||||
instance = super().__new__(cls, value)
|
||||
instance._api_params = context.api_params
|
||||
instance._http_response = context.http_response
|
||||
instance._frozen = True
|
||||
return instance
|
||||
|
||||
|
||||
class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
|
||||
value: bool
|
||||
|
||||
def __bool__(self):
|
||||
return self.value
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.__class__.__qualname__}: {self.value}'
|
||||
|
||||
def __new__(cls, value: bool, context: APIResultContext):
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError
|
||||
instance = super().__new__(cls, value, context)
|
||||
instance.value = value
|
||||
instance._api_params = context.api_params
|
||||
instance._http_response = context.http_response
|
||||
instance._frozen = True
|
||||
return instance
|
||||
|
||||
|
||||
def set_alias_gen(
|
||||
model_cls: type[BaseModel],
|
||||
alias_type: Literal['validation', 'serialization'],
|
||||
name_mapping_dict: dict[str, str]
|
||||
):
|
||||
if model_cls.model_config.get('alias_generator'):
|
||||
return
|
||||
|
||||
def get_model_classes_from_annotation(
|
||||
annotation: type[Any] | None,
|
||||
classes: list[type[BaseModel]] | None = None,
|
||||
):
|
||||
_classes = classes or []
|
||||
if annotation:
|
||||
if isinstance(annotation, (UnionType, GenericAlias)):
|
||||
for cls in get_args(annotation):
|
||||
_classes += get_model_classes_from_annotation(cls)
|
||||
elif issubclass(annotation, BaseModel):
|
||||
_classes.append(annotation)
|
||||
return _classes
|
||||
|
||||
for field in model_cls.model_fields.values():
|
||||
for cls in get_model_classes_from_annotation(field.annotation):
|
||||
set_alias_gen(
|
||||
model_cls=cls,
|
||||
alias_type=alias_type,
|
||||
name_mapping_dict=name_mapping_dict,
|
||||
)
|
||||
|
||||
def alias_gen(field_name: str):
|
||||
individual_alias = name_mapping_dict.get(
|
||||
f'{field_name}__{model_cls.__qualname__}',
|
||||
)
|
||||
if individual_alias:
|
||||
return individual_alias
|
||||
else:
|
||||
if field_name in name_mapping_dict:
|
||||
return name_mapping_dict[field_name]
|
||||
else:
|
||||
raise KeyError(
|
||||
f'{field_name} not found in name mapping dictionary.'
|
||||
f' Class model: {model_cls.__name__}.'
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
f'{alias_type}_alias': alias_gen
|
||||
}
|
||||
|
||||
model_cls.model_config['alias_generator'] = AliasGenerator(**kwargs)
|
||||
model_cls.model_rebuild(force=True)
|
||||
|
||||
|
||||
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
|
||||
|
||||
|
||||
APIParamsP = ParamSpec('APIParamsP')
|
||||
APIResultT = TypeVar(
|
||||
'APIResultT',
|
||||
bound=BaseAPIResultModel | BaseAPIResultBasicType,
|
||||
)
|
||||
|
||||
|
||||
class BaseAPI(ABC):
|
||||
_config: Config
|
||||
_base_api_path: None | str
|
||||
_path_mapping_dict: dict[str, str] = path_mapping_dict
|
||||
_last_call_api_path: None | str = None
|
||||
_name_mapping_dict: dict[str, str] = name_mapping_dict
|
||||
_post_json: bool = True
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
path_mapping_dict: None | dict[str, str] = None,
|
||||
name_mapping_dict: None | dict[str, str] = None,
|
||||
post_json: None | bool = None,
|
||||
) -> None:
|
||||
if path_mapping_dict:
|
||||
cls._path_mapping_dict = path_mapping_dict
|
||||
if name_mapping_dict:
|
||||
cls._name_mapping_dict = name_mapping_dict
|
||||
if post_json is not None:
|
||||
cls._post_json = post_json
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
base_api_path: None | str = None
|
||||
):
|
||||
self._config = config
|
||||
self._base_api_path = base_api_path
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
if name.startswith('_'):
|
||||
return super().__getattribute__(name)
|
||||
else:
|
||||
path_part = self._path_mapping_dict.get(name, name)
|
||||
if name in self.__annotations__:
|
||||
annotation = self.__annotations__[name]
|
||||
if issubclass(annotation, BaseAPI):
|
||||
api_cls = annotation
|
||||
return api_cls(
|
||||
config=self._config,
|
||||
base_api_path=(
|
||||
f'{self._base_api_path or ""}/{path_part}'
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._last_call_api_path = (
|
||||
f'{self._base_api_path or ""}/{path_part}'
|
||||
)
|
||||
|
||||
attr_value = super().__getattribute__(name)
|
||||
if not inspect.ismethod(attr_value):
|
||||
raise ValueError
|
||||
|
||||
return self._make_api_function(attr_value)
|
||||
|
||||
def _make_api_function(
|
||||
self,
|
||||
protocol_method: Callable[APIParamsP, APIResultT],
|
||||
/,
|
||||
) -> Callable[APIParamsP, APIResultT]:
|
||||
return_type = inspect.signature(protocol_method).return_annotation
|
||||
if not issubclass(return_type, get_args(APIResultT.__bound__)):
|
||||
raise TypeError(
|
||||
f'Return type annotation of {protocol_method}'
|
||||
f' must be subclass of {get_args(APIResultT.__bound__)}.'
|
||||
)
|
||||
api_result_cls: type[APIResultT] = return_type
|
||||
|
||||
def api_function(
|
||||
*args: APIParamsP.args,
|
||||
**kwargs: APIParamsP.kwargs,
|
||||
):
|
||||
config = self._config
|
||||
|
||||
if self._last_call_api_path:
|
||||
api_path = self._last_call_api_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
api_params_cls_name = gen_cls_name_from_url_path(
|
||||
url_path=re.sub(r'[\[\]]', '', api_path),
|
||||
postfix=BaseAPIParamsModel.__qualname__.removeprefix(
|
||||
'BaseAPI'
|
||||
),
|
||||
)
|
||||
|
||||
api_params_cls = params_model_classes.get(
|
||||
api_params_cls_name,
|
||||
None,
|
||||
)
|
||||
if not api_params_cls:
|
||||
field_definitions = {}
|
||||
parameters = inspect.signature(protocol_method).parameters
|
||||
inspect_empty = inspect.Parameter.empty
|
||||
for p in parameters.values():
|
||||
field_definitions[p.name] = (
|
||||
p.annotation,
|
||||
... if p.default is inspect_empty else p.default,
|
||||
)
|
||||
|
||||
api_params_cls = create_model(
|
||||
api_params_cls_name,
|
||||
__base__=BaseAPIParamsModel,
|
||||
__module__=api_result_cls.__module__,
|
||||
**field_definitions,
|
||||
)
|
||||
params_model_classes[api_params_cls_name] = api_params_cls
|
||||
|
||||
set_alias_gen(
|
||||
model_cls=api_params_cls,
|
||||
alias_type='serialization',
|
||||
name_mapping_dict=self._name_mapping_dict,
|
||||
)
|
||||
|
||||
api_params = api_params_cls(*args, **kwargs)
|
||||
|
||||
req_headers = None
|
||||
if isinstance(config, ConfigWithAuth):
|
||||
req_headers = {
|
||||
'Authorization': f'bearer {config.jwt}',
|
||||
}
|
||||
|
||||
model_data = api_params.model_dump(
|
||||
by_alias=True,
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
for api_mixin_cls in self.__class__.__bases__[1:]:
|
||||
if hasattr(api_mixin_cls, protocol_method.__name__):
|
||||
if issubclass(api_mixin_cls, BasePostAPIFunctionProtocol):
|
||||
http_method = HTTPMethod.POST
|
||||
break
|
||||
elif issubclass(api_mixin_cls, BaseGetAPIFunctionProtocol):
|
||||
http_method = HTTPMethod.GET
|
||||
break
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
req_params = None
|
||||
req_json = None
|
||||
req_data = None
|
||||
match http_method:
|
||||
case HTTPMethod.GET:
|
||||
req_params = model_data
|
||||
case HTTPMethod.POST:
|
||||
if self._post_json:
|
||||
req_json = model_data
|
||||
else:
|
||||
req_data = model_data
|
||||
case _:
|
||||
raise RuntimeError
|
||||
|
||||
http_request = requests.Request(
|
||||
method=http_method,
|
||||
url=config.get_api_url(api_path=api_path),
|
||||
headers=req_headers,
|
||||
params=req_params,
|
||||
json=req_json,
|
||||
data=req_data,
|
||||
).prepare()
|
||||
|
||||
attempts = config.http503_attempts
|
||||
with requests.Session() as session:
|
||||
while True:
|
||||
http_response = session.send(
|
||||
request=http_request,
|
||||
verify=config.verify_ssl,
|
||||
)
|
||||
|
||||
if http_response.status_code == 503:
|
||||
if attempts < 1:
|
||||
break
|
||||
else:
|
||||
attempts -= 1
|
||||
time.sleep(config.http503_attempts_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
http_response.raise_for_status()
|
||||
|
||||
api_result_context = APIResultContext(
|
||||
api_params=api_params,
|
||||
http_response=http_response,
|
||||
)
|
||||
|
||||
if issubclass(api_result_cls, BaseAPIResultModel):
|
||||
set_alias_gen(
|
||||
model_cls=api_result_cls,
|
||||
alias_type='validation',
|
||||
name_mapping_dict=self._name_mapping_dict,
|
||||
)
|
||||
|
||||
result_extra_allow = (
|
||||
api_result_cls.model_config.get('extra') == 'allow'
|
||||
)
|
||||
if config.result_extra_allow != result_extra_allow:
|
||||
if config.result_extra_allow:
|
||||
api_result_cls.model_config['extra'] = 'allow'
|
||||
else:
|
||||
api_result_cls.model_config['extra'] = (
|
||||
BaseAPIResultModel.model_config.get('extra')
|
||||
)
|
||||
api_result_cls.model_rebuild(force=True)
|
||||
|
||||
api_result = api_result_cls.model_validate_json(
|
||||
json_data=http_response.content,
|
||||
context=api_result_context,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
decoded_content = http_response.json()
|
||||
except requests.JSONDecodeError:
|
||||
decoded_content = http_response.text
|
||||
|
||||
api_result = api_result_cls(
|
||||
value=decoded_content,
|
||||
context=api_result_context,
|
||||
)
|
||||
|
||||
return api_result
|
||||
|
||||
return api_function
|
||||
Reference in New Issue
Block a user