From df5a5cf44936c7f4efa4bc9d4ee02efeeb8331fb Mon Sep 17 00:00:00 2001 From: Simon Date: Sun, 9 Feb 2025 17:41:12 +0700 Subject: [PATCH] add drf_spectacular, type and serialize channel app --- backend/channel/serializers.py | 111 +++++++++++++++++ backend/channel/views.py | 181 +++++++++++++++++++++------- backend/common/serializers.py | 29 +++++ backend/common/src/index_generic.py | 2 +- backend/common/views_base.py | 14 +-- backend/config/settings.py | 13 ++ backend/config/urls.py | 7 ++ backend/requirements-dev.txt | 4 +- backend/requirements.txt | 7 +- 9 files changed, 305 insertions(+), 63 deletions(-) create mode 100644 backend/channel/serializers.py create mode 100644 backend/common/serializers.py diff --git a/backend/channel/serializers.py b/backend/channel/serializers.py new file mode 100644 index 00000000..940530eb --- /dev/null +++ b/backend/channel/serializers.py @@ -0,0 +1,111 @@ +"""channel serializers""" + +# pylint: disable=abstract-method + +from common.serializers import PaginationSerializer +from rest_framework import serializers + + +class ChannelOverwriteSerializer(serializers.Serializer): + """serialize channel overwrites""" + + download_format = serializers.CharField(required=False, allow_null=True) + autodelete_days = serializers.IntegerField(required=False, allow_null=True) + index_playlists = serializers.BooleanField(required=False, allow_null=True) + integrate_sponsorblock = serializers.BooleanField( + required=False, allow_null=True + ) + subscriptions_channel_size = serializers.IntegerField( + required=False, allow_null=True + ) + subscriptions_live_channel_size = serializers.IntegerField( + required=False, allow_null=True + ) + subscriptions_shorts_channel_size = serializers.IntegerField( + required=False, allow_null=True + ) + + def to_internal_value(self, data): + """Override this method to detect unknown fields.""" + allowed_fields = set(self.fields.keys()) + input_fields = set(data.keys()) + + unknown_fields = input_fields - allowed_fields + + if unknown_fields: + raise serializers.ValidationError( + {"error": f"Unknown fields: {', '.join(unknown_fields)}"} + ) + + return super().to_internal_value(data) + + +class ChannelSerializer(serializers.Serializer): + """serialize channel""" + + channel_id = serializers.CharField() + channel_active = serializers.BooleanField() + channel_banner_url = serializers.CharField() + channel_thumb_url = serializers.CharField() + channel_tvart_url = serializers.CharField() + channel_description = serializers.CharField() + channel_last_refresh = serializers.CharField() + channel_name = serializers.CharField() + channel_overwrites = ChannelOverwriteSerializer(required=False) + channel_subs = serializers.IntegerField() + channel_subscribed = serializers.BooleanField() + channel_tags = serializers.ListField(child=serializers.CharField()) + channel_views = serializers.IntegerField() + _index = serializers.CharField(required=False) + _score = serializers.IntegerField(required=False) + + +class ChannelListSerializer(serializers.Serializer): + """serialize channel list""" + + data = ChannelSerializer(many=True) + paginate = PaginationSerializer() + + +class ChannelListQuerySerializer(serializers.Serializer): + """serialize list query""" + + filter = serializers.ChoiceField(choices=["subscribed"], required=False) + + +class ChannelUpdateSerializer(serializers.Serializer): + """update channel""" + + channel_subscribed = serializers.BooleanField(required=False) + channel_overwrites = ChannelOverwriteSerializer(required=False) + + +class ChannelAggBucketSerializer(serializers.Serializer): + """serialize channel agg bucket""" + + value = serializers.IntegerField() + value_str = serializers.CharField(required=False) + + +class ChannelAggSerializer(serializers.Serializer): + """serialize channel aggregation""" + + total_items = ChannelAggBucketSerializer() + total_size = ChannelAggBucketSerializer() + total_duration = ChannelAggBucketSerializer() + + +class ChannelNavSerializer(serializers.Serializer): + """serialize channel navigation""" + + has_pending = serializers.BooleanField() + has_playlists = serializers.BooleanField() + has_videos = serializers.BooleanField() + has_streams = serializers.BooleanField() + has_shorts = serializers.BooleanField() + + +class ChannelSearchQuerySerializer(serializers.Serializer): + """serialize query parameters for searching""" + + q = serializers.CharField() diff --git a/backend/channel/views.py b/backend/channel/views.py index 7fe0860a..5e01320e 100644 --- a/backend/channel/views.py +++ b/backend/channel/views.py @@ -1,10 +1,25 @@ """all channel API views""" +from channel.serializers import ( + ChannelAggSerializer, + ChannelListQuerySerializer, + ChannelListSerializer, + ChannelNavSerializer, + ChannelSearchQuerySerializer, + ChannelSerializer, + ChannelUpdateSerializer, +) from channel.src.index import YoutubeChannel, channel_overwrites from channel.src.nav import ChannelNav +from common.serializers import ErrorResponseSerializer from common.src.urlparser import Parser from common.views_base import AdminWriteOnly, ApiBaseView from download.src.subscriptions import ChannelSubscription +from drf_spectacular.utils import ( + OpenApiParameter, + OpenApiResponse, + extend_schema, +) from rest_framework.response import Response from task.tasks import index_channel_playlists, subscribe_to @@ -19,26 +34,38 @@ class ChannelApiListView(ApiBaseView): valid_filter = ["subscribed"] permission_classes = [AdminWriteOnly] + @extend_schema( + responses={ + 200: OpenApiResponse(ChannelListSerializer()), + }, + parameters=[ + OpenApiParameter( + name="filter", + description="Filter by Subscribed", + type=ChannelListQuerySerializer(), + ), + ], + ) def get(self, request): """get request""" self.data.update( {"sort": [{"channel_name.keyword": {"order": "asc"}}]} ) - query_filter = request.GET.get("filter", False) - must_list = [] - if query_filter: - if query_filter not in self.valid_filter: - message = f"invalid url query filter: {query_filter}" - print(message) - return Response({"message": message}, status=400) + serializer = ChannelListQuerySerializer(data=request.query_params) + serializer.is_valid(raise_exception=True) + validated_data = serializer.validated_data + must_list = [] + query_filter = validated_data.get("filter") + if query_filter: must_list.append({"term": {"channel_subscribed": {"value": True}}}) self.data["query"] = {"bool": {"must": must_list}} self.get_document_list(request) + serializer = ChannelListSerializer(self.response) - return Response(self.response) + return Response(serializer.data) def post(self, request): """subscribe/unsubscribe to list of channels""" @@ -81,53 +108,79 @@ class ChannelApiView(ApiBaseView): search_base = "ta_channel/_doc/" permission_classes = [AdminWriteOnly] + @extend_schema( + responses={ + 200: OpenApiResponse(ChannelSerializer()), + 404: OpenApiResponse( + ErrorResponseSerializer(), description="Channel not found" + ), + } + ) def get(self, request, channel_id): # pylint: disable=unused-argument - """get request""" + """get channel detail""" self.get_document(channel_id) - return Response(self.response, status=self.status_code) + if not self.response: + error = ErrorResponseSerializer({"error": "channel not found"}) + return Response(error.data, status=404) + response_serializer = ChannelSerializer(self.response) + return Response(response_serializer.data, status=self.status_code) + + @extend_schema( + request=ChannelUpdateSerializer(), + responses={ + 200: OpenApiResponse(ChannelUpdateSerializer()), + 400: OpenApiResponse( + ErrorResponseSerializer(), description="Bad request" + ), + 404: OpenApiResponse( + ErrorResponseSerializer(), description="Channel not found" + ), + }, + ) def post(self, request, channel_id): - """modify channel overwrites""" + """modify channel""" self.get_document(channel_id) - if not self.response["data"]: - return Response({"error": "channel not found"}, status=404) + if not self.response: + error = ErrorResponseSerializer({"error": "channel not found"}) + return Response(error.data, status=404) - data = request.data - subscribed = data.get("channel_subscribed") + serializer = ChannelUpdateSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + validated_data = serializer.validated_data + + subscribed = validated_data.get("channel_subscribed") if subscribed is not None: - channel_sub = ChannelSubscription() - json_data = channel_sub.change_subscribe(channel_id, subscribed) - return Response(json_data, status=200) + ChannelSubscription().change_subscribe(channel_id, subscribed) - if "channel_overwrites" not in data: - return Response({"error": "invalid payload"}, status=400) - - overwrites = data["channel_overwrites"] - - try: - json_data = channel_overwrites(channel_id, overwrites) + overwrites = validated_data.get("channel_overwrites") + if overwrites: + channel_overwrites(channel_id, overwrites) if overwrites.get("index_playlists"): index_channel_playlists.delay(channel_id) - except ValueError as err: - return Response({"error": str(err)}, status=400) - - return Response(json_data, status=200) + return Response(serializer.data, status=200) + @extend_schema( + responses={ + 204: OpenApiResponse(description="Channel deleted"), + 404: OpenApiResponse( + ErrorResponseSerializer(), description="Channel not found" + ), + }, + ) def delete(self, request, channel_id): # pylint: disable=unused-argument """delete channel""" - message = {"channel": channel_id} try: YoutubeChannel(channel_id).delete_channel() - status_code = 200 - message.update({"state": "delete"}) + return Response(status=204) except FileNotFoundError: - status_code = 404 - message.update({"state": "not found"}) + pass - return Response(message, status=status_code) + error = ErrorResponseSerializer({"error": "channel not found"}) + return Response(error.data, status=404) class ChannelAggsApiView(ApiBaseView): @@ -137,8 +190,13 @@ class ChannelAggsApiView(ApiBaseView): search_base = "ta_video/_search" + @extend_schema( + responses={ + 200: OpenApiResponse(ChannelAggSerializer()), + }, + ) def get(self, request, channel_id): - """get aggs""" + """get channel aggregations""" self.data.update( { "query": { @@ -152,8 +210,9 @@ class ChannelAggsApiView(ApiBaseView): } ) self.get_aggs() + serializer = ChannelAggSerializer(self.response) - return Response(self.response) + return Response(serializer.data) class ChannelNavApiView(ApiBaseView): @@ -161,11 +220,17 @@ class ChannelNavApiView(ApiBaseView): GET: get channel nav """ + @extend_schema( + responses={ + 200: OpenApiResponse(ChannelNavSerializer()), + }, + ) def get(self, request, channel_id): - """get nav""" + """get navigation""" nav = ChannelNav(channel_id).get_nav() - return Response(nav) + serializer = ChannelNavSerializer(nav) + return Response(serializer.data) class ChannelApiSearchView(ApiBaseView): @@ -175,10 +240,31 @@ class ChannelApiSearchView(ApiBaseView): search_base = "ta_channel/_doc/" + @extend_schema( + responses={ + 200: OpenApiResponse(ChannelSerializer()), + 400: OpenApiResponse(description="Bad Request"), + 404: OpenApiResponse( + ErrorResponseSerializer(), description="Channel not found" + ), + }, + parameters=[ + OpenApiParameter( + name="q", + description="Search query string", + required=True, + type=str, + ), + ], + ) def get(self, request): - """handle get request, search with s parameter""" + """search for local channel ID""" - query = request.GET.get("q") + serializer = ChannelSearchQuerySerializer(data=request.query_params) + serializer.is_valid(raise_exception=True) + validated_data = serializer.validated_data + + query = validated_data.get("q") if not query: message = "missing expected q parameter" return Response({"message": message, "data": False}, status=400) @@ -186,13 +272,16 @@ class ChannelApiSearchView(ApiBaseView): try: parsed = Parser(query).parse()[0] except (ValueError, IndexError, AttributeError): - message = f"channel not found: {query}" - return Response({"message": message, "data": False}, status=404) + error = ErrorResponseSerializer( + {"error": f"channel not found: {query}"} + ) + return Response(error.data, status=404) if not parsed["type"] == "channel": - message = "expected type channel" - return Response({"message": message, "data": False}, status=400) + error = ErrorResponseSerializer({"error": "expected channel data"}) + return Response(error.data, status=400) self.get_document(parsed["url"]) + serializer = ChannelSerializer(self.response) - return Response(self.response, status=self.status_code) + return Response(serializer.data, status=self.status_code) diff --git a/backend/common/serializers.py b/backend/common/serializers.py new file mode 100644 index 00000000..8b9079a5 --- /dev/null +++ b/backend/common/serializers.py @@ -0,0 +1,29 @@ +"""common serializers""" + +# pylint: disable=abstract-method + +from rest_framework import serializers + + +class ErrorResponseSerializer(serializers.Serializer): + """error message""" + + error = serializers.CharField() + + +class PaginationSerializer(serializers.Serializer): + """serialize paginate response""" + + page_size = serializers.IntegerField() + page_from = serializers.IntegerField() + prev_pages = serializers.ListField( + child=serializers.IntegerField(), allow_null=True + ) + current_page = serializers.IntegerField() + max_hits = serializers.BooleanField() + params = serializers.CharField() + last_page = serializers.BooleanField() + next_pages = serializers.ListField( + child=serializers.IntegerField(), allow_null=True + ) + total_hits = serializers.IntegerField() diff --git a/backend/common/src/index_generic.py b/backend/common/src/index_generic.py index 3daeefe7..450c1f88 100644 --- a/backend/common/src/index_generic.py +++ b/backend/common/src/index_generic.py @@ -106,7 +106,7 @@ class Pagination: page_get = self.page_get page_from = 0 if page_get in [0, 1]: - prev_pages = False + prev_pages = None elif page_get > 1: page_from = (page_get - 1) * self.page_size prev_pages = [ diff --git a/backend/common/views_base.py b/backend/common/views_base.py index 324ecb05..9382dbae 100644 --- a/backend/common/views_base.py +++ b/backend/common/views_base.py @@ -1,7 +1,5 @@ """base classes to inherit from""" -from appsettings.src.config import AppConfig -from common.src.env_settings import EnvironmentSettings from common.src.es_connect import ElasticWrap from common.src.index_generic import Pagination from common.src.search_processor import SearchProcess, process_aggs @@ -45,13 +43,7 @@ class ApiBaseView(APIView): def __init__(self): super().__init__() - self.response = { - "data": False, - "config": { - "enable_cast": EnvironmentSettings.ENABLE_CAST, - "downloads": AppConfig().config["downloads"], - }, - } + self.response = {} self.data = {"query": {"match_all": {}}} self.status_code = False self.context = False @@ -62,12 +54,12 @@ class ApiBaseView(APIView): path = f"{self.search_base}{document_id}" response, status_code = ElasticWrap(path).get() try: - self.response["data"] = SearchProcess( + self.response = SearchProcess( response, match_video_user_progress=progress_match ).process() except KeyError: print(f"item not found: {document_id}") - self.response["data"] = False + self.status_code = status_code def initiate_pagination(self, request): diff --git a/backend/config/settings.py b/backend/config/settings.py index 5652efc2..9ea6ddd4 100644 --- a/backend/config/settings.py +++ b/backend/config/settings.py @@ -59,6 +59,7 @@ INSTALLED_APPS = [ "django.contrib.humanize", "rest_framework", "rest_framework.authtoken", + "drf_spectacular", "common", "video", "channel", @@ -295,3 +296,15 @@ CORS_ALLOW_HEADERS = list(default_headers) + [ # TA application settings TA_UPSTREAM = "https://github.com/tubearchivist/tubearchivist" TA_VERSION = "v0.5.0-unstable" + +# API +REST_FRAMEWORK = { + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", +} + +SPECTACULAR_SETTINGS = { + "TITLE": "Tube Archivist API", + "DESCRIPTION": "API documentation for Tube Archivist backend.", + "VERSION": TA_VERSION, + "SERVE_INCLUDE_SCHEMA": False, +} diff --git a/backend/config/urls.py b/backend/config/urls.py index 2fbf84b3..cc56d743 100644 --- a/backend/config/urls.py +++ b/backend/config/urls.py @@ -16,6 +16,7 @@ Including another URLconf from django.contrib import admin from django.urls import include, path +from drf_spectacular.views import SpectacularAPIView, SpectacularSwaggerView urlpatterns = [ path("api/", include("common.urls")), @@ -27,5 +28,11 @@ urlpatterns = [ path("api/appsettings/", include("appsettings.urls")), path("api/stats/", include("stats.urls")), path("api/user/", include("user.urls")), + path("api/schema/", SpectacularAPIView.as_view(), name="schema"), + path( + "api/docs/", + SpectacularSwaggerView.as_view(url_name="schema"), + name="swagger-ui", + ), path("admin/", admin.site.urls), ] diff --git a/backend/requirements-dev.txt b/backend/requirements-dev.txt index bd641c5c..0a934098 100644 --- a/backend/requirements-dev.txt +++ b/backend/requirements-dev.txt @@ -1,8 +1,8 @@ -r requirements.txt -ipython==8.31.0 +ipython==8.32.0 pre-commit==4.1.0 pylint-django==2.6.1 -pylint==3.3.3 +pylint==3.3.4 pytest-django==4.9.0 pytest==8.3.4 python-dotenv==1.0.1 diff --git a/backend/requirements.txt b/backend/requirements.txt index b4c6c09e..70c9a25b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -2,13 +2,14 @@ apprise==1.9.2 celery==5.4.0 django-auth-ldap==5.1.0 django-celery-beat==2.7.0 -django-cors-headers==4.6.0 -Django==5.1.5 +django-cors-headers==4.7.0 +Django==5.1.6 djangorestframework==3.15.2 +drf-spectacular==0.28.0 Pillow==11.1.0 redis==5.2.1 requests==2.32.3 ryd-client==0.0.6 uvicorn==0.34.0 -whitenoise==6.8.2 +whitenoise==6.9.0 yt-dlp[default]==2025.1.26