Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/pages/page_db_query.py: 91%
55 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements a PageQuery for a database."""
3from dataclasses import dataclass
4from typing import AsyncIterator, Self
6from kwai_core.db.database import Database, Record
7from kwai_core.db.database_query import DatabaseQuery
8from kwai_core.db.rows import OwnerRow
9from kwai_core.db.table_row import JoinedTableRow
10from kwai_core.domain.value_objects.unique_id import UniqueId
11from sql_smith.functions import express, group, on
13from kwai_bc_portal.applications.application_tables import ApplicationRow
14from kwai_bc_portal.domain.page import PageEntity, PageIdentifier
15from kwai_bc_portal.pages._tables import PageRow, PageTextRow
16from kwai_bc_portal.pages.page_query import PageQuery
19@dataclass(kw_only=True, frozen=True, slots=True)
20class PageQueryRow(JoinedTableRow):
21 """A data transfer object for the page query."""
23 text: PageTextRow
24 page: PageRow
25 application: ApplicationRow
26 owner: OwnerRow
28 @classmethod
29 def create_entity(cls, rows: list[Self]) -> PageEntity:
30 """Create a page entity from a group of rows."""
31 return rows[0].page.create_entity(
32 rows[0].application.create_entity(),
33 tuple([row.text.create_text(row.owner.create_owner()) for row in rows]),
34 )
37class PageDbQuery(PageQuery, DatabaseQuery):
38 """A database query for pages."""
40 def __init__(self, database: Database):
41 self._main_query = database.create_query_factory().select()
42 super().__init__(database)
44 def init(self):
45 self._query.from_(PageRow.__table_name__).join(
46 ApplicationRow.__table_name__,
47 on(
48 ApplicationRow.column("id"),
49 PageRow.column("application_id"),
50 ),
51 )
52 self._main_query = (
53 self._main_query.from_(PageRow.__table_name__)
54 .columns(*(self.columns + PageRow.get_aliases()))
55 .with_("limited", self._query)
56 .right_join("limited", on("limited.id", PageRow.column("id")))
57 .join(
58 ApplicationRow.__table_name__,
59 on(
60 ApplicationRow.column("id"),
61 PageRow.column("application_id"),
62 ),
63 )
64 .join(
65 PageTextRow.__table_name__,
66 on(PageTextRow.column("page_id"), PageRow.column("id")),
67 )
68 .join(
69 OwnerRow.__table_name__,
70 on(OwnerRow.column("id"), PageTextRow.column("user_id")),
71 )
72 )
74 @property
75 def columns(self):
76 return PageQueryRow.get_aliases()
78 @property
79 def count_column(self) -> str:
80 return PageRow.column("id")
82 def filter_by_id(self, id_: PageIdentifier) -> Self:
83 self._query.and_where(PageRow.field("id").eq(id_.value))
84 return self
86 def filter_by_application(self, application: int | str) -> Self:
87 if isinstance(application, str):
88 self._query.and_where(ApplicationRow.field("name").eq(application))
89 else:
90 self._query.and_where(ApplicationRow.field("id").eq(application))
91 return self
93 def filter_by_active(self) -> Self:
94 self._query.and_where(PageRow.field("enabled").eq(1))
95 return self
97 def filter_by_user(self, user: int | UniqueId) -> Self:
98 inner_select = (
99 self._database.create_query_factory()
100 .select(OwnerRow.column("id"))
101 .from_(OwnerRow.__table_name__)
102 )
103 if isinstance(user, UniqueId):
104 inner_select.where(OwnerRow.field("uuid").eq(str(user)))
105 else:
106 inner_select.where(OwnerRow.field("id").eq(user))
108 self._main_query.and_where(
109 group(PageTextRow.field("user_id").in_(express("%s", inner_select)))
110 )
111 return self
113 def fetch(
114 self, limit: int | None = None, offset: int | None = None
115 ) -> AsyncIterator[Record]:
116 self._query.limit(limit)
117 self._query.offset(offset)
118 self._query.columns(PageRow.column("id"))
119 self._main_query.order_by(PageRow.column("priority"))
120 self._main_query.order_by(PageRow.column("id"))
122 return self._database.fetch(self._main_query)