Coverage for bc/kwai-bc-training/src/kwai_bc_training/coaches/coach_db_repository.py: 96%

25 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that defines a coach repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai_core.db.database import Database 

6 

7from kwai_bc_training.coaches.coach import CoachEntity, CoachIdentifier 

8from kwai_bc_training.coaches.coach_db_query import CoachDbQuery, CoachQueryRow 

9from kwai_bc_training.coaches.coach_query import CoachQuery 

10from kwai_bc_training.coaches.coach_repository import ( 

11 CoachNotFoundException, 

12 CoachRepository, 

13) 

14 

15 

16class CoachDbRepository(CoachRepository): 

17 """A coach repository for a database.""" 

18 

19 def __init__(self, database: Database): 

20 """Initialize the repository. 

21 

22 Args: 

23 database: The database for this repository. 

24 """ 

25 self._database = database 

26 

27 def create_query(self) -> CoachQuery: 

28 """Create the coach query.""" 

29 return CoachDbQuery(self._database) 

30 

31 async def get_by_id(self, id: CoachIdentifier) -> CoachEntity: 

32 query = self.create_query().filter_by_id(id) 

33 row = await query.fetch_one() 

34 

35 if not row: 

36 raise CoachNotFoundException(f"Coach with id {id} not found.") 

37 

38 return CoachQueryRow.map(row).create_entity() 

39 

40 async def get_by_ids(self, *ids: CoachIdentifier) -> AsyncIterator[CoachEntity]: 

41 query = self.create_query().filter_by_ids(*ids) 

42 

43 async for row in query.fetch(): 

44 yield CoachQueryRow.map(row).create_entity() 

45 

46 async def get_all( 

47 self, query: CoachQuery | None = None 

48 ) -> AsyncIterator[CoachEntity]: 

49 query = query or self.create_query() 

50 async for row in query.fetch(): 

51 yield CoachQueryRow.map(row).create_entity()