Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_repository.py: 94%

103 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module for implementing a training repository for a database.""" 

2 

3from dataclasses import replace 

4from typing import AsyncIterator, cast 

5 

6from kwai_core.db.database import Database 

7from kwai_core.functions import async_groupby 

8from sql_smith.functions import alias, express, field 

9 

10from kwai_bc_training.teams.team import TeamEntity 

11from kwai_bc_training.trainings._tables import ( 

12 TrainingCoachRow, 

13 TrainingRow, 

14 TrainingTeamRow, 

15 TrainingTextRow, 

16) 

17from kwai_bc_training.trainings.training import ( 

18 TrainingCoachEntity, 

19 TrainingEntity, 

20 TrainingIdentifier, 

21) 

22from kwai_bc_training.trainings.training_coach_db_query import TrainingCoachDbQuery 

23from kwai_bc_training.trainings.training_db_query import ( 

24 TrainingDbQuery, 

25 TrainingQueryRow, 

26) 

27from kwai_bc_training.trainings.training_query import TrainingQuery 

28from kwai_bc_training.trainings.training_repository import ( 

29 TrainingNotFoundException, 

30 TrainingRepository, 

31) 

32from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity 

33from kwai_bc_training.trainings.training_team_db_query import TrainingTeamDbQuery 

34 

35 

36class TrainingDbRepository(TrainingRepository): 

37 """A training repository for a database.""" 

38 

39 def __init__(self, database: Database): 

40 """Initialize the repository. 

41 

42 Args: 

43 database: The database for this repository. 

44 """ 

45 self._database = database 

46 

47 def create_query(self) -> TrainingQuery: 

48 return TrainingDbQuery(self._database) 

49 

50 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity: 

51 query = self.create_query() 

52 query.filter_by_id(id) 

53 

54 try: 

55 row_iterator = self.get_all(query, 1) 

56 entity = await anext(row_iterator) 

57 except StopAsyncIteration: 

58 raise TrainingNotFoundException( 

59 f"Training with id {id} does not exist" 

60 ) from None 

61 return entity 

62 

63 async def get_all( 

64 self, 

65 query: TrainingQuery | None = None, 

66 limit: int | None = None, 

67 offset: int | None = None, 

68 ) -> AsyncIterator[TrainingEntity]: 

69 if query is None: 

70 query = self.create_query() 

71 

72 trainings: dict[TrainingIdentifier, TrainingEntity] = {} 

73 group_by_column = "training_id" 

74 

75 row_it = query.fetch(limit, offset) 

76 async for _, group in async_groupby( 

77 row_it, key=lambda row: row[group_by_column] 

78 ): 

79 mapped = list(map(TrainingQueryRow.map, group)) 

80 training = TrainingQueryRow.create_entity(mapped) 

81 trainings[training.id] = training 

82 

83 # Get the coaches of all the trainings. 

84 training_ids = trainings.keys() 

85 coaches: dict[TrainingIdentifier, list[TrainingCoachEntity]] = {} 

86 teams: dict[TrainingIdentifier, list[TeamEntity]] = {} 

87 

88 if len(training_ids) > 0: 

89 training_coach_query = TrainingCoachDbQuery( 

90 self._database 

91 ).filter_by_trainings(*training_ids) 

92 coaches = await training_coach_query.fetch_coaches() 

93 

94 # Get the teams of all trainings 

95 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings( 

96 *training_ids 

97 ) 

98 teams = await team_query.fetch_teams() 

99 

100 for training in trainings.values(): 

101 training_coaches = frozenset( 

102 coaches.get(cast(TrainingIdentifier, training.id), []) 

103 ) 

104 training_teams = frozenset( 

105 teams.get(cast(TrainingIdentifier, training.id), []) 

106 ) 

107 if len(training_coaches) > 0 or len(training_teams) > 0: 

108 yield replace(training, coaches=training_coaches, teams=training_teams) 

109 else: 

110 yield training 

111 

112 async def create(self, training: TrainingEntity) -> TrainingEntity: 

113 new_id = await self._database.insert( 

114 TrainingRow.__table_name__, TrainingRow.persist(training) 

115 ) 

116 result = training.set_id(TrainingIdentifier(new_id)) 

117 

118 content_rows = [ 

119 TrainingTextRow.persist(result, content) for content in training.texts 

120 ] 

121 

122 await self._database.insert(TrainingTextRow.__table_name__, *content_rows) 

123 await self._insert_coaches(result) 

124 await self._insert_teams(result) 

125 

126 return result 

127 

128 async def update(self, training: TrainingEntity) -> None: 

129 # Update the training 

130 await self._database.update( 

131 training.id.value, 

132 TrainingRow.__table_name__, 

133 TrainingRow.persist(training), 

134 ) 

135 

136 # Update the text, first delete, then insert again. 

137 await self._delete_contents(training) 

138 content_rows = [ 

139 TrainingTextRow.persist(training, content) for content in training.texts 

140 ] 

141 await self._database.insert(TrainingTextRow.__table_name__, *content_rows) 

142 

143 # Update coaches, first delete, then insert again. 

144 await self._delete_coaches(training) 

145 await self._insert_coaches(training) 

146 

147 # Update teams, first delete, then insert again. 

148 await self._delete_teams(training) 

149 await self._insert_teams(training) 

150 

151 async def _insert_coaches(self, training: TrainingEntity): 

152 """Insert the related coaches.""" 

153 training_coach_rows = [ 

154 TrainingCoachRow.persist(training, training_coach) 

155 for training_coach in training.coaches 

156 ] 

157 if training_coach_rows: 

158 await self._database.insert( 

159 TrainingCoachRow.__table_name__, *training_coach_rows 

160 ) 

161 

162 async def _insert_teams(self, training: TrainingEntity): 

163 """Insert the related teams.""" 

164 training_team_rows = [ 

165 TrainingTeamRow.persist(training, team) for team in training.teams 

166 ] 

167 if training_team_rows: 

168 await self._database.insert( 

169 TrainingTeamRow.__table_name__, *training_team_rows 

170 ) 

171 

172 async def _delete_coaches(self, training: TrainingEntity): 

173 """Delete coaches of the training.""" 

174 delete_coaches_query = ( 

175 self._database.create_query_factory() 

176 .delete(TrainingCoachRow.__table_name__) 

177 .where(field("training_id").eq(training.id.value)) 

178 ) 

179 await self._database.execute(delete_coaches_query) 

180 

181 async def _delete_contents(self, training: TrainingEntity): 

182 """Delete text contents of the training.""" 

183 delete_contents_query = ( 

184 self._database.create_query_factory() 

185 .delete(TrainingTextRow.__table_name__) 

186 .where(field("training_id").eq(training.id.value)) 

187 ) 

188 await self._database.execute(delete_contents_query) 

189 

190 async def _delete_teams(self, training: TrainingEntity): 

191 """Delete the teams of the training.""" 

192 delete_teams_query = ( 

193 self._database.create_query_factory() 

194 .delete(TrainingTeamRow.__table_name__) 

195 .where(field("training_id").eq(training.id.value)) 

196 ) 

197 await self._database.execute(delete_teams_query) 

198 

199 async def delete(self, training: TrainingEntity) -> None: 

200 await self._database.delete(training.id.value, TrainingRow.__table_name__) 

201 

202 await self._delete_contents(training) 

203 await self._delete_coaches(training) 

204 await self._delete_teams(training) 

205 

206 async def reset_schedule( 

207 self, training_schedule: TrainingScheduleEntity, delete: bool = False 

208 ) -> None: 

209 trainings_query = ( 

210 self._database.create_query_factory() 

211 .select(TrainingRow.column("id")) 

212 .from_(TrainingRow.__table_name__) 

213 .and_where(field("training_schedule_id").eq(training_schedule.id.value)) 

214 ) 

215 if delete: 

216 delete_teams = ( 

217 self._database.create_query_factory() 

218 .delete(TrainingTeamRow.__table_name__) 

219 .and_where(TrainingTeamRow.field("training_id").in_(trainings_query)) 

220 ) 

221 await self._database.execute(delete_teams) 

222 

223 delete_coaches = ( 

224 self._database.create_query_factory() 

225 .delete(TrainingCoachRow.__table_name__) 

226 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query)) 

227 ) 

228 await self._database.execute(delete_coaches) 

229 

230 delete_contents = ( 

231 self._database.create_query_factory() 

232 .delete(TrainingTextRow.__table_name__) 

233 .and_where(TrainingTextRow.field("training_id").in_(trainings_query)) 

234 ) 

235 await self._database.execute(delete_contents) 

236 else: 

237 # Because it is not allowed to update the table that is used 

238 # in a sub query, we need to create a copy. 

239 copy_trainings_query = ( 

240 self._database.create_query_factory() 

241 .select("t.id") 

242 .from_(alias(express("({})", trainings_query), "t")) 

243 ) 

244 update_trainings = ( 

245 self._database.create_query_factory() 

246 .update(TrainingRow.__table_name__, {"training_schedule_id": None}) 

247 .where(TrainingRow.field("id").in_(copy_trainings_query)) 

248 ) 

249 await self._database.execute(update_trainings)