1.0.1
This commit is contained in:
@@ -6,8 +6,16 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from types import GenericAlias, UnionType
|
||||
from typing import Any, Literal, ParamSpec, Protocol, TypeVar, get_args
|
||||
from types import GenericAlias, MethodType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
get_args,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import requests
|
||||
from pydantic import (
|
||||
@@ -77,6 +85,7 @@ if len(common_mappings_values) > len(set(common_mappings_values)):
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseAPIFunctionProtocol(Protocol):
|
||||
pass
|
||||
|
||||
@@ -89,6 +98,12 @@ class BaseGetAPIFunctionProtocol(BaseAPIFunctionProtocol):
|
||||
pass
|
||||
|
||||
|
||||
base_proto_to_http_method: dict[type[BaseAPIFunctionProtocol], HTTPMethod] = {
|
||||
BasePostAPIFunctionProtocol: HTTPMethod.POST,
|
||||
BaseGetAPIFunctionProtocol: HTTPMethod.GET,
|
||||
}
|
||||
|
||||
|
||||
class BaseModel(PydanticBaseModel, ABC):
|
||||
@staticmethod
|
||||
def _get_datetime_from_timestamp(timestamp: float) -> None | datetime:
|
||||
@@ -109,18 +124,7 @@ class BaseAPIParamsModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||||
|
||||
|
||||
class BaseAPIParamsNestedModel(BaseModelWithFrozenStrictExtraForbid, ABC):
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
*args: tuple[Any],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init_subclass__(*args, **kwargs)
|
||||
|
||||
postfix = 'APIParamsNM'
|
||||
if not cls.__qualname__.endswith(postfix):
|
||||
raise ValueError(
|
||||
f'Name of {cls} must end with {postfix}.',
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -194,16 +198,58 @@ class BaseAPIResultBool(BaseAPIResultBasicType, ABC):
|
||||
return f'{self.__class__.__qualname__}: {self.value}'
|
||||
|
||||
def __new__(cls, value: bool, context: APIResultContext):
|
||||
if not isinstance(value, bool):
|
||||
if isinstance(value, bool):
|
||||
_value = value
|
||||
elif value is None: # BDX-8379
|
||||
_value = True
|
||||
else:
|
||||
raise TypeError
|
||||
instance = super().__new__(cls, value, context)
|
||||
instance.value = value
|
||||
instance = super().__new__(cls, _value, context)
|
||||
instance.value = _value
|
||||
instance._api_params = context.api_params
|
||||
instance._http_response = context.http_response
|
||||
instance._frozen = True
|
||||
return instance
|
||||
|
||||
|
||||
def get_alias(
|
||||
field_name: str,
|
||||
model_cls: type[BaseModel],
|
||||
name_mapping_dict: dict[str, str],
|
||||
) -> str:
|
||||
if field_name in model_cls.__annotations__.keys():
|
||||
individual_alias = name_mapping_dict.get(
|
||||
f'{field_name}__{model_cls.__qualname__}',
|
||||
)
|
||||
if individual_alias:
|
||||
return individual_alias
|
||||
|
||||
if field_name in name_mapping_dict:
|
||||
return name_mapping_dict[field_name]
|
||||
|
||||
raise KeyError(
|
||||
f'Mapping for attr {model_cls.__qualname__}.{field_name}'
|
||||
f' not found in name mapping dictionary.'
|
||||
)
|
||||
|
||||
for base_cls in model_cls.__bases__:
|
||||
if not issubclass(base_cls, BaseModel):
|
||||
continue
|
||||
|
||||
if field_name not in base_cls.model_fields.keys():
|
||||
continue
|
||||
|
||||
return get_alias(
|
||||
field_name=field_name,
|
||||
model_cls=base_cls,
|
||||
name_mapping_dict=name_mapping_dict,
|
||||
)
|
||||
|
||||
raise NameError(
|
||||
f'Field {field_name} not found in model {model_cls.__qualname__}.'
|
||||
)
|
||||
|
||||
|
||||
def set_alias_gen(
|
||||
model_cls: type[BaseModel],
|
||||
alias_type: Literal['validation', 'serialization'],
|
||||
@@ -213,11 +259,11 @@ def set_alias_gen(
|
||||
return
|
||||
|
||||
def get_model_classes_from_annotation(
|
||||
annotation: type[Any] | None,
|
||||
annotation: Any,
|
||||
classes: list[type[BaseModel]] | None = None,
|
||||
):
|
||||
_classes = classes or []
|
||||
if annotation:
|
||||
if annotation and annotation is not Any:
|
||||
if isinstance(annotation, (UnionType, GenericAlias)):
|
||||
for cls in get_args(annotation):
|
||||
_classes += get_model_classes_from_annotation(cls)
|
||||
@@ -234,19 +280,11 @@ def set_alias_gen(
|
||||
)
|
||||
|
||||
def alias_gen(field_name: str):
|
||||
individual_alias = name_mapping_dict.get(
|
||||
f'{field_name}__{model_cls.__qualname__}',
|
||||
return get_alias(
|
||||
field_name=field_name,
|
||||
model_cls=model_cls,
|
||||
name_mapping_dict=name_mapping_dict,
|
||||
)
|
||||
if individual_alias:
|
||||
return individual_alias
|
||||
else:
|
||||
if field_name in name_mapping_dict:
|
||||
return name_mapping_dict[field_name]
|
||||
else:
|
||||
raise KeyError(
|
||||
f'{field_name} not found in name mapping dictionary.'
|
||||
f' Class model: {model_cls.__name__}.'
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
f'{alias_type}_alias': alias_gen
|
||||
@@ -259,6 +297,37 @@ def set_alias_gen(
|
||||
params_model_classes: dict[str, type[BaseAPIParamsModel]] = {}
|
||||
|
||||
|
||||
def gen_api_params_cls_name(api_path: str):
|
||||
return gen_cls_name_from_url_path(
|
||||
url_path=re.sub(r'[\[\]]', '', api_path),
|
||||
postfix='ParamsModel',
|
||||
)
|
||||
|
||||
|
||||
def create_api_params_cls(
|
||||
cls_name: str,
|
||||
module_name: str,
|
||||
protocol_method: Callable
|
||||
):
|
||||
field_definitions = {}
|
||||
params = inspect.signature(protocol_method).parameters.values()
|
||||
inspect_empty = inspect.Parameter.empty
|
||||
for param in params:
|
||||
if param.name == 'self':
|
||||
continue
|
||||
|
||||
field_definitions[param.name] = (
|
||||
param.annotation,
|
||||
... if param.default is inspect_empty else param.default,
|
||||
)
|
||||
return create_model(
|
||||
cls_name,
|
||||
__base__=BaseAPIParamsModel,
|
||||
__module__=module_name,
|
||||
**field_definitions,
|
||||
)
|
||||
|
||||
|
||||
APIParamsP = ParamSpec('APIParamsP')
|
||||
APIResultT = TypeVar(
|
||||
'APIResultT',
|
||||
@@ -334,6 +403,9 @@ class BaseAPI(ABC):
|
||||
)
|
||||
api_result_cls: type[APIResultT] = return_type
|
||||
|
||||
if not isinstance(protocol_method, MethodType):
|
||||
raise TypeError
|
||||
|
||||
def api_function(
|
||||
*args: APIParamsP.args,
|
||||
**kwargs: APIParamsP.kwargs,
|
||||
@@ -345,11 +417,8 @@ class BaseAPI(ABC):
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
api_params_cls_name = gen_cls_name_from_url_path(
|
||||
url_path=re.sub(r'[\[\]]', '', api_path),
|
||||
postfix=BaseAPIParamsModel.__qualname__.removeprefix(
|
||||
'BaseAPI'
|
||||
),
|
||||
api_params_cls_name = gen_api_params_cls_name(
|
||||
api_path=api_path,
|
||||
)
|
||||
|
||||
api_params_cls = params_model_classes.get(
|
||||
@@ -357,20 +426,10 @@ class BaseAPI(ABC):
|
||||
None,
|
||||
)
|
||||
if not api_params_cls:
|
||||
field_definitions = {}
|
||||
parameters = inspect.signature(protocol_method).parameters
|
||||
inspect_empty = inspect.Parameter.empty
|
||||
for p in parameters.values():
|
||||
field_definitions[p.name] = (
|
||||
p.annotation,
|
||||
... if p.default is inspect_empty else p.default,
|
||||
)
|
||||
|
||||
api_params_cls = create_model(
|
||||
api_params_cls_name,
|
||||
__base__=BaseAPIParamsModel,
|
||||
__module__=api_result_cls.__module__,
|
||||
**field_definitions,
|
||||
api_params_cls = create_api_params_cls(
|
||||
cls_name=api_params_cls_name,
|
||||
module_name=api_result_cls.__module__,
|
||||
protocol_method=protocol_method,
|
||||
)
|
||||
params_model_classes[api_params_cls_name] = api_params_cls
|
||||
|
||||
@@ -394,13 +453,20 @@ class BaseAPI(ABC):
|
||||
)
|
||||
|
||||
for api_mixin_cls in self.__class__.__bases__[1:]:
|
||||
if hasattr(api_mixin_cls, protocol_method.__name__):
|
||||
if issubclass(api_mixin_cls, BasePostAPIFunctionProtocol):
|
||||
http_method = HTTPMethod.POST
|
||||
break
|
||||
elif issubclass(api_mixin_cls, BaseGetAPIFunctionProtocol):
|
||||
http_method = HTTPMethod.GET
|
||||
break
|
||||
if not hasattr(api_mixin_cls, protocol_method.__name__):
|
||||
continue
|
||||
proto_cls_base = api_mixin_cls.__base__
|
||||
if (
|
||||
not proto_cls_base
|
||||
or not issubclass(proto_cls_base, BaseAPIFunctionProtocol)
|
||||
):
|
||||
raise TypeError(
|
||||
f'Class {api_mixin_cls.__qualname__} must have'
|
||||
f' {BaseAPIFunctionProtocol.__qualname__}'
|
||||
f' as parent class.'
|
||||
)
|
||||
http_method = base_proto_to_http_method[proto_cls_base]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user