Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_db_query.py: 100%
74 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements a training query for a database."""
3from dataclasses import dataclass
4from typing import AsyncIterator, Self
6from kwai_core.db.database import Database, Record
7from kwai_core.db.database_query import DatabaseQuery
8from kwai_core.db.rows import OwnerRow
9from kwai_core.db.table_row import JoinedTableRow
10from kwai_core.domain.value_objects.timestamp import Timestamp
11from sql_smith.functions import alias, criteria, express, func, group, literal, on
13from kwai_bc_training.coaches.coach import CoachEntity
14from kwai_bc_training.teams.team import TeamEntity
15from kwai_bc_training.teams.team_tables import TeamRow
16from kwai_bc_training.trainings._tables import (
17 TrainingCoachRow,
18 TrainingRow,
19 TrainingScheduleRow,
20 TrainingTeamRow,
21 TrainingTextRow,
22)
23from kwai_bc_training.trainings.training import TrainingEntity, TrainingIdentifier
24from kwai_bc_training.trainings.training_query import TrainingQuery
25from kwai_bc_training.trainings.training_schedule import TrainingScheduleEntity
28@dataclass(kw_only=True, frozen=True, slots=True)
29class TrainingQueryRow(JoinedTableRow):
30 """A data transfer object for the training query."""
32 text: TrainingTextRow
33 training: TrainingRow
34 owner: OwnerRow
35 team: TeamRow
36 training_schedule: TrainingScheduleRow
37 training_schedule_owner: OwnerRow
39 @classmethod
40 def create_entity(cls, rows: list[Self]) -> TrainingEntity:
41 """Create a training entity from a group of rows."""
42 if rows[0].training_schedule.id is not None:
43 training_schema = rows[0].training_schedule.create_entity(
44 rows[0].team.create_entity(),
45 rows[0].training_schedule_owner.create_owner(),
46 )
47 else:
48 training_schema = None
49 return rows[0].training.create_entity(
50 tuple([row.text.create_text(row.owner.create_owner()) for row in rows]),
51 training_schema,
52 )
55class TrainingDbQuery(TrainingQuery, DatabaseQuery):
56 """A database query for trainings."""
58 def __init__(self, database: Database):
59 self._main_query = database.create_query_factory().select()
60 super().__init__(database)
62 def init(self):
63 # This query will be used as CTE, so only joins the tables that are needed
64 # for counting and limiting results.
65 self._query.from_(TrainingRow.__table_name__).left_join(
66 TrainingScheduleRow.__table_name__,
67 on(
68 TrainingRow.column("training_schedule_id"),
69 TrainingScheduleRow.column("id"),
70 ),
71 )
72 self._main_query = (
73 self._main_query.from_(TrainingRow.__table_name__)
74 .columns(*self.columns)
75 .with_("limited", self._query)
76 .right_join("limited", on("limited.id", TrainingRow.column("id")))
77 .left_join(
78 TrainingScheduleRow.__table_name__,
79 on(
80 TrainingRow.column("training_schedule_id"),
81 TrainingScheduleRow.column("id"),
82 ),
83 )
84 .left_join(
85 alias(OwnerRow.__table_name__, "training_schema_owners"),
86 on(TrainingScheduleRow.column("user_id"), "training_schema_owners.id"),
87 )
88 .left_join(
89 TeamRow.__table_name__,
90 on(TeamRow.column("id"), TrainingScheduleRow.column("team_id")),
91 )
92 .join(
93 TrainingTextRow.__table_name__,
94 on(
95 TrainingTextRow.column("training_id"),
96 TrainingRow.column("id"),
97 ),
98 )
99 .join(
100 OwnerRow.__table_name__,
101 on(OwnerRow.column("id"), TrainingTextRow.column("user_id")),
102 )
103 )
105 @property
106 def columns(self):
107 return TrainingQueryRow.get_aliases()
109 @property
110 def count_column(self) -> str:
111 return TrainingRow.column("id")
113 def filter_by_id(self, id_: TrainingIdentifier) -> "TrainingQuery":
114 self._query.and_where(TrainingRow.field("id").eq(id_.value))
115 return self
117 def filter_by_year_month(
118 self, year: int, month: int | None = None
119 ) -> "TrainingQuery":
120 condition = criteria(
121 "{} = {}", func("YEAR", TrainingRow.column("start_date")), literal(year)
122 )
123 if month is not None:
124 condition = condition.and_(
125 criteria(
126 "{} = {}",
127 func("MONTH", TrainingRow.column("start_date")),
128 literal(month),
129 )
130 )
131 self._query.and_where(group(condition))
132 return self
134 def filter_by_dates(self, start: Timestamp, end: Timestamp) -> "TrainingQuery":
135 self._query.and_where(
136 TrainingRow.field("start_date").between(str(start), str(end))
137 )
138 return self
140 def filter_by_coach(self, coach: CoachEntity) -> "TrainingQuery":
141 inner_select = (
142 self._database.create_query_factory()
143 .select()
144 .columns(TrainingCoachRow.column("training_id"))
145 .from_(TrainingCoachRow.__table_name__)
146 .where(TrainingCoachRow.field("coach_id").eq(coach.id.value))
147 )
148 condition = TrainingRow.field("id").in_(express("{}", inner_select))
149 self._query.and_where(group(condition))
150 return self
152 def filter_by_team(self, team: TeamEntity) -> "TrainingQuery":
153 inner_select = (
154 self._database.create_query_factory()
155 .select()
156 .columns(TrainingTeamRow.column("training_id"))
157 .from_(TrainingTeamRow.__table_name__)
158 .where(TrainingTeamRow.field("team_id").eq(team.id.value))
159 )
160 condition = TrainingRow.field("id").in_(express("{}", inner_select))
161 self._query.and_where(group(condition))
162 return self
164 def filter_by_training_schedule(
165 self, training_schedule: TrainingScheduleEntity
166 ) -> "TrainingQuery":
167 self._query.and_where(
168 TrainingRow.field("training_schedule_id").eq(training_schedule.id.value)
169 )
170 return self
172 def filter_active(self) -> "TrainingQuery":
173 self._query.and_where(TrainingRow.field("active").eq(1))
174 return self
176 def fetch(
177 self, limit: int | None = None, offset: int | None = None
178 ) -> AsyncIterator[Record]:
179 self._query.limit(limit)
180 self._query.offset(offset)
181 self._query.columns(TrainingRow.column("id"))
182 self._main_query.order_by(TrainingRow.column("id"))
184 return self._database.fetch(self._main_query)
186 def order_by_date(self) -> "TrainingQuery":
187 self._query.order_by(TrainingRow.column("start_date"), "ASC")
188 self._main_query.order_by(TrainingRow.column("start_date"), "ASC")
189 return self