Coverage for bc/kwai-bc-teams/src/kwai_bc_teams/repositories/team_db_repository.py: 100%

76 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a team repository for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import Any, AsyncGenerator, Self 

5 

6from kwai_bc_club.domain.value_objects import Birthdate, Gender, License 

7from kwai_core.db.database import Database 

8from kwai_core.db.database_query import DatabaseQuery 

9from kwai_core.db.table_row import JoinedTableRow 

10from kwai_core.domain.value_objects.date import Date 

11from kwai_core.domain.value_objects.name import Name 

12from kwai_core.domain.value_objects.unique_id import UniqueId 

13from kwai_core.functions import async_groupby 

14from sql_smith.functions import field, on 

15 

16from kwai_bc_teams.domain.team import TeamEntity, TeamIdentifier 

17from kwai_bc_teams.domain.team_member import ( 

18 MemberEntity, 

19 MemberIdentifier, 

20 TeamMember, 

21) 

22from kwai_bc_teams.repositories._tables import ( 

23 CountryRow, 

24 MemberPersonRow, 

25 MemberRow, 

26 TeamMemberRow, 

27 TeamRow, 

28) 

29from kwai_bc_teams.repositories.team_repository import ( 

30 TeamNotFoundException, 

31 TeamQuery, 

32 TeamRepository, 

33) 

34 

35 

36@dataclass(kw_only=True, frozen=True, slots=True) 

37class MemberPersonCountryMixin: 

38 """Dataclass for a member related row.""" 

39 

40 member: MemberRow 

41 member_person: MemberPersonRow 

42 country: CountryRow 

43 

44 def create_member_entity(self) -> MemberEntity: 

45 """Create a member entity from a row.""" 

46 return MemberEntity( 

47 id=MemberIdentifier(self.member.id), # type:ignore 

48 name=Name( 

49 first_name=self.member_person.firstname, 

50 last_name=self.member_person.lastname, 

51 ), 

52 license=License( 

53 number=self.member.license, # type: ignore 

54 end_date=Date.create_from_date(self.member.license_end_date), # type: ignore 

55 ), 

56 birthdate=Birthdate( 

57 date=Date.create_from_date(self.member_person.birthdate) # type: ignore 

58 ), 

59 nationality=self.country.create_country(), # type: ignore 

60 gender=Gender(self.member_person.gender), 

61 uuid=UniqueId.create_from_string(self.member.uuid), # type: ignore 

62 active_in_club=self.member.active == 1, 

63 ) 

64 

65 

66@dataclass(kw_only=True, frozen=True, slots=True) 

67class TeamQueryRow(MemberPersonCountryMixin, JoinedTableRow): 

68 """A data transfer object for the team query.""" 

69 

70 team: TeamRow 

71 team_member: TeamMemberRow 

72 

73 @classmethod 

74 def create_entity(cls, rows: list[dict[str, Any]]) -> TeamEntity: 

75 """Create a team entity from a group of rows.""" 

76 team_query_row = cls.map(rows[0]) 

77 team_members = {} 

78 for row in rows: 

79 mapped_row = cls.map(row) 

80 if mapped_row.member.id is None: 

81 continue 

82 

83 member = mapped_row.create_member_entity() 

84 team_members[member.uuid] = mapped_row.team_member.create_team_member( 

85 member 

86 ) 

87 return team_query_row.team.create_entity(team_members) 

88 

89 

90class TeamDbQuery(TeamQuery, DatabaseQuery): 

91 """A team query for a database.""" 

92 

93 def __init__(self, database: Database): 

94 super().__init__(database) 

95 

96 def init(self): 

97 self._query.from_(TeamRow.__table_name__).left_join( 

98 TeamMemberRow.__table_name__, 

99 on(TeamRow.column("id"), TeamMemberRow.column("team_id")), 

100 ).left_join( 

101 MemberRow.__table_name__, 

102 on(MemberRow.column("id"), TeamMemberRow.column("member_id")), 

103 ).left_join( 

104 MemberPersonRow.__table_name__, 

105 on(MemberPersonRow.column("id"), MemberRow.column("person_id")), 

106 ).left_join( 

107 CountryRow.__table_name__, 

108 on(CountryRow.column("id"), MemberPersonRow.column("nationality_id")), 

109 ) 

110 

111 @property 

112 def columns(self): 

113 return TeamQueryRow.get_aliases() 

114 

115 @property 

116 def count_column(self) -> str: 

117 return TeamRow.column("id") 

118 

119 def filter_by_id(self, id_: TeamIdentifier) -> Self: 

120 self._query.and_where(TeamRow.field("id").eq(id_.value)) 

121 return self 

122 

123 

124class TeamDbRepository(TeamRepository): 

125 """A team repository for a database.""" 

126 

127 def create_query(self) -> TeamQuery: 

128 return TeamDbQuery(self._database) 

129 

130 async def get(self, query: TeamQuery | None = None) -> TeamEntity: 

131 team_iterator = self.get_all(query) 

132 try: 

133 return await anext(team_iterator) 

134 except StopAsyncIteration: 

135 raise TeamNotFoundException("Team not found") from None 

136 

137 async def get_all( 

138 self, 

139 query: TeamQuery | None = None, 

140 limit: int | None = None, 

141 offset: int | None = None, 

142 ) -> AsyncGenerator[TeamEntity, None]: 

143 if query is None: 

144 query = self.create_query() 

145 

146 group_by_column = "team_id" 

147 row_iterator = query.fetch(limit=limit, offset=offset) 

148 async for _, group in async_groupby( 

149 row_iterator, key=lambda row: row[group_by_column] 

150 ): 

151 yield TeamQueryRow.create_entity(group) 

152 

153 def __init__(self, database: Database): 

154 self._database = database 

155 

156 async def create(self, team: TeamEntity) -> TeamEntity: 

157 new_team_id = await self._database.insert( 

158 TeamRow.__table_name__, TeamRow.persist(team) 

159 ) 

160 return team.set_id(TeamIdentifier(new_team_id)) 

161 

162 async def delete(self, team: TeamEntity) -> None: 

163 delete_team_members_query = ( 

164 self._database.create_query_factory() 

165 .delete(TeamMemberRow.__table_name__) 

166 .where(field("team_id").eq(team.id.value)) 

167 ) 

168 await self._database.execute(delete_team_members_query) 

169 await self._database.delete(team.id.value, TeamRow.__table_name__) 

170 

171 async def update(self, team: TeamEntity): 

172 await self._database.update( 

173 team.id.value, TeamRow.__table_name__, TeamRow.persist(team) 

174 ) 

175 

176 async def add_team_member(self, team: TeamEntity, member: TeamMember): 

177 team_member_row = TeamMemberRow.persist(team, member) 

178 await self._database.insert(TeamMemberRow.__table_name__, team_member_row)