from collections import OrderedDict
from functools import wraps
try:
import coreschema
except ImportError:
coreschema = None
from .. import openapi
from ..utils import force_real_str
from .base import FilterInspector, NotHandled, PaginatorInspector
def ignore_assert_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except AssertionError:
return NotHandled
return wrapper
class DrfAPICompatInspector(PaginatorInspector, FilterInspector):
[docs]
def param_to_schema(self, param):
return openapi.Parameter(
name=param["name"],
in_=param["in"],
description=param.get("description"),
required=param.get("required", False),
**param["schema"],
)
[docs]
def get_paginator_parameters(self, paginator):
if hasattr(paginator, "get_schema_operation_parameters"):
return list(
map(
self.param_to_schema,
paginator.get_schema_operation_parameters(self.view),
)
)
return NotHandled
[docs]
def get_filter_parameters(self, filter_backend):
if hasattr(filter_backend, "get_schema_operation_parameters"):
return list(
map(
self.param_to_schema,
filter_backend.get_schema_operation_parameters(self.view),
)
)
if hasattr(filter_backend, "get_filterset_class"):
return self.get_filter_backend_params(filter_backend)
return NotHandled
[docs]
def get_filter_backend_params(self, filter_backend):
filterset_class = filter_backend.get_filterset_class(
self.view,
self.view.queryset,
)
if filterset_class is None:
return NotHandled
return [
openapi.Parameter(
name=field_name,
in_=openapi.IN_QUERY,
required=field.extra.get("required", False),
type=openapi.TYPE_STRING,
description=str(field.label) if field.label else "",
)
for field_name, field in filterset_class.base_filters.items()
]
class CoreAPICompatInspector(PaginatorInspector, FilterInspector):
"""Converts ``coreapi.Field``\\ s to :class:`.openapi.Parameter`\\ s for filters and
paginators that implement a ``get_schema_fields`` method.
"""
[docs]
@ignore_assert_decorator
def get_paginator_parameters(self, paginator):
fields = []
if hasattr(paginator, "get_schema_fields"):
fields = paginator.get_schema_fields(self.view)
return [self.coreapi_field_to_parameter(field) for field in fields]
[docs]
@ignore_assert_decorator
def get_filter_parameters(self, filter_backend):
fields = []
if hasattr(filter_backend, "get_schema_fields"):
fields = filter_backend.get_schema_fields(self.view)
return [self.coreapi_field_to_parameter(field) for field in fields]
[docs]
def coreapi_field_to_parameter(self, field):
"""Convert an instance of `coreapi.Field` to a swagger :class:`.Parameter`
object.
:param coreapi.Field field:
:rtype: openapi.Parameter
"""
location_to_in = {
"query": openapi.IN_QUERY,
"path": openapi.IN_PATH,
"form": openapi.IN_FORM,
"body": openapi.IN_FORM,
}
coreapi_types = {
coreschema.Integer: openapi.TYPE_INTEGER,
coreschema.Number: openapi.TYPE_NUMBER,
coreschema.String: openapi.TYPE_STRING,
coreschema.Boolean: openapi.TYPE_BOOLEAN,
}
coreschema_attrs = ["format", "pattern", "enum", "min_length", "max_length"]
schema = field.schema
return openapi.Parameter(
name=field.name,
in_=location_to_in[field.location],
required=field.required,
description=force_real_str(schema.description) if schema else None,
type=coreapi_types.get(type(schema), openapi.TYPE_STRING),
**OrderedDict(
(attr, getattr(schema, attr, None)) for attr in coreschema_attrs
),
)
class DjangoRestResponsePagination(PaginatorInspector):
"""Provides response schema pagination wrapping for django-rest-framework'
LimitOffsetPagination, PageNumberPagination and CursorPagination
"""