Coverage for bc/kwai-bc-teams/src/kwai_bc_teams/repositories/member_db_repository.py: 95%

65 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module for defining a team member repository for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import AsyncGenerator, Self 

5 

6from kwai_bc_club.domain.value_objects import Birthdate, Gender, License 

7from kwai_core.db.database import Database 

8from kwai_core.db.database_query import DatabaseQuery 

9from kwai_core.db.table_row import JoinedTableRow 

10from kwai_core.domain.value_objects.date import Date 

11from kwai_core.domain.value_objects.name import Name 

12from kwai_core.domain.value_objects.unique_id import UniqueId 

13from sql_smith.functions import express, on 

14 

15from kwai_bc_teams.domain.team import TeamIdentifier 

16from kwai_bc_teams.domain.team_member import MemberEntity, MemberIdentifier 

17from kwai_bc_teams.repositories._tables import ( 

18 CountryRow, 

19 MemberPersonRow, 

20 MemberRow, 

21 TeamMemberRow, 

22) 

23from kwai_bc_teams.repositories.member_repository import ( 

24 MemberNotFoundException, 

25 MemberQuery, 

26 MemberRepository, 

27) 

28 

29 

30@dataclass(kw_only=True, frozen=True, slots=True) 

31class MemberQueryRow(JoinedTableRow): 

32 """A data transfer object for the member query.""" 

33 

34 member: MemberRow 

35 person: MemberPersonRow 

36 country: CountryRow 

37 

38 def create_entity(self) -> MemberEntity: 

39 """Create a team member entity from a row.""" 

40 nationality = self.country.create_country() 

41 if nationality is None: 

42 raise ValueError("A member must have a nationality.") 

43 

44 return MemberEntity( 

45 id=MemberIdentifier(self.member.id), # type:ignore 

46 uuid=UniqueId.create_from_string(self.member.uuid), # type:ignore 

47 name=Name(first_name=self.person.firstname, last_name=self.person.lastname), 

48 license=License( 

49 number=self.member.license, # type:ignore 

50 end_date=Date.create_from_date(self.member.license_end_date), # type:ignore 

51 ), 

52 birthdate=Birthdate(Date.create_from_date(self.person.birthdate)), # type:ignore 

53 gender=Gender(self.person.gender), 

54 nationality=nationality, 

55 active_in_club=self.member.active == 1, 

56 ) 

57 

58 

59class MemberDbQuery(MemberQuery, DatabaseQuery): 

60 """A team member query for a database.""" 

61 

62 def __init__(self, database: Database): 

63 super().__init__(database) 

64 

65 def init(self): 

66 self._query.from_(MemberRow.__table_name__).inner_join( 

67 MemberPersonRow.__table_name__, 

68 on(MemberPersonRow.column("id"), MemberRow.column("person_id")), 

69 ).inner_join( 

70 CountryRow.__table_name__, 

71 on(CountryRow.column("id"), MemberPersonRow.column("nationality_id")), 

72 ) 

73 

74 @property 

75 def columns(self): 

76 return MemberQueryRow.get_aliases() 

77 

78 @property 

79 def count_column(self): 

80 return MemberRow.column("id") 

81 

82 def filter_by_id(self, id_: MemberIdentifier) -> Self: 

83 self._query.and_where(MemberRow.field("id").eq(id_.value)) 

84 return self 

85 

86 def filter_by_birthdate( 

87 self, start_date: Date, end_date: Date | None = None 

88 ) -> Self: 

89 if end_date is None: 

90 self._query.and_where(MemberPersonRow.field("birthdate").gte(start_date)) 

91 else: 

92 self._query.and_where( 

93 MemberPersonRow.field("birthdate").between(start_date, end_date) 

94 ) 

95 return self 

96 

97 def filter_by_uuid(self, uuid: UniqueId) -> Self: 

98 self._query.and_where(MemberRow.field("uuid").eq(str(uuid))) 

99 return self 

100 

101 def filter_by_team(self, team_id: TeamIdentifier, in_team: bool = True) -> Self: 

102 inner_select = ( 

103 self._database.create_query_factory() 

104 .select() 

105 .columns(TeamMemberRow.column("member_id")) 

106 .from_(TeamMemberRow.__table_name__) 

107 .where(TeamMemberRow.field("team_id").eq(team_id.value)) 

108 ) 

109 if in_team: 

110 condition = MemberRow.field("id").in_(express("{}", inner_select)) 

111 else: 

112 condition = MemberRow.field("id").not_in(express("{}", inner_select)) 

113 self._query.and_where(condition) 

114 return self 

115 

116 

117class MemberDbRepository(MemberRepository): 

118 """A member repository for a database.""" 

119 

120 def __init__(self, database: Database): 

121 self._database = database 

122 

123 def create_query(self) -> MemberQuery: 

124 return MemberDbQuery(self._database) 

125 

126 async def get(self, query: MemberQuery | None = None) -> MemberEntity: 

127 team_member_iterator = self.get_all(query) 

128 try: 

129 return await anext(team_member_iterator) 

130 except StopAsyncIteration: 

131 raise MemberNotFoundException("Member not found") from None 

132 

133 async def get_all( 

134 self, 

135 query: MemberQuery | None = None, 

136 limit: int | None = None, 

137 offset: int | None = None, 

138 ) -> AsyncGenerator[MemberEntity, None]: 

139 query = query or self.create_query() 

140 

141 async for row in query.fetch(limit, offset): 

142 yield MemberQueryRow.map(row).create_entity()