You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
7.2 KiB

1 week ago
from collections.abc import Callable
from dataclasses import dataclass
import inspect
import os
from types import GenericAlias, UnionType
from typing import Any, get_args
from urllib3 import disable_warnings
import pytest
import requests
import dynamix_sdk.types as sdk_types
from dynamix_sdk.base import (
gen_api_params_cls_name,
create_api_params_cls,
BaseAPIFunctionProtocol,
base_proto_to_http_method,
)
from dynamix_sdk.utils import JSON, HTTPMethod
@dataclass(kw_only=True)
class SDKFunction:
api_cls: type[sdk_types.BaseAPI]
call_attrs: tuple[str, ...]
url_path: str
proto_cls: type[BaseAPIFunctionProtocol]
proto_method: Callable
http_method: HTTPMethod
params_model_cls: type[sdk_types.BaseAPIParamsModel]
result_cls: type[sdk_types.BaseAPIResult]
@dataclass(kw_only=True)
class APISubgroup:
name: str
cls: sdk_types.BaseAPI
functions: tuple[SDKFunction, ...]
@dataclass(kw_only=True)
class APIGroup:
name: str
cls: sdk_types.BaseAPI
subgroups: tuple[APISubgroup, ...]
@pytest.fixture(scope='session')
def api_groups():
result_list: list[APIGroup] = []
for attr_name, attr_annot in sdk_types.API.__annotations__.items():
api_group_name = attr_name
api_group_cls = attr_annot
api_subgroups: list[APISubgroup] = []
for attr_name, attr_annot in api_group_cls.__annotations__.items():
api_subgroup_name = attr_name
api_subgroup_cls = attr_annot
sdk_functions: list[SDKFunction] = []
for attr_name in dir(api_subgroup_cls):
if attr_name.startswith('_'):
continue
attr = getattr(api_subgroup_cls, attr_name)
if not callable(attr):
continue
method_name = attr_name
method = attr
for mixin_cls in api_subgroup_cls.__bases__[1:]:
if not hasattr(mixin_cls, method_name):
continue
assert issubclass(mixin_cls, BaseAPIFunctionProtocol), (
f'Class {mixin_cls.__qualname__}'
f' must be inherited from'
f' {BaseAPIFunctionProtocol.__qualname__}.'
)
valid_bases = base_proto_to_http_method.keys()
mixin_cls_base = mixin_cls.__base__
assert (
mixin_cls_base
and issubclass(mixin_cls_base, BaseAPIFunctionProtocol)
and mixin_cls_base in valid_bases
), (
f'Class {mixin_cls.__qualname__}'
f' must be inherited from one of these classes:'
f" {', '.join(p.__qualname__ for p in valid_bases)}."
)
proto_cls = mixin_cls
proto_cls_base = mixin_cls_base
break
else:
raise LookupError(
f'{api_subgroup_cls.__qualname__}:'
f'mixin class for method "{method_name}" not found.'
)
attr_names = (
api_group_name,
api_subgroup_name,
attr_name,
)
api_func_url_path = ''
for sdk_func_path_part in attr_names:
url_path_part = api_subgroup_cls._path_mapping_dict.get(
sdk_func_path_part,
sdk_func_path_part
)
api_func_url_path = f'{api_func_url_path}/{url_path_part}'
api_params_cls_name = gen_api_params_cls_name(
api_path=api_func_url_path,
)
result_cls = inspect.signature(method).return_annotation
sdk_functions.append(
SDKFunction(
api_cls=api_subgroup_cls,
call_attrs=attr_names,
url_path=api_func_url_path,
proto_cls=proto_cls,
proto_method=method,
http_method=base_proto_to_http_method[proto_cls_base],
params_model_cls=create_api_params_cls(
cls_name=api_params_cls_name,
module_name=result_cls.__module__,
protocol_method=method,
),
result_cls=result_cls,
)
)
api_subgroups.append(
APISubgroup(
name=api_subgroup_name,
cls=api_subgroup_cls,
functions=tuple(sdk_functions),
)
)
result_list.append(
APIGroup(
name=api_group_name,
cls=api_group_cls,
subgroups=tuple(api_subgroups),
)
)
return tuple(result_list)
@pytest.fixture(scope='session')
def sdk_dx_functions(api_groups):
result_list: list[SDKFunction] = []
for api_group in api_groups:
for api_subgroup in api_group.subgroups:
result_list += api_subgroup.functions
return tuple(result_list)
@pytest.fixture(scope='session')
def dx_models(sdk_dx_functions: tuple[SDKFunction, ...]):
def get_models_from_annotation(annotation: Any):
if not annotation:
raise TypeError
models = []
if annotation is Any:
return models
if isinstance(annotation, (UnionType, GenericAlias)):
for annotation in get_args(annotation):
models += get_models_from_annotation(annotation=annotation)
elif issubclass(annotation, sdk_types.BaseModel):
model_cls = annotation
models.append(model_cls)
return set(models)
def get_nested_models(model_cls: type[sdk_types.BaseModel]):
models = []
for field_info in model_cls.model_fields.values():
models += get_models_from_annotation(
annotation=field_info.annotation,
)
for model in models:
models += get_nested_models(model_cls=model)
return set(models)
dx_models = []
for sdk_func in sdk_dx_functions:
dx_models.append(sdk_func.params_model_cls)
dx_models += get_nested_models(model_cls=sdk_func.params_model_cls)
if issubclass(sdk_func.result_cls, sdk_types.BaseAPIResultModel):
dx_models.append(sdk_func.result_cls)
dx_models += get_nested_models(model_cls=sdk_func.result_cls)
return set(dx_models)
@pytest.fixture(scope='session')
def dx_url():
dx_url = os.getenv('DYNAMIX_URL')
assert dx_url
return dx_url
@pytest.fixture(scope='session')
def dx_api_definition(dx_url: str) -> JSON:
API_DEFINITION_API_PATH = '/restmachine/system/docgenerator/prepareCatalog'
disable_warnings()
api_definition_resp = requests.post(
url=f'{dx_url}{API_DEFINITION_API_PATH}',
verify=False,
)
api_definition_resp.raise_for_status()
return api_definition_resp.json()