Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_team_db_query.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that defines a database query to get teams of training(s).""" 

2 

3from collections import defaultdict 

4 

5from kwai_core.db.database_query import DatabaseQuery 

6from sql_smith.functions import on 

7 

8from kwai_bc_training.teams.team import TeamEntity 

9from kwai_bc_training.teams.team_tables import TeamRow 

10from kwai_bc_training.trainings._tables import TrainingTeamRow 

11from kwai_bc_training.trainings.training import TrainingIdentifier 

12 

13 

14class TrainingTeamDbQuery(DatabaseQuery): 

15 """A database query for getting teams of training(s).""" 

16 

17 def init(self): 

18 self._query.from_(TrainingTeamRow.__table_name__).left_join( 

19 TeamRow.__table_name__, 

20 on(TrainingTeamRow.column("team_id"), TeamRow.column("id")), 

21 ) 

22 

23 @property 

24 def columns(self): 

25 return TrainingTeamRow.get_aliases() + TeamRow.get_aliases() 

26 

27 def filter_by_trainings(self, *ids: TrainingIdentifier) -> "TrainingTeamDbQuery": 

28 """Filter by trainings. 

29 

30 Only the rows of the trainings with the given ids, will be returned. 

31 """ 

32 unpacked_ids = tuple(i.value for i in ids) 

33 self._query.and_where(TrainingTeamRow.field("training_id").in_(*unpacked_ids)) 

34 return self 

35 

36 async def fetch_teams(self) -> dict[TrainingIdentifier, list[TeamEntity]]: 

37 """Fetch teams. 

38 

39 A specialized fetch method that already transforms the records into 

40 Team objects. 

41 

42 Returns: 

43 A dictionary that contains the list of teams for trainings. The key 

44 is the identifier of a training. 

45 """ 

46 result: dict[TrainingIdentifier, list[TeamEntity]] = defaultdict(list) 

47 

48 async for team_record in self.fetch(): 

49 training_team = TrainingTeamRow.map(team_record) 

50 result[TrainingIdentifier(training_team.training_id)].append( 

51 TeamRow.map(team_record).create_entity() 

52 ) 

53 

54 return result