Coverage for bc/kwai-bc-teams/src/kwai_bc_teams/repositories/member_db_repository.py: 95%
65 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module for defining a team member repository for a database."""
3from dataclasses import dataclass
4from typing import AsyncGenerator, Self
6from kwai_bc_club.domain.value_objects import Birthdate, Gender, License
7from kwai_core.db.database import Database
8from kwai_core.db.database_query import DatabaseQuery
9from kwai_core.db.table_row import JoinedTableRow
10from kwai_core.domain.value_objects.date import Date
11from kwai_core.domain.value_objects.name import Name
12from kwai_core.domain.value_objects.unique_id import UniqueId
13from sql_smith.functions import express, on
15from kwai_bc_teams.domain.team import TeamIdentifier
16from kwai_bc_teams.domain.team_member import MemberEntity, MemberIdentifier
17from kwai_bc_teams.repositories._tables import (
18 CountryRow,
19 MemberPersonRow,
20 MemberRow,
21 TeamMemberRow,
22)
23from kwai_bc_teams.repositories.member_repository import (
24 MemberNotFoundException,
25 MemberQuery,
26 MemberRepository,
27)
30@dataclass(kw_only=True, frozen=True, slots=True)
31class MemberQueryRow(JoinedTableRow):
32 """A data transfer object for the member query."""
34 member: MemberRow
35 person: MemberPersonRow
36 country: CountryRow
38 def create_entity(self) -> MemberEntity:
39 """Create a team member entity from a row."""
40 nationality = self.country.create_country()
41 if nationality is None:
42 raise ValueError("A member must have a nationality.")
44 return MemberEntity(
45 id=MemberIdentifier(self.member.id), # type:ignore
46 uuid=UniqueId.create_from_string(self.member.uuid), # type:ignore
47 name=Name(first_name=self.person.firstname, last_name=self.person.lastname),
48 license=License(
49 number=self.member.license, # type:ignore
50 end_date=Date.create_from_date(self.member.license_end_date), # type:ignore
51 ),
52 birthdate=Birthdate(Date.create_from_date(self.person.birthdate)), # type:ignore
53 gender=Gender(self.person.gender),
54 nationality=nationality,
55 active_in_club=self.member.active == 1,
56 )
59class MemberDbQuery(MemberQuery, DatabaseQuery):
60 """A team member query for a database."""
62 def __init__(self, database: Database):
63 super().__init__(database)
65 def init(self):
66 self._query.from_(MemberRow.__table_name__).inner_join(
67 MemberPersonRow.__table_name__,
68 on(MemberPersonRow.column("id"), MemberRow.column("person_id")),
69 ).inner_join(
70 CountryRow.__table_name__,
71 on(CountryRow.column("id"), MemberPersonRow.column("nationality_id")),
72 )
74 @property
75 def columns(self):
76 return MemberQueryRow.get_aliases()
78 @property
79 def count_column(self):
80 return MemberRow.column("id")
82 def filter_by_id(self, id_: MemberIdentifier) -> Self:
83 self._query.and_where(MemberRow.field("id").eq(id_.value))
84 return self
86 def filter_by_birthdate(
87 self, start_date: Date, end_date: Date | None = None
88 ) -> Self:
89 if end_date is None:
90 self._query.and_where(MemberPersonRow.field("birthdate").gte(start_date))
91 else:
92 self._query.and_where(
93 MemberPersonRow.field("birthdate").between(start_date, end_date)
94 )
95 return self
97 def filter_by_uuid(self, uuid: UniqueId) -> Self:
98 self._query.and_where(MemberRow.field("uuid").eq(str(uuid)))
99 return self
101 def filter_by_team(self, team_id: TeamIdentifier, in_team: bool = True) -> Self:
102 inner_select = (
103 self._database.create_query_factory()
104 .select()
105 .columns(TeamMemberRow.column("member_id"))
106 .from_(TeamMemberRow.__table_name__)
107 .where(TeamMemberRow.field("team_id").eq(team_id.value))
108 )
109 if in_team:
110 condition = MemberRow.field("id").in_(express("{}", inner_select))
111 else:
112 condition = MemberRow.field("id").not_in(express("{}", inner_select))
113 self._query.and_where(condition)
114 return self
117class MemberDbRepository(MemberRepository):
118 """A member repository for a database."""
120 def __init__(self, database: Database):
121 self._database = database
123 def create_query(self) -> MemberQuery:
124 return MemberDbQuery(self._database)
126 async def get(self, query: MemberQuery | None = None) -> MemberEntity:
127 team_member_iterator = self.get_all(query)
128 try:
129 return await anext(team_member_iterator)
130 except StopAsyncIteration:
131 raise MemberNotFoundException("Member not found") from None
133 async def get_all(
134 self,
135 query: MemberQuery | None = None,
136 limit: int | None = None,
137 offset: int | None = None,
138 ) -> AsyncGenerator[MemberEntity, None]:
139 query = query or self.create_query()
141 async for row in query.fetch(limit, offset):
142 yield MemberQueryRow.map(row).create_entity()