Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/pages/page_db_query.py: 91%

55 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a PageQuery for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import AsyncIterator, Self 

5 

6from kwai_core.db.database import Database, Record 

7from kwai_core.db.database_query import DatabaseQuery 

8from kwai_core.db.rows import OwnerRow 

9from kwai_core.db.table_row import JoinedTableRow 

10from kwai_core.domain.value_objects.unique_id import UniqueId 

11from sql_smith.functions import express, group, on 

12 

13from kwai_bc_portal.applications.application_tables import ApplicationRow 

14from kwai_bc_portal.domain.page import PageEntity, PageIdentifier 

15from kwai_bc_portal.pages._tables import PageRow, PageTextRow 

16from kwai_bc_portal.pages.page_query import PageQuery 

17 

18 

19@dataclass(kw_only=True, frozen=True, slots=True) 

20class PageQueryRow(JoinedTableRow): 

21 """A data transfer object for the page query.""" 

22 

23 text: PageTextRow 

24 page: PageRow 

25 application: ApplicationRow 

26 owner: OwnerRow 

27 

28 @classmethod 

29 def create_entity(cls, rows: list[Self]) -> PageEntity: 

30 """Create a page entity from a group of rows.""" 

31 return rows[0].page.create_entity( 

32 rows[0].application.create_entity(), 

33 tuple([row.text.create_text(row.owner.create_owner()) for row in rows]), 

34 ) 

35 

36 

37class PageDbQuery(PageQuery, DatabaseQuery): 

38 """A database query for pages.""" 

39 

40 def __init__(self, database: Database): 

41 self._main_query = database.create_query_factory().select() 

42 super().__init__(database) 

43 

44 def init(self): 

45 self._query.from_(PageRow.__table_name__).join( 

46 ApplicationRow.__table_name__, 

47 on( 

48 ApplicationRow.column("id"), 

49 PageRow.column("application_id"), 

50 ), 

51 ) 

52 self._main_query = ( 

53 self._main_query.from_(PageRow.__table_name__) 

54 .columns(*(self.columns + PageRow.get_aliases())) 

55 .with_("limited", self._query) 

56 .right_join("limited", on("limited.id", PageRow.column("id"))) 

57 .join( 

58 ApplicationRow.__table_name__, 

59 on( 

60 ApplicationRow.column("id"), 

61 PageRow.column("application_id"), 

62 ), 

63 ) 

64 .join( 

65 PageTextRow.__table_name__, 

66 on(PageTextRow.column("page_id"), PageRow.column("id")), 

67 ) 

68 .join( 

69 OwnerRow.__table_name__, 

70 on(OwnerRow.column("id"), PageTextRow.column("user_id")), 

71 ) 

72 ) 

73 

74 @property 

75 def columns(self): 

76 return PageQueryRow.get_aliases() 

77 

78 @property 

79 def count_column(self) -> str: 

80 return PageRow.column("id") 

81 

82 def filter_by_id(self, id_: PageIdentifier) -> Self: 

83 self._query.and_where(PageRow.field("id").eq(id_.value)) 

84 return self 

85 

86 def filter_by_application(self, application: int | str) -> Self: 

87 if isinstance(application, str): 

88 self._query.and_where(ApplicationRow.field("name").eq(application)) 

89 else: 

90 self._query.and_where(ApplicationRow.field("id").eq(application)) 

91 return self 

92 

93 def filter_by_active(self) -> Self: 

94 self._query.and_where(PageRow.field("enabled").eq(1)) 

95 return self 

96 

97 def filter_by_user(self, user: int | UniqueId) -> Self: 

98 inner_select = ( 

99 self._database.create_query_factory() 

100 .select(OwnerRow.column("id")) 

101 .from_(OwnerRow.__table_name__) 

102 ) 

103 if isinstance(user, UniqueId): 

104 inner_select.where(OwnerRow.field("uuid").eq(str(user))) 

105 else: 

106 inner_select.where(OwnerRow.field("id").eq(user)) 

107 

108 self._main_query.and_where( 

109 group(PageTextRow.field("user_id").in_(express("%s", inner_select))) 

110 ) 

111 return self 

112 

113 def fetch( 

114 self, limit: int | None = None, offset: int | None = None 

115 ) -> AsyncIterator[Record]: 

116 self._query.limit(limit) 

117 self._query.offset(offset) 

118 self._query.columns(PageRow.column("id")) 

119 self._main_query.order_by(PageRow.column("priority")) 

120 self._main_query.order_by(PageRow.column("id")) 

121 

122 return self._database.fetch(self._main_query)