Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_schedule_db_repository.py: 100%
61 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements a training schedule repository for a database."""
3from dataclasses import replace
4from typing import AsyncIterator, cast
6from kwai_core.db.database import Database
7from sql_smith.functions import field
9from kwai_bc_training.coaches._tables import ( # noqa
10 CoachRow,
11 MemberRow,
12 PersonRow,
13)
14from kwai_bc_training.coaches.coach import CoachEntity
15from kwai_bc_training.trainings._tables import (
16 TrainingScheduleCoachRow,
17 TrainingScheduleRow,
18)
19from kwai_bc_training.trainings.training_schedule import (
20 TrainingScheduleEntity,
21 TrainingScheduleIdentifier,
22)
23from kwai_bc_training.trainings.training_schedule_coach_db_query import (
24 TrainingScheduleCoachDbQuery,
25)
26from kwai_bc_training.trainings.training_schedule_db_query import (
27 TrainingScheduleDbQuery,
28 TrainingScheduleQueryRow,
29)
30from kwai_bc_training.trainings.training_schedule_query import (
31 TrainingScheduleQuery,
32)
33from kwai_bc_training.trainings.training_schedule_repository import (
34 TrainingScheduleNotFoundException,
35 TrainingScheduleRepository,
36)
39class TrainingScheduleDbRepository(TrainingScheduleRepository):
40 """A training schedule repository for a database."""
42 def __init__(self, database: Database) -> None:
43 """Initialize the repository.
45 Args:
46 database: The database for this repository
47 """
48 self._database = database
50 def create_query(self) -> TrainingScheduleQuery: # noqa
51 return TrainingScheduleDbQuery(self._database)
53 async def get_by_id(
54 self, id_: TrainingScheduleIdentifier
55 ) -> TrainingScheduleEntity:
56 query = self.create_query()
57 query.filter_by_id(id_)
59 try:
60 entity = await anext(self.get_all(query, 1))
61 except StopAsyncIteration:
62 raise TrainingScheduleNotFoundException(
63 f"Training schedule with id {id_} does not exist."
64 ) from None
65 return entity
67 async def get_all(
68 self,
69 query: TrainingScheduleQuery | None = None,
70 limit: int | None = None,
71 offset: int | None = None,
72 ) -> AsyncIterator[TrainingScheduleEntity]:
73 if query is None:
74 query = self.create_query()
76 training_schedules: dict[
77 TrainingScheduleIdentifier, TrainingScheduleEntity
78 ] = {}
79 async for row in query.fetch(limit, offset):
80 training_schedule = TrainingScheduleQueryRow.map(row).create_entity()
81 training_schedules[
82 cast(TrainingScheduleIdentifier, training_schedule.id)
83 ] = training_schedule
85 training_schedule_ids = training_schedules.keys()
86 all_coaches: dict[TrainingScheduleIdentifier, list[CoachEntity]] = {}
88 if len(training_schedule_ids) > 0:
89 training_schedule_coach_query = TrainingScheduleCoachDbQuery(
90 self._database
91 ).filter_by_schedule(*training_schedule_ids)
92 all_coaches = await training_schedule_coach_query.fetch_coaches()
94 for training_schedule in training_schedules.values():
95 schedule_coaches = frozenset(
96 all_coaches.get(
97 cast(TrainingScheduleIdentifier, training_schedule.id), []
98 )
99 )
100 if len(schedule_coaches) > 0:
101 yield replace(training_schedule, coaches=schedule_coaches)
102 else:
103 yield training_schedule
105 async def create(
106 self, training_schedule: TrainingScheduleEntity
107 ) -> TrainingScheduleEntity:
108 new_id = await self._database.insert(
109 TrainingScheduleRow.__table_name__,
110 TrainingScheduleRow.persist(training_schedule),
111 )
112 result = training_schedule.set_id(TrainingScheduleIdentifier(new_id))
114 await self._create_coaches(result)
116 return result
118 async def update(self, training_schedule: TrainingScheduleEntity):
119 await self._database.update(
120 training_schedule.id.value,
121 TrainingScheduleRow.__table_name__,
122 TrainingScheduleRow.persist(training_schedule),
123 )
125 # Update the 1-to-n relationships for the coaches. First delete and then
126 # recreate the relationships
127 await self._delete_coaches(training_schedule)
128 await self._create_coaches(training_schedule)
130 async def delete(self, training_schedule: TrainingScheduleEntity):
131 await self._database.delete(
132 training_schedule.id.value, TrainingScheduleRow.__table_name__
133 )
134 await self._delete_coaches(training_schedule)
136 async def _delete_coaches(self, training_schedule: TrainingScheduleEntity):
137 """Delete the coaches of the training schedule."""
138 delete_coaches_query = (
139 self._database.create_query_factory()
140 .delete(TrainingScheduleCoachRow.__table_name__)
141 .where(field("training_schedule_id").eq(training_schedule.id.value))
142 )
143 await self._database.execute(delete_coaches_query)
145 async def _create_coaches(self, training_schedule: TrainingScheduleEntity):
146 """Add the coaches to the training schedule."""
147 schedule_coach_rows = [
148 TrainingScheduleCoachRow(
149 training_schedule_id=training_schedule.id.value, coach_id=coach.id.value
150 )
151 for coach in training_schedule.coaches
152 ]
153 if schedule_coach_rows:
154 await self._database.insert(
155 TrainingScheduleCoachRow.__table_name__, *schedule_coach_rows
156 )