This commit is contained in:
2025-06-06 08:20:45 +03:00
parent 346ffd4255
commit caf367262c
205 changed files with 6115 additions and 200 deletions

View File

@@ -6,8 +6,16 @@ import inspect
import os
import re
import time
from types import GenericAlias, UnionType
from typing import Any, Literal, ParamSpec, Protocol, TypeVar, get_args
from types import GenericAlias, MethodType, UnionType
from typing import (
Any,
Literal,
ParamSpec,
Protocol,
TypeVar,
get_args,
runtime_checkable,
)
import requests
from pydantic import (
@@ -77,6 +85,7 @@ if len(common_mappings_values) > len(set(common_mappings_values)):
)
@runtime_checkable
class BaseAPIFunctionProtocol(Protocol):
pass
@@ -89,6 +98,12 @@ class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
pass
base_proto_to_http_method: dict[type[BaseAPIFunctionProtocol], HTTPMethod] = {
BasePostAPIFunctionProtocol: HTTPMethod.POST,
BaseGetAPIFunctionProtocol: HTTPMethod.GET,
}
class BaseModel(PydanticBaseModel, ABC):
@staticmethod
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
@@ -109,18 +124,7 @@ class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
def __init_subclass__(
cls,
*args: tuple[Any],
**kwargs: dict[str, Any],
) -> None:
super().__init_subclass__(*args, **kwargs)
postfix = 'APIParamsNM'
if not cls.__qualname__.endswith(postfix):
raise ValueError(
f'Name of {cls} must end with {postfix}.',
)
pass
@dataclass
@@ -194,16 +198,58 @@ class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
return f'{self.__class__.__qualname__}: {self.value}'
def __new__(cls, value: bool, context: APIResultContext):
if not isinstance(value, bool):
if isinstance(value, bool):
_value = value
elif value is None: # BDX-8379
_value = True
else:
raise TypeError
instance = super().__new__(cls, value, context)
instance.value = value
instance = super().__new__(cls, _value, context)
instance.value = _value
instance._api_params = context.api_params
instance._http_response = context.http_response
instance._frozen = True
return instance
def get_alias(
field_name: str,
model_cls: type[BaseModel],
name_mapping_dict: dict[str, str],
) -> str:
if field_name in model_cls.__annotations__.keys():
individual_alias = name_mapping_dict.get(
f'{field_name}__{model_cls.__qualname__}',
)
if individual_alias:
return individual_alias
if field_name in name_mapping_dict:
return name_mapping_dict[field_name]
raise KeyError(
f'Mapping for attr {model_cls.__qualname__}.{field_name}'
f' not found in name mapping dictionary.'
)
for base_cls in model_cls.__bases__:
if not issubclass(base_cls, BaseModel):
continue
if field_name not in base_cls.model_fields.keys():
continue
return get_alias(
field_name=field_name,
model_cls=base_cls,
name_mapping_dict=name_mapping_dict,
)
raise NameError(
f'Field {field_name} not found in model {model_cls.__qualname__}.'
)
def set_alias_gen(
model_cls: type[BaseModel],
alias_type: Literal['validation', 'serialization'],
@@ -213,11 +259,11 @@ def set_alias_gen(
return
def get_model_classes_from_annotation(
annotation: type[Any] | None,
annotation: Any,
classes: list[type[BaseModel]] | None = None,
):
_classes = classes or []
if annotation:
if annotation and annotation is not Any:
if isinstance(annotation, (UnionType, GenericAlias)):
for cls in get_args(annotation):
_classes += get_model_classes_from_annotation(cls)
@@ -234,19 +280,11 @@ def set_alias_gen(
)
def alias_gen(field_name: str):
individual_alias = name_mapping_dict.get(
f'{field_name}__{model_cls.__qualname__}',
return get_alias(
field_name=field_name,
model_cls=model_cls,
name_mapping_dict=name_mapping_dict,
)
if individual_alias:
return individual_alias
else:
if field_name in name_mapping_dict:
return name_mapping_dict[field_name]
else:
raise KeyError(
f'{field_name} not found in name mapping dictionary.'
f' Class model: {model_cls.__name__}.'
)
kwargs = {
f'{alias_type}_alias': alias_gen
@@ -259,6 +297,37 @@ def set_alias_gen(
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
def gen_api_params_cls_name(api_path: str):
return gen_cls_name_from_url_path(
url_path=re.sub(r'[\[\]]', '', api_path),
postfix='ParamsModel',
)
def create_api_params_cls(
cls_name: str,
module_name: str,
protocol_method: Callable
):
field_definitions = {}
params = inspect.signature(protocol_method).parameters.values()
inspect_empty = inspect.Parameter.empty
for param in params:
if param.name == 'self':
continue
field_definitions[param.name] = (
param.annotation,
... if param.default is inspect_empty else param.default,
)
return create_model(
cls_name,
__base__=BaseAPIParamsModel,
__module__=module_name,
**field_definitions,
)
APIParamsP = ParamSpec('APIParamsP')
APIResultT = TypeVar(
'APIResultT',
@@ -334,6 +403,9 @@ class BaseAPI(ABC):
)
api_result_cls: type[APIResultT] = return_type
if not isinstance(protocol_method, MethodType):
raise TypeError
def api_function(
*args: APIParamsP.args,
**kwargs: APIParamsP.kwargs,
@@ -345,11 +417,8 @@ class BaseAPI(ABC):
else:
raise ValueError
api_params_cls_name = gen_cls_name_from_url_path(
url_path=re.sub(r'[\[\]]', '', api_path),
postfix=BaseAPIParamsModel.__qualname__.removeprefix(
'BaseAPI'
),
api_params_cls_name = gen_api_params_cls_name(
api_path=api_path,
)
api_params_cls = params_model_classes.get(
@@ -357,20 +426,10 @@ class BaseAPI(ABC):
None,
)
if not api_params_cls:
field_definitions = {}
parameters = inspect.signature(protocol_method).parameters
inspect_empty = inspect.Parameter.empty
for p in parameters.values():
field_definitions[p.name] = (
p.annotation,
... if p.default is inspect_empty else p.default,
)
api_params_cls = create_model(
api_params_cls_name,
__base__=BaseAPIParamsModel,
__module__=api_result_cls.__module__,
**field_definitions,
api_params_cls = create_api_params_cls(
cls_name=api_params_cls_name,
module_name=api_result_cls.__module__,
protocol_method=protocol_method,
)
params_model_classes[api_params_cls_name] = api_params_cls
@@ -394,13 +453,20 @@ class BaseAPI(ABC):
)
for api_mixin_cls in self.__class__.__bases__[1:]:
if hasattr(api_mixin_cls, protocol_method.__name__):
if issubclass(api_mixin_cls, BasePostAPIFunctionProtocol):
http_method = HTTPMethod.POST
break
elif issubclass(api_mixin_cls, BaseGetAPIFunctionProtocol):
http_method = HTTPMethod.GET
break
if not hasattr(api_mixin_cls, protocol_method.__name__):
continue
proto_cls_base = api_mixin_cls.__base__
if (
not proto_cls_base
or not issubclass(proto_cls_base, BaseAPIFunctionProtocol)
):
raise TypeError(
f'Class {api_mixin_cls.__qualname__} must have'
f' {BaseAPIFunctionProtocol.__qualname__}'
f' as parent class.'
)
http_method = base_proto_to_http_method[proto_cls_base]
break
else:
raise RuntimeError