Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_repository.py: 94%
104 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module for implementing a training repository for a database."""
3from dataclasses import replace
4from typing import AsyncIterator, cast
6from kwai_core.db.database import Database
7from kwai_core.functions import async_groupby
8from sql_smith.functions import alias, express, field
10from kwai_bc_training.teams.team import TeamEntity
11from kwai_bc_training.trainings._tables import (
12 TrainingCoachRow,
13 TrainingRow,
14 TrainingTeamRow,
15 TrainingTextRow,
16)
17from kwai_bc_training.trainings.training import TrainingEntity, TrainingIdentifier
18from kwai_bc_training.trainings.training_coach_db_query import TrainingCoachDbQuery
19from kwai_bc_training.trainings.training_db_query import (
20 TrainingDbQuery,
21 TrainingQueryRow,
22)
23from kwai_bc_training.trainings.training_query import TrainingQuery
24from kwai_bc_training.trainings.training_repository import (
25 TrainingNotFoundException,
26 TrainingRepository,
27)
28from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity
29from kwai_bc_training.trainings.training_team_db_query import TrainingTeamDbQuery
30from kwai_bc_training.trainings.value_objects import TrainingCoach
33class TrainingDbRepository(TrainingRepository):
34 """A training repository for a database."""
36 def __init__(self, database: Database):
37 """Initialize the repository.
39 Args:
40 database: The database for this repository.
41 """
42 self._database = database
44 def create_query(self) -> TrainingQuery:
45 return TrainingDbQuery(self._database)
47 async def get_by_id(self, id: TrainingIdentifier) -> TrainingEntity:
48 query = self.create_query()
49 query.filter_by_id(id)
51 try:
52 row_iterator = self.get_all(query, 1)
53 entity = await anext(row_iterator)
54 except StopAsyncIteration:
55 raise TrainingNotFoundException(
56 f"Training with id {id} does not exist"
57 ) from None
58 return entity
60 async def get_all(
61 self,
62 query: TrainingQuery | None = None,
63 limit: int | None = None,
64 offset: int | None = None,
65 ) -> AsyncIterator[TrainingEntity]:
66 if query is None:
67 query = self.create_query()
69 trainings: dict[TrainingIdentifier, TrainingEntity] = {}
70 group_by_column = "training_id"
72 row_it = query.fetch(limit, offset)
73 async for _, group in async_groupby(
74 row_it, key=lambda row: row[group_by_column]
75 ):
76 mapped = list(map(TrainingQueryRow.map, group))
77 training = TrainingQueryRow.create_entity(mapped)
78 trainings[training.id] = training
80 # Get the coaches of all the trainings.
81 training_ids = trainings.keys()
82 coaches: dict[TrainingIdentifier, list[TrainingCoach]] = {}
83 teams: dict[TrainingIdentifier, list[TeamEntity]] = {}
85 if len(training_ids) > 0:
86 training_coach_query = TrainingCoachDbQuery(
87 self._database
88 ).filter_by_trainings(*training_ids)
89 coaches = await training_coach_query.fetch_coaches()
91 # Get the teams of all trainings
92 team_query = TrainingTeamDbQuery(self._database).filter_by_trainings(
93 *training_ids
94 )
95 teams = await team_query.fetch_teams()
97 for training in trainings.values():
98 training_coaches = frozenset(
99 coaches.get(cast(TrainingIdentifier, training.id), [])
100 )
101 training_teams = frozenset(
102 teams.get(cast(TrainingIdentifier, training.id), [])
103 )
104 if len(training_coaches) > 0 or len(training_teams) > 0:
105 yield replace(training, coaches=training_coaches, teams=training_teams)
106 else:
107 yield training
109 async def create(self, training: TrainingEntity) -> TrainingEntity:
110 new_id = await self._database.insert(
111 TrainingRow.__table_name__, TrainingRow.persist(training)
112 )
113 result = training.set_id(TrainingIdentifier(new_id))
115 content_rows = [
116 TrainingTextRow.persist(result, content) for content in training.texts
117 ]
119 await self._database.insert(TrainingTextRow.__table_name__, *content_rows)
120 await self._insert_coaches(result)
121 await self._insert_teams(result)
123 return result
125 async def update(self, training: TrainingEntity) -> None:
126 # Update the training
127 await self._database.update(
128 training.id.value,
129 TrainingRow.__table_name__,
130 TrainingRow.persist(training),
131 )
133 # Update the text, first delete, then insert again.
134 await self._delete_contents(training)
135 content_rows = [
136 TrainingTextRow.persist(training, content) for content in training.texts
137 ]
138 await self._database.insert(TrainingTextRow.__table_name__, *content_rows)
140 # Update coaches, first delete, then insert again.
141 await self._delete_coaches(training)
142 await self._insert_coaches(training)
144 # Update teams, first delete, then insert again.
145 await self._delete_teams(training)
146 await self._insert_teams(training)
148 async def _insert_coaches(self, training: TrainingEntity):
149 """Insert the related coaches."""
150 training_coach_rows = [
151 TrainingCoachRow.persist(training, training_coach)
152 for training_coach in training.coaches
153 ]
154 if training_coach_rows:
155 await self._database.insert(
156 TrainingCoachRow.__table_name__, *training_coach_rows
157 )
159 async def _insert_teams(self, training: TrainingEntity):
160 """Insert the related teams."""
161 training_team_rows = [
162 TrainingTeamRow.persist(training, team) for team in training.teams
163 ]
164 if training_team_rows:
165 await self._database.insert(
166 TrainingTeamRow.__table_name__, *training_team_rows
167 )
169 async def _delete_coaches(self, training: TrainingEntity):
170 """Delete coaches of the training."""
171 delete_coaches_query = (
172 self._database.create_query_factory()
173 .delete(TrainingCoachRow.__table_name__)
174 .where(field("training_id").eq(training.id.value))
175 )
176 await self._database.execute(delete_coaches_query)
178 async def _delete_contents(self, training: TrainingEntity):
179 """Delete text contents of the training."""
180 delete_contents_query = (
181 self._database.create_query_factory()
182 .delete(TrainingTextRow.__table_name__)
183 .where(field("training_id").eq(training.id.value))
184 )
185 await self._database.execute(delete_contents_query)
187 async def _delete_teams(self, training: TrainingEntity):
188 """Delete the teams of the training."""
189 delete_teams_query = (
190 self._database.create_query_factory()
191 .delete(TrainingTeamRow.__table_name__)
192 .where(field("training_id").eq(training.id.value))
193 )
194 await self._database.execute(delete_teams_query)
196 async def delete(self, training: TrainingEntity) -> None:
197 await self._database.delete(training.id.value, TrainingRow.__table_name__)
199 await self._delete_contents(training)
200 await self._delete_coaches(training)
201 await self._delete_teams(training)
203 async def reset_schedule(
204 self, training_schedule: TrainingScheduleEntity, delete: bool = False
205 ) -> None:
206 trainings_query = (
207 self._database.create_query_factory()
208 .select(TrainingRow.column("id"))
209 .from_(TrainingRow.__table_name__)
210 .and_where(field("training_schedule_id").eq(training_schedule.id.value))
211 )
212 if delete:
213 delete_teams = (
214 self._database.create_query_factory()
215 .delete(TrainingTeamRow.__table_name__)
216 .and_where(TrainingTeamRow.field("training_id").in_(trainings_query))
217 )
218 await self._database.execute(delete_teams)
220 delete_coaches = (
221 self._database.create_query_factory()
222 .delete(TrainingCoachRow.__table_name__)
223 .and_where(TrainingCoachRow.field("training_id").in_(trainings_query))
224 )
225 await self._database.execute(delete_coaches)
227 delete_contents = (
228 self._database.create_query_factory()
229 .delete(TrainingTextRow.__table_name__)
230 .and_where(TrainingTextRow.field("training_id").in_(trainings_query))
231 )
232 await self._database.execute(delete_contents)
233 else:
234 # Because it is not allowed to update the table that is used
235 # in a sub query, we need to create a copy.
236 copy_trainings_query = (
237 self._database.create_query_factory()
238 .select("t.id")
239 .from_(alias(express("({})", trainings_query), "t"))
240 )
241 update_trainings = (
242 self._database.create_query_factory()
243 .update(TrainingRow.__table_name__, {"training_schedule_id": None})
244 .where(TrainingRow.field("id").in_(copy_trainings_query))
245 )
246 await self._database.execute(update_trainings)