Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/news/news_item_db_query.py: 92%

72 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a NewsItemQuery for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import AsyncIterator, Self 

5 

6from kwai_core.db.database import Database, Record 

7from kwai_core.db.database_query import DatabaseQuery 

8from kwai_core.db.rows import OwnerRow 

9from kwai_core.db.table_row import JoinedTableRow 

10from kwai_core.domain.value_objects.timestamp import Timestamp 

11from kwai_core.domain.value_objects.unique_id import UniqueId 

12from sql_smith.functions import criteria, express, func, group, literal, on 

13 

14from kwai_bc_portal.applications.application_tables import ApplicationRow 

15from kwai_bc_portal.domain.news_item import NewsItemEntity, NewsItemIdentifier 

16from kwai_bc_portal.news._tables import NewsItemRow, NewsItemTextRow 

17from kwai_bc_portal.news.news_item_query import NewsItemQuery 

18 

19 

20@dataclass(kw_only=True, frozen=True, slots=True) 

21class NewsItemQueryRow(JoinedTableRow): 

22 """A data transfer object for the news item query.""" 

23 

24 text: NewsItemTextRow 

25 news_item: NewsItemRow 

26 application: ApplicationRow 

27 owner: OwnerRow 

28 

29 @classmethod 

30 def create_entity(cls, rows: list[Self]) -> NewsItemEntity: 

31 """Create a news item entity from a group of rows.""" 

32 return rows[0].news_item.create_entity( 

33 rows[0].application.create_entity(), 

34 tuple([row.text.create_text(row.owner.create_owner()) for row in rows]), 

35 ) 

36 

37 

38class NewsItemDbQuery(NewsItemQuery, DatabaseQuery): 

39 """A database query for news stories.""" 

40 

41 def __init__(self, database: Database): 

42 self._main_query = database.create_query_factory().select() 

43 super().__init__(database) 

44 

45 def init(self): 

46 # This query will be used as CTE, so only join the tables that are needed 

47 # for counting and limiting the results. 

48 self._query.from_(NewsItemRow.__table_name__).join( 

49 ApplicationRow.__table_name__, 

50 on( 

51 ApplicationRow.column("id"), 

52 NewsItemRow.column("application_id"), 

53 ), 

54 ) 

55 

56 self._main_query = ( 

57 self._main_query.from_(NewsItemRow.__table_name__) 

58 .columns(*(self.columns + NewsItemRow.get_aliases())) 

59 .with_("limited", self._query) 

60 .right_join("limited", on("limited.id", NewsItemRow.column("id"))) 

61 .join( 

62 ApplicationRow.__table_name__, 

63 on( 

64 ApplicationRow.column("id"), 

65 NewsItemRow.column("application_id"), 

66 ), 

67 ) 

68 .join( 

69 NewsItemTextRow.__table_name__, 

70 on(NewsItemTextRow.column("news_id"), NewsItemRow.column("id")), 

71 ) 

72 .join( 

73 OwnerRow.__table_name__, 

74 on(OwnerRow.column("id"), NewsItemTextRow.column("user_id")), 

75 ) 

76 .order_by(NewsItemRow.column("created_at"), "desc") 

77 ) 

78 

79 @property 

80 def columns(self): 

81 return NewsItemQueryRow.get_aliases() 

82 

83 @property 

84 def count_column(self) -> str: 

85 return NewsItemRow.column("id") 

86 

87 def filter_by_id(self, id_: NewsItemIdentifier) -> Self: 

88 self._query.and_where(NewsItemRow.field("id").eq(id_.value)) 

89 return self 

90 

91 def filter_by_publication_date(self, year: int, month: int | None = None) -> Self: 

92 condition = criteria( 

93 "{} = {}", 

94 func("YEAR", NewsItemRow.column("publish_date")), 

95 literal(year), 

96 ) 

97 if month is not None: 

98 condition.and_( 

99 criteria( 

100 "{} = {}", 

101 func("MONTH", NewsItemRow.column("publish_date")), 

102 literal(month), 

103 ) 

104 ) 

105 self._query.and_where(condition) 

106 return self 

107 

108 def filter_by_promoted(self) -> Self: 

109 now = str(Timestamp.create_now()) 

110 condition = ( 

111 NewsItemRow.field("promotion") 

112 .gt(0) 

113 .and_( 

114 group( 

115 NewsItemRow.field("promotion_end_date") 

116 .is_null() 

117 .or_(NewsItemRow.field("promotion_end_date").gt(now)) 

118 ) 

119 ) 

120 ) 

121 self._query.and_where(condition) 

122 self._query.order_by(NewsItemRow.column("promotion")) 

123 return self 

124 

125 def filter_by_application(self, application: int | str) -> Self: 

126 if isinstance(application, str): 

127 self._query.and_where(ApplicationRow.field("name").eq(application)) 

128 else: 

129 self._query.and_where(ApplicationRow.field("id").eq(application)) 

130 

131 return self 

132 

133 def filter_by_active(self) -> Self: 

134 now = str(Timestamp.create_now()) 

135 self._query.and_where( 

136 group( 

137 NewsItemRow.field("enabled") 

138 .eq(True) 

139 .and_(NewsItemRow.field("publish_date").lte(now)) 

140 .or_( 

141 group( 

142 NewsItemRow.field("end_date") 

143 .is_not_null() 

144 .and_(NewsItemRow.field("end_date").gt(now)) 

145 ) 

146 ) 

147 ) 

148 ) 

149 

150 return self 

151 

152 def filter_by_user(self, user: int | UniqueId) -> Self: 

153 inner_select = ( 

154 self._database.create_query_factory() 

155 .select(OwnerRow.column("id")) 

156 .from_(OwnerRow.__table_name__) 

157 ) 

158 if isinstance(user, UniqueId): 

159 inner_select.where(OwnerRow.field("uuid").eq(str(user))) 

160 else: 

161 inner_select.where(OwnerRow.field("id").eq(user)) 

162 

163 self._main_query.and_where( 

164 group(NewsItemTextRow.field("user_id").in_(express("%s", inner_select))) 

165 ) 

166 return self 

167 

168 def order_by_publication_date(self) -> Self: 

169 self._main_query.order_by(NewsItemRow.column("publish_date"), "DESC") 

170 # Also add the order to the CTE 

171 self._query.order_by(NewsItemRow.column("publish_date"), "DESC") 

172 return self 

173 

174 def fetch( 

175 self, limit: int | None = None, offset: int | None = None 

176 ) -> AsyncIterator[Record]: 

177 self._query.limit(limit) 

178 self._query.offset(offset) 

179 self._query.columns(NewsItemRow.column("id")) 

180 self._main_query.order_by(NewsItemRow.column("id")) 

181 

182 return self._database.fetch(self._main_query)