Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/news/news_item_db_query.py: 92%
72 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements a NewsItemQuery for a database."""
3from dataclasses import dataclass
4from typing import AsyncIterator, Self
6from kwai_core.db.database import Database, Record
7from kwai_core.db.database_query import DatabaseQuery
8from kwai_core.db.rows import OwnerRow
9from kwai_core.db.table_row import JoinedTableRow
10from kwai_core.domain.value_objects.timestamp import Timestamp
11from kwai_core.domain.value_objects.unique_id import UniqueId
12from sql_smith.functions import criteria, express, func, group, literal, on
14from kwai_bc_portal.applications.application_tables import ApplicationRow
15from kwai_bc_portal.domain.news_item import NewsItemEntity, NewsItemIdentifier
16from kwai_bc_portal.news._tables import NewsItemRow, NewsItemTextRow
17from kwai_bc_portal.news.news_item_query import NewsItemQuery
20@dataclass(kw_only=True, frozen=True, slots=True)
21class NewsItemQueryRow(JoinedTableRow):
22 """A data transfer object for the news item query."""
24 text: NewsItemTextRow
25 news_item: NewsItemRow
26 application: ApplicationRow
27 owner: OwnerRow
29 @classmethod
30 def create_entity(cls, rows: list[Self]) -> NewsItemEntity:
31 """Create a news item entity from a group of rows."""
32 return rows[0].news_item.create_entity(
33 rows[0].application.create_entity(),
34 tuple([row.text.create_text(row.owner.create_owner()) for row in rows]),
35 )
38class NewsItemDbQuery(NewsItemQuery, DatabaseQuery):
39 """A database query for news stories."""
41 def __init__(self, database: Database):
42 self._main_query = database.create_query_factory().select()
43 super().__init__(database)
45 def init(self):
46 # This query will be used as CTE, so only join the tables that are needed
47 # for counting and limiting the results.
48 self._query.from_(NewsItemRow.__table_name__).join(
49 ApplicationRow.__table_name__,
50 on(
51 ApplicationRow.column("id"),
52 NewsItemRow.column("application_id"),
53 ),
54 )
56 self._main_query = (
57 self._main_query.from_(NewsItemRow.__table_name__)
58 .columns(*(self.columns + NewsItemRow.get_aliases()))
59 .with_("limited", self._query)
60 .right_join("limited", on("limited.id", NewsItemRow.column("id")))
61 .join(
62 ApplicationRow.__table_name__,
63 on(
64 ApplicationRow.column("id"),
65 NewsItemRow.column("application_id"),
66 ),
67 )
68 .join(
69 NewsItemTextRow.__table_name__,
70 on(NewsItemTextRow.column("news_id"), NewsItemRow.column("id")),
71 )
72 .join(
73 OwnerRow.__table_name__,
74 on(OwnerRow.column("id"), NewsItemTextRow.column("user_id")),
75 )
76 .order_by(NewsItemRow.column("created_at"), "desc")
77 )
79 @property
80 def columns(self):
81 return NewsItemQueryRow.get_aliases()
83 @property
84 def count_column(self) -> str:
85 return NewsItemRow.column("id")
87 def filter_by_id(self, id_: NewsItemIdentifier) -> Self:
88 self._query.and_where(NewsItemRow.field("id").eq(id_.value))
89 return self
91 def filter_by_publication_date(self, year: int, month: int | None = None) -> Self:
92 condition = criteria(
93 "{} = {}",
94 func("YEAR", NewsItemRow.column("publish_date")),
95 literal(year),
96 )
97 if month is not None:
98 condition.and_(
99 criteria(
100 "{} = {}",
101 func("MONTH", NewsItemRow.column("publish_date")),
102 literal(month),
103 )
104 )
105 self._query.and_where(condition)
106 return self
108 def filter_by_promoted(self) -> Self:
109 now = str(Timestamp.create_now())
110 condition = (
111 NewsItemRow.field("promotion")
112 .gt(0)
113 .and_(
114 group(
115 NewsItemRow.field("promotion_end_date")
116 .is_null()
117 .or_(NewsItemRow.field("promotion_end_date").gt(now))
118 )
119 )
120 )
121 self._query.and_where(condition)
122 self._query.order_by(NewsItemRow.column("promotion"))
123 return self
125 def filter_by_application(self, application: int | str) -> Self:
126 if isinstance(application, str):
127 self._query.and_where(ApplicationRow.field("name").eq(application))
128 else:
129 self._query.and_where(ApplicationRow.field("id").eq(application))
131 return self
133 def filter_by_active(self) -> Self:
134 now = str(Timestamp.create_now())
135 self._query.and_where(
136 group(
137 NewsItemRow.field("enabled")
138 .eq(True)
139 .and_(NewsItemRow.field("publish_date").lte(now))
140 .or_(
141 group(
142 NewsItemRow.field("end_date")
143 .is_not_null()
144 .and_(NewsItemRow.field("end_date").gt(now))
145 )
146 )
147 )
148 )
150 return self
152 def filter_by_user(self, user: int | UniqueId) -> Self:
153 inner_select = (
154 self._database.create_query_factory()
155 .select(OwnerRow.column("id"))
156 .from_(OwnerRow.__table_name__)
157 )
158 if isinstance(user, UniqueId):
159 inner_select.where(OwnerRow.field("uuid").eq(str(user)))
160 else:
161 inner_select.where(OwnerRow.field("id").eq(user))
163 self._main_query.and_where(
164 group(NewsItemTextRow.field("user_id").in_(express("%s", inner_select)))
165 )
166 return self
168 def order_by_publication_date(self) -> Self:
169 self._main_query.order_by(NewsItemRow.column("publish_date"), "DESC")
170 # Also add the order to the CTE
171 self._query.order_by(NewsItemRow.column("publish_date"), "DESC")
172 return self
174 def fetch(
175 self, limit: int | None = None, offset: int | None = None
176 ) -> AsyncIterator[Record]:
177 self._query.limit(limit)
178 self._query.offset(offset)
179 self._query.columns(NewsItemRow.column("id"))
180 self._main_query.order_by(NewsItemRow.column("id"))
182 return self._database.fetch(self._main_query)