Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_repository.py: 94%

104 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module for implementing a training repository for a database.""" 

2 

3from dataclasses import replace 

4from typing import AsyncIterator, cast 

5 

6from kwai_core.db.database import Database 

7from kwai_core.functions import async_groupby 

8from sql_smith.functions import alias, express, field 

9 

10from kwai_bc_training.teams.team import TeamEntity 

11from kwai_bc_training.trainings._tables import ( 

12 TrainingCoachRow, 

13 TrainingRow, 

14 TrainingTeamRow, 

15 TrainingTextRow, 

16) 

17from kwai_bc_training.trainings.training import TrainingEntity, TrainingIdentifier 

18from kwai_bc_training.trainings.training_coach_db_query import TrainingCoachDbQuery 

19from kwai_bc_training.trainings.training_db_query import ( 

20 TrainingDbQuery, 

21 TrainingQueryRow, 

22) 

23from kwai_bc_training.trainings.training_query import TrainingQuery 

24from kwai_bc_training.trainings.training_repository import ( 

25 TrainingNotFoundException, 

26 TrainingRepository, 

27) 

28from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity 

29from kwai_bc_training.trainings.training_team_db_query import TrainingTeamDbQuery 

30from kwai_bc_training.trainings.value_objects import TrainingCoach 

31 

32 

33class TrainingDbRepository(TrainingRepository): 

34 """A training repository for a database.""" 

35 

36 def __init__(self, database: Database): 

37 """Initialize the repository. 

38 

39 Args: 

40 database: The database for this repository. 

41 """ 

42 self._database = database 

43 

44 def create_query(self) -> TrainingQuery: 

45 return TrainingDbQuery(self._database) 

46 

47 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity: 

48 query = self.create_query() 

49 query.filter_by_id(id) 

50 

51 try: 

52 row_iterator = self.get_all(query, 1) 

53 entity = await anext(row_iterator) 

54 except StopAsyncIteration: 

55 raise TrainingNotFoundException( 

56 f"Training with id {id} does not exist" 

57 ) from None 

58 return entity 

59 

60 async def get_all( 

61 self, 

62 query: TrainingQuery | None = None, 

63 limit: int | None = None, 

64 offset: int | None = None, 

65 ) -> AsyncIterator[TrainingEntity]: 

66 if query is None: 

67 query = self.create_query() 

68 

69 trainings: dict[TrainingIdentifier, TrainingEntity] = {} 

70 group_by_column = "training_id" 

71 

72 row_it = query.fetch(limit, offset) 

73 async for _, group in async_groupby( 

74 row_it, key=lambda row: row[group_by_column] 

75 ): 

76 mapped = list(map(TrainingQueryRow.map, group)) 

77 training = TrainingQueryRow.create_entity(mapped) 

78 trainings[training.id] = training 

79 

80 # Get the coaches of all the trainings. 

81 training_ids = trainings.keys() 

82 coaches: dict[TrainingIdentifier, list[TrainingCoach]] = {} 

83 teams: dict[TrainingIdentifier, list[TeamEntity]] = {} 

84 

85 if len(training_ids) > 0: 

86 training_coach_query = TrainingCoachDbQuery( 

87 self._database 

88 ).filter_by_trainings(*training_ids) 

89 coaches = await training_coach_query.fetch_coaches() 

90 

91 # Get the teams of all trainings 

92 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings( 

93 *training_ids 

94 ) 

95 teams = await team_query.fetch_teams() 

96 

97 for training in trainings.values(): 

98 training_coaches = frozenset( 

99 coaches.get(cast(TrainingIdentifier, training.id), []) 

100 ) 

101 training_teams = frozenset( 

102 teams.get(cast(TrainingIdentifier, training.id), []) 

103 ) 

104 if len(training_coaches) > 0 or len(training_teams) > 0: 

105 yield replace(training, coaches=training_coaches, teams=training_teams) 

106 else: 

107 yield training 

108 

109 async def create(self, training: TrainingEntity) -> TrainingEntity: 

110 new_id = await self._database.insert( 

111 TrainingRow.__table_name__, TrainingRow.persist(training) 

112 ) 

113 result = training.set_id(TrainingIdentifier(new_id)) 

114 

115 content_rows = [ 

116 TrainingTextRow.persist(result, content) for content in training.texts 

117 ] 

118 

119 await self._database.insert(TrainingTextRow.__table_name__, *content_rows) 

120 await self._insert_coaches(result) 

121 await self._insert_teams(result) 

122 

123 return result 

124 

125 async def update(self, training: TrainingEntity) -> None: 

126 # Update the training 

127 await self._database.update( 

128 training.id.value, 

129 TrainingRow.__table_name__, 

130 TrainingRow.persist(training), 

131 ) 

132 

133 # Update the text, first delete, then insert again. 

134 await self._delete_contents(training) 

135 content_rows = [ 

136 TrainingTextRow.persist(training, content) for content in training.texts 

137 ] 

138 await self._database.insert(TrainingTextRow.__table_name__, *content_rows) 

139 

140 # Update coaches, first delete, then insert again. 

141 await self._delete_coaches(training) 

142 await self._insert_coaches(training) 

143 

144 # Update teams, first delete, then insert again. 

145 await self._delete_teams(training) 

146 await self._insert_teams(training) 

147 

148 async def _insert_coaches(self, training: TrainingEntity): 

149 """Insert the related coaches.""" 

150 training_coach_rows = [ 

151 TrainingCoachRow.persist(training, training_coach) 

152 for training_coach in training.coaches 

153 ] 

154 if training_coach_rows: 

155 await self._database.insert( 

156 TrainingCoachRow.__table_name__, *training_coach_rows 

157 ) 

158 

159 async def _insert_teams(self, training: TrainingEntity): 

160 """Insert the related teams.""" 

161 training_team_rows = [ 

162 TrainingTeamRow.persist(training, team) for team in training.teams 

163 ] 

164 if training_team_rows: 

165 await self._database.insert( 

166 TrainingTeamRow.__table_name__, *training_team_rows 

167 ) 

168 

169 async def _delete_coaches(self, training: TrainingEntity): 

170 """Delete coaches of the training.""" 

171 delete_coaches_query = ( 

172 self._database.create_query_factory() 

173 .delete(TrainingCoachRow.__table_name__) 

174 .where(field("training_id").eq(training.id.value)) 

175 ) 

176 await self._database.execute(delete_coaches_query) 

177 

178 async def _delete_contents(self, training: TrainingEntity): 

179 """Delete text contents of the training.""" 

180 delete_contents_query = ( 

181 self._database.create_query_factory() 

182 .delete(TrainingTextRow.__table_name__) 

183 .where(field("training_id").eq(training.id.value)) 

184 ) 

185 await self._database.execute(delete_contents_query) 

186 

187 async def _delete_teams(self, training: TrainingEntity): 

188 """Delete the teams of the training.""" 

189 delete_teams_query = ( 

190 self._database.create_query_factory() 

191 .delete(TrainingTeamRow.__table_name__) 

192 .where(field("training_id").eq(training.id.value)) 

193 ) 

194 await self._database.execute(delete_teams_query) 

195 

196 async def delete(self, training: TrainingEntity) -> None: 

197 await self._database.delete(training.id.value, TrainingRow.__table_name__) 

198 

199 await self._delete_contents(training) 

200 await self._delete_coaches(training) 

201 await self._delete_teams(training) 

202 

203 async def reset_schedule( 

204 self, training_schedule: TrainingScheduleEntity, delete: bool = False 

205 ) -> None: 

206 trainings_query = ( 

207 self._database.create_query_factory() 

208 .select(TrainingRow.column("id")) 

209 .from_(TrainingRow.__table_name__) 

210 .and_where(field("training_schedule_id").eq(training_schedule.id.value)) 

211 ) 

212 if delete: 

213 delete_teams = ( 

214 self._database.create_query_factory() 

215 .delete(TrainingTeamRow.__table_name__) 

216 .and_where(TrainingTeamRow.field("training_id").in_(trainings_query)) 

217 ) 

218 await self._database.execute(delete_teams) 

219 

220 delete_coaches = ( 

221 self._database.create_query_factory() 

222 .delete(TrainingCoachRow.__table_name__) 

223 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query)) 

224 ) 

225 await self._database.execute(delete_coaches) 

226 

227 delete_contents = ( 

228 self._database.create_query_factory() 

229 .delete(TrainingTextRow.__table_name__) 

230 .and_where(TrainingTextRow.field("training_id").in_(trainings_query)) 

231 ) 

232 await self._database.execute(delete_contents) 

233 else: 

234 # Because it is not allowed to update the table that is used 

235 # in a sub query, we need to create a copy. 

236 copy_trainings_query = ( 

237 self._database.create_query_factory() 

238 .select("t.id") 

239 .from_(alias(express("({})", trainings_query), "t")) 

240 ) 

241 update_trainings = ( 

242 self._database.create_query_factory() 

243 .update(TrainingRow.__table_name__, {"training_schedule_id": None}) 

244 .where(TrainingRow.field("id").in_(copy_trainings_query)) 

245 ) 

246 await self._database.execute(update_trainings)