Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_schedule_db_query.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a TrainingScheduleQuery for a database.""" 

2 

3from dataclasses import dataclass 

4from typing import Self 

5 

6from kwai_core.db.database_query import DatabaseQuery 

7from kwai_core.db.rows import OwnerRow 

8from kwai_core.db.table_row import JoinedTableRow 

9from sql_smith.functions import on 

10 

11from kwai_bc_training.teams.team_tables import TeamRow 

12from kwai_bc_training.trainings._tables import ( 

13 TrainingScheduleRow, 

14) 

15from kwai_bc_training.trainings.training_schedule import ( 

16 TrainingScheduleEntity, 

17 TrainingScheduleIdentifier, 

18) 

19from kwai_bc_training.trainings.training_schedule_query import ( 

20 TrainingScheduleQuery, 

21) 

22 

23 

24@dataclass(kw_only=True, frozen=True, slots=True) 

25class TrainingScheduleQueryRow(JoinedTableRow): 

26 """A data transfer object for the training schedule query.""" 

27 

28 training_schedule: TrainingScheduleRow 

29 team: TeamRow 

30 owner: OwnerRow 

31 

32 def create_entity(self) -> TrainingScheduleEntity: 

33 """Create a training schedule entity.""" 

34 team_entity = None 

35 if self.team.id is not None: 

36 team_entity = self.team.create_entity() 

37 return self.training_schedule.create_entity( 

38 team_entity, self.owner.create_owner() 

39 ) 

40 

41 

42class TrainingScheduleDbQuery(DatabaseQuery, TrainingScheduleQuery): 

43 """A database query for a training schedule.""" 

44 

45 def init(self): 

46 return ( 

47 self._query.from_(TrainingScheduleRow.__table_name__) 

48 .join( 

49 OwnerRow.__table_name__, 

50 on(OwnerRow.column("id"), TrainingScheduleRow.column("user_id")), 

51 ) 

52 .left_join( 

53 TeamRow.__table_name__, 

54 on(TeamRow.column("id"), TrainingScheduleRow.column("team_id")), 

55 ) 

56 ) 

57 

58 @property 

59 def columns(self): 

60 return TrainingScheduleQueryRow.get_aliases() 

61 

62 @property 

63 def count_column(self) -> str: 

64 return TrainingScheduleRow.column("id") 

65 

66 def filter_by_id(self, id_: TrainingScheduleIdentifier) -> Self: 

67 self._query.and_where(TrainingScheduleRow.field("id").eq(id_.value)) 

68 return self 

69 

70 def filter_by_ids(self, *ids: TrainingScheduleIdentifier) -> Self: 

71 unpacked_ids = (id_.value for id_ in ids) 

72 self._query.and_where(TrainingScheduleRow.field("id").in_(*unpacked_ids)) 

73 return self