Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_repository.py: 94%
103 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module for implementing a training repository for a database."""
3from dataclasses import replace
4from typing import AsyncIterator, cast
6from kwai_core.db.database import Database
7from kwai_core.functions import async_groupby
8from sql_smith.functions import alias, express, field
10from kwai_bc_training.teams.team import TeamEntity
11from kwai_bc_training.trainings._tables import (
12 TrainingCoachRow,
13 TrainingRow,
14 TrainingTeamRow,
15 TrainingTextRow,
16)
17from kwai_bc_training.trainings.training import (
18 TrainingCoachEntity,
19 TrainingEntity,
20 TrainingIdentifier,
21)
22from kwai_bc_training.trainings.training_coach_db_query import TrainingCoachDbQuery
23from kwai_bc_training.trainings.training_db_query import (
24 TrainingDbQuery,
25 TrainingQueryRow,
26)
27from kwai_bc_training.trainings.training_query import TrainingQuery
28from kwai_bc_training.trainings.training_repository import (
29 TrainingNotFoundException,
30 TrainingRepository,
31)
32from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity
33from kwai_bc_training.trainings.training_team_db_query import TrainingTeamDbQuery
36class TrainingDbRepository(TrainingRepository):
37 """A training repository for a database."""
39 def __init__(self, database: Database):
40 """Initialize the repository.
42 Args:
43 database: The database for this repository.
44 """
45 self._database = database
47 def create_query(self) -> TrainingQuery:
48 return TrainingDbQuery(self._database)
50 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity:
51 query = self.create_query()
52 query.filter_by_id(id)
54 try:
55 row_iterator = self.get_all(query, 1)
56 entity = await anext(row_iterator)
57 except StopAsyncIteration:
58 raise TrainingNotFoundException(
59 f"Training with id {id} does not exist"
60 ) from None
61 return entity
63 async def get_all(
64 self,
65 query: TrainingQuery | None = None,
66 limit: int | None = None,
67 offset: int | None = None,
68 ) -> AsyncIterator[TrainingEntity]:
69 if query is None:
70 query = self.create_query()
72 trainings: dict[TrainingIdentifier, TrainingEntity] = {}
73 group_by_column = "training_id"
75 row_it = query.fetch(limit, offset)
76 async for _, group in async_groupby(
77 row_it, key=lambda row: row[group_by_column]
78 ):
79 mapped = list(map(TrainingQueryRow.map, group))
80 training = TrainingQueryRow.create_entity(mapped)
81 trainings[training.id] = training
83 # Get the coaches of all the trainings.
84 training_ids = trainings.keys()
85 coaches: dict[TrainingIdentifier, list[TrainingCoachEntity]] = {}
86 teams: dict[TrainingIdentifier, list[TeamEntity]] = {}
88 if len(training_ids) > 0:
89 training_coach_query = TrainingCoachDbQuery(
90 self._database
91 ).filter_by_trainings(*training_ids)
92 coaches = await training_coach_query.fetch_coaches()
94 # Get the teams of all trainings
95 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings(
96 *training_ids
97 )
98 teams = await team_query.fetch_teams()
100 for training in trainings.values():
101 training_coaches = frozenset(
102 coaches.get(cast(TrainingIdentifier, training.id), [])
103 )
104 training_teams = frozenset(
105 teams.get(cast(TrainingIdentifier, training.id), [])
106 )
107 if len(training_coaches) > 0 or len(training_teams) > 0:
108 yield replace(training, coaches=training_coaches, teams=training_teams)
109 else:
110 yield training
112 async def create(self, training: TrainingEntity) -> TrainingEntity:
113 new_id = await self._database.insert(
114 TrainingRow.__table_name__, TrainingRow.persist(training)
115 )
116 result = training.set_id(TrainingIdentifier(new_id))
118 content_rows = [
119 TrainingTextRow.persist(result, content) for content in training.texts
120 ]
122 await self._database.insert(TrainingTextRow.__table_name__, *content_rows)
123 await self._insert_coaches(result)
124 await self._insert_teams(result)
126 return result
128 async def update(self, training: TrainingEntity) -> None:
129 # Update the training
130 await self._database.update(
131 training.id.value,
132 TrainingRow.__table_name__,
133 TrainingRow.persist(training),
134 )
136 # Update the text, first delete, then insert again.
137 await self._delete_contents(training)
138 content_rows = [
139 TrainingTextRow.persist(training, content) for content in training.texts
140 ]
141 await self._database.insert(TrainingTextRow.__table_name__, *content_rows)
143 # Update coaches, first delete, then insert again.
144 await self._delete_coaches(training)
145 await self._insert_coaches(training)
147 # Update teams, first delete, then insert again.
148 await self._delete_teams(training)
149 await self._insert_teams(training)
151 async def _insert_coaches(self, training: TrainingEntity):
152 """Insert the related coaches."""
153 training_coach_rows = [
154 TrainingCoachRow.persist(training, training_coach)
155 for training_coach in training.coaches
156 ]
157 if training_coach_rows:
158 await self._database.insert(
159 TrainingCoachRow.__table_name__, *training_coach_rows
160 )
162 async def _insert_teams(self, training: TrainingEntity):
163 """Insert the related teams."""
164 training_team_rows = [
165 TrainingTeamRow.persist(training, team) for team in training.teams
166 ]
167 if training_team_rows:
168 await self._database.insert(
169 TrainingTeamRow.__table_name__, *training_team_rows
170 )
172 async def _delete_coaches(self, training: TrainingEntity):
173 """Delete coaches of the training."""
174 delete_coaches_query = (
175 self._database.create_query_factory()
176 .delete(TrainingCoachRow.__table_name__)
177 .where(field("training_id").eq(training.id.value))
178 )
179 await self._database.execute(delete_coaches_query)
181 async def _delete_contents(self, training: TrainingEntity):
182 """Delete text contents of the training."""
183 delete_contents_query = (
184 self._database.create_query_factory()
185 .delete(TrainingTextRow.__table_name__)
186 .where(field("training_id").eq(training.id.value))
187 )
188 await self._database.execute(delete_contents_query)
190 async def _delete_teams(self, training: TrainingEntity):
191 """Delete the teams of the training."""
192 delete_teams_query = (
193 self._database.create_query_factory()
194 .delete(TrainingTeamRow.__table_name__)
195 .where(field("training_id").eq(training.id.value))
196 )
197 await self._database.execute(delete_teams_query)
199 async def delete(self, training: TrainingEntity) -> None:
200 await self._database.delete(training.id.value, TrainingRow.__table_name__)
202 await self._delete_contents(training)
203 await self._delete_coaches(training)
204 await self._delete_teams(training)
206 async def reset_schedule(
207 self, training_schedule: TrainingScheduleEntity, delete: bool = False
208 ) -> None:
209 trainings_query = (
210 self._database.create_query_factory()
211 .select(TrainingRow.column("id"))
212 .from_(TrainingRow.__table_name__)
213 .and_where(field("training_schedule_id").eq(training_schedule.id.value))
214 )
215 if delete:
216 delete_teams = (
217 self._database.create_query_factory()
218 .delete(TrainingTeamRow.__table_name__)
219 .and_where(TrainingTeamRow.field("training_id").in_(trainings_query))
220 )
221 await self._database.execute(delete_teams)
223 delete_coaches = (
224 self._database.create_query_factory()
225 .delete(TrainingCoachRow.__table_name__)
226 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query))
227 )
228 await self._database.execute(delete_coaches)
230 delete_contents = (
231 self._database.create_query_factory()
232 .delete(TrainingTextRow.__table_name__)
233 .and_where(TrainingTextRow.field("training_id").in_(trainings_query))
234 )
235 await self._database.execute(delete_contents)
236 else:
237 # Because it is not allowed to update the table that is used
238 # in a sub query, we need to create a copy.
239 copy_trainings_query = (
240 self._database.create_query_factory()
241 .select("t.id")
242 .from_(alias(express("({})", trainings_query), "t"))
243 )
244 update_trainings = (
245 self._database.create_query_factory()
246 .update(TrainingRow.__table_name__, {"training_schedule_id": None})
247 .where(TrainingRow.field("id").in_(copy_trainings_query))
248 )
249 await self._database.execute(update_trainings)