Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_repository.py: 94%
107 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module for implementing a training repository for a database."""
3from dataclasses import replace
4from typing import AsyncIterator, cast
6from kwai_core.db.database import Database
7from kwai_core.functions import async_groupby
8from sql_smith.functions import alias, express, field
10from kwai_bc_training.teams.team import TeamEntity
11from kwai_bc_training.trainings._tables import (
12 TrainingCoachRow,
13 TrainingRow,
14 TrainingTeamRow,
15 TrainingTextRow,
16)
17from kwai_bc_training.trainings.training import (
18 TrainingCoachEntity,
19 TrainingCoachIdentifier,
20 TrainingEntity,
21 TrainingIdentifier,
22)
23from kwai_bc_training.trainings.training_coach_db_query import TrainingCoachDbQuery
24from kwai_bc_training.trainings.training_db_query import (
25 TrainingDbQuery,
26 TrainingQueryRow,
27)
28from kwai_bc_training.trainings.training_query import TrainingQuery
29from kwai_bc_training.trainings.training_repository import (
30 TrainingNotFoundException,
31 TrainingRepository,
32)
33from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity
34from kwai_bc_training.trainings.training_team_db_query import TrainingTeamDbQuery
37class TrainingDbRepository(TrainingRepository):
38 """A training repository for a database."""
40 def __init__(self, database: Database):
41 """Initialize the repository.
43 Args:
44 database: The database for this repository.
45 """
46 self._database = database
48 def create_query(self) -> TrainingQuery:
49 return TrainingDbQuery(self._database)
51 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity:
52 query = self.create_query()
53 query.filter_by_id(id)
55 try:
56 row_iterator = self.get_all(query, 1)
57 entity = await anext(row_iterator)
58 except StopAsyncIteration:
59 raise TrainingNotFoundException(
60 f"Training with id {id} does not exist"
61 ) from None
62 return entity
64 async def get_all(
65 self,
66 query: TrainingQuery | None = None,
67 limit: int | None = None,
68 offset: int | None = None,
69 ) -> AsyncIterator[TrainingEntity]:
70 if query is None:
71 query = self.create_query()
73 trainings: dict[TrainingIdentifier, TrainingEntity] = {}
74 group_by_column = "training_id"
76 row_it = query.fetch(limit, offset)
77 async for _, group in async_groupby(
78 row_it, key=lambda row: row[group_by_column]
79 ):
80 mapped = list(map(TrainingQueryRow.map, group))
81 training = TrainingQueryRow.create_entity(mapped)
82 trainings[training.id] = training
84 # Get the coaches of all the trainings.
85 training_ids = trainings.keys()
86 coaches: dict[TrainingIdentifier, list[TrainingCoachEntity]] = {}
87 teams: dict[TrainingIdentifier, list[TeamEntity]] = {}
89 if len(training_ids) > 0:
90 training_coach_query = TrainingCoachDbQuery(
91 self._database
92 ).filter_by_trainings(*training_ids)
93 coaches = await training_coach_query.fetch_coaches()
95 # Get the teams of all trainings
96 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings(
97 *training_ids
98 )
99 teams = await team_query.fetch_teams()
101 for training in trainings.values():
102 training_coaches = frozenset(
103 coaches.get(cast(TrainingIdentifier, training.id), [])
104 )
105 training_teams = frozenset(
106 teams.get(cast(TrainingIdentifier, training.id), [])
107 )
108 if len(training_coaches) > 0 or len(training_teams) > 0:
109 yield replace(training, coaches=training_coaches, teams=training_teams)
110 else:
111 yield training
113 async def create(self, training: TrainingEntity) -> TrainingEntity:
114 new_id = await self._database.insert(
115 TrainingRow.__table_name__, TrainingRow.persist(training)
116 )
117 result = training.set_id(TrainingIdentifier(new_id))
119 content_rows = [
120 TrainingTextRow.persist(result, content) for content in training.texts
121 ]
123 await self._database.insert(TrainingTextRow.__table_name__, *content_rows)
125 result = await self._insert_coaches(result)
127 await self._insert_teams(result)
129 return result
131 async def update(self, training: TrainingEntity) -> TrainingEntity:
132 # Update the training
133 await self._database.update(
134 training.id.value,
135 TrainingRow.__table_name__,
136 TrainingRow.persist(training),
137 )
139 # Update the text, first delete, then insert again.
140 await self._delete_contents(training)
141 content_rows = [
142 TrainingTextRow.persist(training, content) for content in training.texts
143 ]
144 await self._database.insert(TrainingTextRow.__table_name__, *content_rows)
146 # Update coaches, first delete, then insert again.
147 await self._delete_coaches(training)
148 training = await self._insert_coaches(training)
150 # Update teams, first delete, then insert again.
151 await self._delete_teams(training)
152 await self._insert_teams(training)
154 return training
156 async def _insert_coaches(self, training: TrainingEntity) -> TrainingEntity:
157 """Insert the related coaches.
159 Todo: Create a training coach repository and move this code to the use case.
160 """
161 coaches = []
162 for training_coach in training.coaches:
163 training_coach_row = TrainingCoachRow.persist(training, training_coach)
164 new_training_coach_id = await self._database.insert(
165 TrainingCoachRow.__table_name__, training_coach_row
166 )
167 coaches.append(
168 replace(
169 training_coach, id=TrainingCoachIdentifier(new_training_coach_id)
170 )
171 )
172 return replace(training, coaches=frozenset(coaches))
174 async def _insert_teams(self, training: TrainingEntity):
175 """Insert the related teams."""
176 training_team_rows = [
177 TrainingTeamRow.persist(training, team) for team in training.teams
178 ]
179 if training_team_rows:
180 await self._database.insert(
181 TrainingTeamRow.__table_name__, *training_team_rows
182 )
184 async def _delete_coaches(self, training: TrainingEntity):
185 """Delete coaches of the training."""
186 delete_coaches_query = (
187 self._database.create_query_factory()
188 .delete(TrainingCoachRow.__table_name__)
189 .where(field("training_id").eq(training.id.value))
190 )
191 await self._database.execute(delete_coaches_query)
193 async def _delete_contents(self, training: TrainingEntity):
194 """Delete text contents of the training."""
195 delete_contents_query = (
196 self._database.create_query_factory()
197 .delete(TrainingTextRow.__table_name__)
198 .where(field("training_id").eq(training.id.value))
199 )
200 await self._database.execute(delete_contents_query)
202 async def _delete_teams(self, training: TrainingEntity):
203 """Delete the teams of the training."""
204 delete_teams_query = (
205 self._database.create_query_factory()
206 .delete(TrainingTeamRow.__table_name__)
207 .where(field("training_id").eq(training.id.value))
208 )
209 await self._database.execute(delete_teams_query)
211 async def delete(self, training: TrainingEntity) -> None:
212 await self._database.delete(training.id.value, TrainingRow.__table_name__)
214 await self._delete_contents(training)
215 await self._delete_coaches(training)
216 await self._delete_teams(training)
218 async def reset_schedule(
219 self, training_schedule: TrainingScheduleEntity, delete: bool = False
220 ) -> None:
221 trainings_query = (
222 self._database.create_query_factory()
223 .select(TrainingRow.column("id"))
224 .from_(TrainingRow.__table_name__)
225 .and_where(field("training_schedule_id").eq(training_schedule.id.value))
226 )
227 if delete:
228 delete_teams = (
229 self._database.create_query_factory()
230 .delete(TrainingTeamRow.__table_name__)
231 .and_where(TrainingTeamRow.field("training_id").in_(trainings_query))
232 )
233 await self._database.execute(delete_teams)
235 delete_coaches = (
236 self._database.create_query_factory()
237 .delete(TrainingCoachRow.__table_name__)
238 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query))
239 )
240 await self._database.execute(delete_coaches)
242 delete_contents = (
243 self._database.create_query_factory()
244 .delete(TrainingTextRow.__table_name__)
245 .and_where(TrainingTextRow.field("training_id").in_(trainings_query))
246 )
247 await self._database.execute(delete_contents)
248 else:
249 # Because it is not allowed to update the table that is used
250 # in a sub query, we need to create a copy.
251 copy_trainings_query = (
252 self._database.create_query_factory()
253 .select("t.id")
254 .from_(alias(express("({})", trainings_query), "t"))
255 )
256 update_trainings = (
257 self._database.create_query_factory()
258 .update(TrainingRow.__table_name__, {"training_schedule_id": None})
259 .where(TrainingRow.field("id").in_(copy_trainings_query))
260 )
261 await self._database.execute(update_trainings)