Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_query.py: 100%

74 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a training query for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import AsyncIterator, Self 

5 

6from kwai_core.db.database import Database, Record 

7from kwai_core.db.database_query import DatabaseQuery 

8from kwai_core.db.rows import OwnerRow 

9from kwai_core.db.table_row import JoinedTableRow 

10from kwai_core.domain.value_objects.timestamp import Timestamp 

11from sql_smith.functions import alias, criteria, express, func, group, literal, on 

12 

13from kwai_bc_training.coaches.coach import CoachEntity 

14from kwai_bc_training.teams.team import TeamEntity 

15from kwai_bc_training.teams.team_tables import TeamRow 

16from kwai_bc_training.trainings._tables import ( 

17 TrainingCoachRow, 

18 TrainingRow, 

19 TrainingScheduleRow, 

20 TrainingTeamRow, 

21 TrainingTextRow, 

22) 

23from kwai_bc_training.trainings.training import TrainingEntity, TrainingIdentifier 

24from kwai_bc_training.trainings.training_query import TrainingQuery 

25from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity 

26 

27 

28@dataclass(kw_only=True, frozen=True, slots=True) 

29class TrainingQueryRow(JoinedTableRow): 

30 """A data transfer object for the training query.""" 

31 

32 text: TrainingTextRow 

33 training: TrainingRow 

34 owner: OwnerRow 

35 team: TeamRow 

36 training_schedule: TrainingScheduleRow 

37 training_schedule_owner: OwnerRow 

38 

39 @classmethod 

40 def create_entity(cls, rows: list[Self]) -> TrainingEntity: 

41 """Create a training entity from a group of rows.""" 

42 if rows[0].training_schedule.id is not None: 

43 training_schema = rows[0].training_schedule.create_entity( 

44 rows[0].team.create_entity(), 

45 rows[0].training_schedule_owner.create_owner(), 

46 ) 

47 else: 

48 training_schema = None 

49 return rows[0].training.create_entity( 

50 tuple([row.text.create_text(row.owner.create_owner()) for row in rows]), 

51 training_schema, 

52 ) 

53 

54 

55class TrainingDbQuery(TrainingQuery, DatabaseQuery): 

56 """A database query for trainings.""" 

57 

58 def __init__(self, database: Database): 

59 self._main_query = database.create_query_factory().select() 

60 super().__init__(database) 

61 

62 def init(self): 

63 # This query will be used as CTE, so only joins the tables that are needed 

64 # for counting and limiting results. 

65 self._query.from_(TrainingRow.__table_name__).left_join( 

66 TrainingScheduleRow.__table_name__, 

67 on( 

68 TrainingRow.column("training_schedule_id"), 

69 TrainingScheduleRow.column("id"), 

70 ), 

71 ) 

72 self._main_query = ( 

73 self._main_query.from_(TrainingRow.__table_name__) 

74 .columns(*self.columns) 

75 .with_("limited", self._query) 

76 .right_join("limited", on("limited.id", TrainingRow.column("id"))) 

77 .left_join( 

78 TrainingScheduleRow.__table_name__, 

79 on( 

80 TrainingRow.column("training_schedule_id"), 

81 TrainingScheduleRow.column("id"), 

82 ), 

83 ) 

84 .left_join( 

85 alias(OwnerRow.__table_name__, "training_schema_owners"), 

86 on(TrainingScheduleRow.column("user_id"), "training_schema_owners.id"), 

87 ) 

88 .left_join( 

89 TeamRow.__table_name__, 

90 on(TeamRow.column("id"), TrainingScheduleRow.column("team_id")), 

91 ) 

92 .join( 

93 TrainingTextRow.__table_name__, 

94 on( 

95 TrainingTextRow.column("training_id"), 

96 TrainingRow.column("id"), 

97 ), 

98 ) 

99 .join( 

100 OwnerRow.__table_name__, 

101 on(OwnerRow.column("id"), TrainingTextRow.column("user_id")), 

102 ) 

103 ) 

104 

105 @property 

106 def columns(self): 

107 return TrainingQueryRow.get_aliases() 

108 

109 @property 

110 def count_column(self) -> str: 

111 return TrainingRow.column("id") 

112 

113 def filter_by_id(self, id_: TrainingIdentifier) -> "TrainingQuery": 

114 self._query.and_where(TrainingRow.field("id").eq(id_.value)) 

115 return self 

116 

117 def filter_by_year_month( 

118 self, year: int, month: int | None = None 

119 ) -> "TrainingQuery": 

120 condition = criteria( 

121 "{} = {}", func("YEAR", TrainingRow.column("start_date")), literal(year) 

122 ) 

123 if month is not None: 

124 condition = condition.and_( 

125 criteria( 

126 "{} = {}", 

127 func("MONTH", TrainingRow.column("start_date")), 

128 literal(month), 

129 ) 

130 ) 

131 self._query.and_where(group(condition)) 

132 return self 

133 

134 def filter_by_dates(self, start: Timestamp, end: Timestamp) -> "TrainingQuery": 

135 self._query.and_where( 

136 TrainingRow.field("start_date").between(str(start), str(end)) 

137 ) 

138 return self 

139 

140 def filter_by_coach(self, coach: CoachEntity) -> "TrainingQuery": 

141 inner_select = ( 

142 self._database.create_query_factory() 

143 .select() 

144 .columns(TrainingCoachRow.column("training_id")) 

145 .from_(TrainingCoachRow.__table_name__) 

146 .where(TrainingCoachRow.field("coach_id").eq(coach.id.value)) 

147 ) 

148 condition = TrainingRow.field("id").in_(express("{}", inner_select)) 

149 self._query.and_where(group(condition)) 

150 return self 

151 

152 def filter_by_team(self, team: TeamEntity) -> "TrainingQuery": 

153 inner_select = ( 

154 self._database.create_query_factory() 

155 .select() 

156 .columns(TrainingTeamRow.column("training_id")) 

157 .from_(TrainingTeamRow.__table_name__) 

158 .where(TrainingTeamRow.field("team_id").eq(team.id.value)) 

159 ) 

160 condition = TrainingRow.field("id").in_(express("{}", inner_select)) 

161 self._query.and_where(group(condition)) 

162 return self 

163 

164 def filter_by_training_schedule( 

165 self, training_schedule: TrainingScheduleEntity 

166 ) -> "TrainingQuery": 

167 self._query.and_where( 

168 TrainingRow.field("training_schedule_id").eq(training_schedule.id.value) 

169 ) 

170 return self 

171 

172 def filter_active(self) -> "TrainingQuery": 

173 self._query.and_where(TrainingRow.field("active").eq(1)) 

174 return self 

175 

176 def fetch( 

177 self, limit: int | None = None, offset: int | None = None 

178 ) -> AsyncIterator[Record]: 

179 self._query.limit(limit) 

180 self._query.offset(offset) 

181 self._query.columns(TrainingRow.column("id")) 

182 self._main_query.order_by(TrainingRow.column("id")) 

183 

184 return self._database.fetch(self._main_query) 

185 

186 def order_by_date(self) -> "TrainingQuery": 

187 self._query.order_by(TrainingRow.column("start_date"), "ASC") 

188 self._main_query.order_by(TrainingRow.column("start_date"), "ASC") 

189 return self