Coverage for bc/kwai-bc-training/src/kwai_bc_training/trainings/training_schedule_db_repository.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a training schedule repository for a database.""" 

2 

3from dataclasses import replace 

4from typing import AsyncIterator, cast 

5 

6from kwai_core.db.database import Database 

7from sql_smith.functions import field 

8 

9from kwai_bc_training.coaches._tables import ( # noqa 

10 CoachRow, 

11 MemberRow, 

12 PersonRow, 

13) 

14from kwai_bc_training.coaches.coach import CoachEntity 

15from kwai_bc_training.trainings._tables import ( 

16 TrainingScheduleCoachRow, 

17 TrainingScheduleRow, 

18) 

19from kwai_bc_training.trainings.training_schedule import ( 

20 TrainingScheduleEntity, 

21 TrainingScheduleIdentifier, 

22) 

23from kwai_bc_training.trainings.training_schedule_coach_db_query import ( 

24 TrainingScheduleCoachDbQuery, 

25) 

26from kwai_bc_training.trainings.training_schedule_db_query import ( 

27 TrainingScheduleDbQuery, 

28 TrainingScheduleQueryRow, 

29) 

30from kwai_bc_training.trainings.training_schedule_query import ( 

31 TrainingScheduleQuery, 

32) 

33from kwai_bc_training.trainings.training_schedule_repository import ( 

34 TrainingScheduleNotFoundException, 

35 TrainingScheduleRepository, 

36) 

37 

38 

39class TrainingScheduleDbRepository(TrainingScheduleRepository): 

40 """A training schedule repository for a database.""" 

41 

42 def __init__(self, database: Database) -> None: 

43 """Initialize the repository. 

44 

45 Args: 

46 database: The database for this repository 

47 """ 

48 self._database = database 

49 

50 def create_query(self) -> TrainingScheduleQuery: # noqa 

51 return TrainingScheduleDbQuery(self._database) 

52 

53 async def get_by_id( 

54 self, id_: TrainingScheduleIdentifier 

55 ) -> TrainingScheduleEntity: 

56 query = self.create_query() 

57 query.filter_by_id(id_) 

58 

59 try: 

60 entity = await anext(self.get_all(query, 1)) 

61 except StopAsyncIteration: 

62 raise TrainingScheduleNotFoundException( 

63 f"Training schedule with id {id_} does not exist." 

64 ) from None 

65 return entity 

66 

67 async def get_all( 

68 self, 

69 query: TrainingScheduleQuery | None = None, 

70 limit: int | None = None, 

71 offset: int | None = None, 

72 ) -> AsyncIterator[TrainingScheduleEntity]: 

73 if query is None: 

74 query = self.create_query() 

75 

76 training_schedules: dict[ 

77 TrainingScheduleIdentifier, TrainingScheduleEntity 

78 ] = {} 

79 async for row in query.fetch(limit, offset): 

80 training_schedule = TrainingScheduleQueryRow.map(row).create_entity() 

81 training_schedules[ 

82 cast(TrainingScheduleIdentifier, training_schedule.id) 

83 ] = training_schedule 

84 

85 training_schedule_ids = training_schedules.keys() 

86 all_coaches: dict[TrainingScheduleIdentifier, list[CoachEntity]] = {} 

87 

88 if len(training_schedule_ids) > 0: 

89 training_schedule_coach_query = TrainingScheduleCoachDbQuery( 

90 self._database 

91 ).filter_by_schedule(*training_schedule_ids) 

92 all_coaches = await training_schedule_coach_query.fetch_coaches() 

93 

94 for training_schedule in training_schedules.values(): 

95 schedule_coaches = frozenset( 

96 all_coaches.get( 

97 cast(TrainingScheduleIdentifier, training_schedule.id), [] 

98 ) 

99 ) 

100 if len(schedule_coaches) > 0: 

101 yield replace(training_schedule, coaches=schedule_coaches) 

102 else: 

103 yield training_schedule 

104 

105 async def create( 

106 self, training_schedule: TrainingScheduleEntity 

107 ) -> TrainingScheduleEntity: 

108 new_id = await self._database.insert( 

109 TrainingScheduleRow.__table_name__, 

110 TrainingScheduleRow.persist(training_schedule), 

111 ) 

112 result = training_schedule.set_id(TrainingScheduleIdentifier(new_id)) 

113 

114 await self._create_coaches(result) 

115 

116 return result 

117 

118 async def update(self, training_schedule: TrainingScheduleEntity): 

119 await self._database.update( 

120 training_schedule.id.value, 

121 TrainingScheduleRow.__table_name__, 

122 TrainingScheduleRow.persist(training_schedule), 

123 ) 

124 

125 # Update the 1-to-n relationships for the coaches. First delete and then 

126 # recreate the relationships 

127 await self._delete_coaches(training_schedule) 

128 await self._create_coaches(training_schedule) 

129 

130 async def delete(self, training_schedule: TrainingScheduleEntity): 

131 await self._database.delete( 

132 training_schedule.id.value, TrainingScheduleRow.__table_name__ 

133 ) 

134 await self._delete_coaches(training_schedule) 

135 

136 async def _delete_coaches(self, training_schedule: TrainingScheduleEntity): 

137 """Delete the coaches of the training schedule.""" 

138 delete_coaches_query = ( 

139 self._database.create_query_factory() 

140 .delete(TrainingScheduleCoachRow.__table_name__) 

141 .where(field("training_schedule_id").eq(training_schedule.id.value)) 

142 ) 

143 await self._database.execute(delete_coaches_query) 

144 

145 async def _create_coaches(self, training_schedule: TrainingScheduleEntity): 

146 """Add the coaches to the training schedule.""" 

147 schedule_coach_rows = [ 

148 TrainingScheduleCoachRow( 

149 training_schedule_id=training_schedule.id.value, coach_id=coach.id.value 

150 ) 

151 for coach in training_schedule.coaches 

152 ] 

153 if schedule_coach_rows: 

154 await self._database.insert( 

155 TrainingScheduleCoachRow.__table_name__, *schedule_coach_rows 

156 )