Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_repository.py: 94%

107 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module for implementing a training repository for a database.""" 

2 

3from dataclasses import replace 

4from typing import AsyncIterator, cast 

5 

6from kwai_core.db.database import Database 

7from kwai_core.functions import async_groupby 

8from sql_smith.functions import alias, express, field 

9 

10from kwai_bc_training.teams.team import TeamEntity 

11from kwai_bc_training.trainings._tables import ( 

12 TrainingCoachRow, 

13 TrainingRow, 

14 TrainingTeamRow, 

15 TrainingTextRow, 

16) 

17from kwai_bc_training.trainings.training import ( 

18 TrainingCoachEntity, 

19 TrainingCoachIdentifier, 

20 TrainingEntity, 

21 TrainingIdentifier, 

22) 

23from kwai_bc_training.trainings.training_coach_db_query import TrainingCoachDbQuery 

24from kwai_bc_training.trainings.training_db_query import ( 

25 TrainingDbQuery, 

26 TrainingQueryRow, 

27) 

28from kwai_bc_training.trainings.training_query import TrainingQuery 

29from kwai_bc_training.trainings.training_repository import ( 

30 TrainingNotFoundException, 

31 TrainingRepository, 

32) 

33from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity 

34from kwai_bc_training.trainings.training_team_db_query import TrainingTeamDbQuery 

35 

36 

37class TrainingDbRepository(TrainingRepository): 

38 """A training repository for a database.""" 

39 

40 def __init__(self, database: Database): 

41 """Initialize the repository. 

42 

43 Args: 

44 database: The database for this repository. 

45 """ 

46 self._database = database 

47 

48 def create_query(self) -> TrainingQuery: 

49 return TrainingDbQuery(self._database) 

50 

51 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity: 

52 query = self.create_query() 

53 query.filter_by_id(id) 

54 

55 try: 

56 row_iterator = self.get_all(query, 1) 

57 entity = await anext(row_iterator) 

58 except StopAsyncIteration: 

59 raise TrainingNotFoundException( 

60 f"Training with id {id} does not exist" 

61 ) from None 

62 return entity 

63 

64 async def get_all( 

65 self, 

66 query: TrainingQuery | None = None, 

67 limit: int | None = None, 

68 offset: int | None = None, 

69 ) -> AsyncIterator[TrainingEntity]: 

70 if query is None: 

71 query = self.create_query() 

72 

73 trainings: dict[TrainingIdentifier, TrainingEntity] = {} 

74 group_by_column = "training_id" 

75 

76 row_it = query.fetch(limit, offset) 

77 async for _, group in async_groupby( 

78 row_it, key=lambda row: row[group_by_column] 

79 ): 

80 mapped = list(map(TrainingQueryRow.map, group)) 

81 training = TrainingQueryRow.create_entity(mapped) 

82 trainings[training.id] = training 

83 

84 # Get the coaches of all the trainings. 

85 training_ids = trainings.keys() 

86 coaches: dict[TrainingIdentifier, list[TrainingCoachEntity]] = {} 

87 teams: dict[TrainingIdentifier, list[TeamEntity]] = {} 

88 

89 if len(training_ids) > 0: 

90 training_coach_query = TrainingCoachDbQuery( 

91 self._database 

92 ).filter_by_trainings(*training_ids) 

93 coaches = await training_coach_query.fetch_coaches() 

94 

95 # Get the teams of all trainings 

96 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings( 

97 *training_ids 

98 ) 

99 teams = await team_query.fetch_teams() 

100 

101 for training in trainings.values(): 

102 training_coaches = frozenset( 

103 coaches.get(cast(TrainingIdentifier, training.id), []) 

104 ) 

105 training_teams = frozenset( 

106 teams.get(cast(TrainingIdentifier, training.id), []) 

107 ) 

108 if len(training_coaches) > 0 or len(training_teams) > 0: 

109 yield replace(training, coaches=training_coaches, teams=training_teams) 

110 else: 

111 yield training 

112 

113 async def create(self, training: TrainingEntity) -> TrainingEntity: 

114 new_id = await self._database.insert( 

115 TrainingRow.__table_name__, TrainingRow.persist(training) 

116 ) 

117 result = training.set_id(TrainingIdentifier(new_id)) 

118 

119 content_rows = [ 

120 TrainingTextRow.persist(result, content) for content in training.texts 

121 ] 

122 

123 await self._database.insert(TrainingTextRow.__table_name__, *content_rows) 

124 

125 result = await self._insert_coaches(result) 

126 

127 await self._insert_teams(result) 

128 

129 return result 

130 

131 async def update(self, training: TrainingEntity) -> TrainingEntity: 

132 # Update the training 

133 await self._database.update( 

134 training.id.value, 

135 TrainingRow.__table_name__, 

136 TrainingRow.persist(training), 

137 ) 

138 

139 # Update the text, first delete, then insert again. 

140 await self._delete_contents(training) 

141 content_rows = [ 

142 TrainingTextRow.persist(training, content) for content in training.texts 

143 ] 

144 await self._database.insert(TrainingTextRow.__table_name__, *content_rows) 

145 

146 # Update coaches, first delete, then insert again. 

147 await self._delete_coaches(training) 

148 training = await self._insert_coaches(training) 

149 

150 # Update teams, first delete, then insert again. 

151 await self._delete_teams(training) 

152 await self._insert_teams(training) 

153 

154 return training 

155 

156 async def _insert_coaches(self, training: TrainingEntity) -> TrainingEntity: 

157 """Insert the related coaches. 

158 

159 Todo: Create a training coach repository and move this code to the use case. 

160 """ 

161 coaches = [] 

162 for training_coach in training.coaches: 

163 training_coach_row = TrainingCoachRow.persist(training, training_coach) 

164 new_training_coach_id = await self._database.insert( 

165 TrainingCoachRow.__table_name__, training_coach_row 

166 ) 

167 coaches.append( 

168 replace( 

169 training_coach, id=TrainingCoachIdentifier(new_training_coach_id) 

170 ) 

171 ) 

172 return replace(training, coaches=frozenset(coaches)) 

173 

174 async def _insert_teams(self, training: TrainingEntity): 

175 """Insert the related teams.""" 

176 training_team_rows = [ 

177 TrainingTeamRow.persist(training, team) for team in training.teams 

178 ] 

179 if training_team_rows: 

180 await self._database.insert( 

181 TrainingTeamRow.__table_name__, *training_team_rows 

182 ) 

183 

184 async def _delete_coaches(self, training: TrainingEntity): 

185 """Delete coaches of the training.""" 

186 delete_coaches_query = ( 

187 self._database.create_query_factory() 

188 .delete(TrainingCoachRow.__table_name__) 

189 .where(field("training_id").eq(training.id.value)) 

190 ) 

191 await self._database.execute(delete_coaches_query) 

192 

193 async def _delete_contents(self, training: TrainingEntity): 

194 """Delete text contents of the training.""" 

195 delete_contents_query = ( 

196 self._database.create_query_factory() 

197 .delete(TrainingTextRow.__table_name__) 

198 .where(field("training_id").eq(training.id.value)) 

199 ) 

200 await self._database.execute(delete_contents_query) 

201 

202 async def _delete_teams(self, training: TrainingEntity): 

203 """Delete the teams of the training.""" 

204 delete_teams_query = ( 

205 self._database.create_query_factory() 

206 .delete(TrainingTeamRow.__table_name__) 

207 .where(field("training_id").eq(training.id.value)) 

208 ) 

209 await self._database.execute(delete_teams_query) 

210 

211 async def delete(self, training: TrainingEntity) -> None: 

212 await self._database.delete(training.id.value, TrainingRow.__table_name__) 

213 

214 await self._delete_contents(training) 

215 await self._delete_coaches(training) 

216 await self._delete_teams(training) 

217 

218 async def reset_schedule( 

219 self, training_schedule: TrainingScheduleEntity, delete: bool = False 

220 ) -> None: 

221 trainings_query = ( 

222 self._database.create_query_factory() 

223 .select(TrainingRow.column("id")) 

224 .from_(TrainingRow.__table_name__) 

225 .and_where(field("training_schedule_id").eq(training_schedule.id.value)) 

226 ) 

227 if delete: 

228 delete_teams = ( 

229 self._database.create_query_factory() 

230 .delete(TrainingTeamRow.__table_name__) 

231 .and_where(TrainingTeamRow.field("training_id").in_(trainings_query)) 

232 ) 

233 await self._database.execute(delete_teams) 

234 

235 delete_coaches = ( 

236 self._database.create_query_factory() 

237 .delete(TrainingCoachRow.__table_name__) 

238 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query)) 

239 ) 

240 await self._database.execute(delete_coaches) 

241 

242 delete_contents = ( 

243 self._database.create_query_factory() 

244 .delete(TrainingTextRow.__table_name__) 

245 .and_where(TrainingTextRow.field("training_id").in_(trainings_query)) 

246 ) 

247 await self._database.execute(delete_contents) 

248 else: 

249 # Because it is not allowed to update the table that is used 

250 # in a sub query, we need to create a copy. 

251 copy_trainings_query = ( 

252 self._database.create_query_factory() 

253 .select("t.id") 

254 .from_(alias(express("({})", trainings_query), "t")) 

255 ) 

256 update_trainings = ( 

257 self._database.create_query_factory() 

258 .update(TrainingRow.__table_name__, {"training_schedule_id": None}) 

259 .where(TrainingRow.field("id").in_(copy_trainings_query)) 

260 ) 

261 await self._database.execute(update_trainings)