Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/pages/page_db_repository.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a page repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai_core.db.database import Database 

6from kwai_core.functions import async_groupby 

7from sql_smith.functions import field 

8 

9from kwai_bc_portal.domain.page import PageEntity, PageIdentifier 

10from kwai_bc_portal.pages._tables import ( 

11 PageRow, 

12 PageTextRow, 

13) 

14from kwai_bc_portal.pages.page_db_query import PageDbQuery, PageQueryRow 

15from kwai_bc_portal.pages.page_query import PageQuery 

16from kwai_bc_portal.pages.page_repository import ( 

17 PageNotFoundException, 

18 PageRepository, 

19) 

20 

21 

22class PageDbRepository(PageRepository): 

23 """Page repository for a database.""" 

24 

25 def __init__(self, database: Database): 

26 self._database = database 

27 

28 async def create(self, page: PageEntity) -> PageEntity: 

29 new_id = await self._database.insert( 

30 PageRow.__table_name__, PageRow.persist(page) 

31 ) 

32 result = page.set_id(PageIdentifier(new_id)) 

33 

34 content_rows = [PageTextRow.persist(result, content) for content in page.texts] 

35 await self._database.insert(PageTextRow.__table_name__, *content_rows) 

36 

37 await self._database.commit() 

38 return result 

39 

40 async def update(self, page: PageEntity): 

41 await self._database.update( 

42 page.id.value, PageRow.__table_name__, PageRow.persist(page) 

43 ) 

44 

45 delete_contents_query = ( 

46 self._database.create_query_factory() 

47 .delete(PageTextRow.__table_name__) 

48 .where(field("page_id").eq(page.id.value)) 

49 ) 

50 await self._database.execute(delete_contents_query) 

51 

52 content_rows = [PageTextRow.persist(page, content) for content in page.texts] 

53 await self._database.insert(PageTextRow.__table_name__, *content_rows) 

54 await self._database.commit() 

55 

56 async def delete(self, page: PageEntity): 

57 delete_contents_query = ( 

58 self._database.create_query_factory() 

59 .delete(PageTextRow.__table_name__) 

60 .where(field("page_id").eq(page.id.value)) 

61 ) 

62 await self._database.execute(delete_contents_query) 

63 await self._database.delete(page.id.value, PageRow.__table_name__) 

64 await self._database.commit() 

65 

66 def create_query(self) -> PageQuery: 

67 return PageDbQuery(self._database) 

68 

69 async def get_by_id(self, id_: PageIdentifier) -> PageEntity: 

70 query = self.create_query() 

71 query.filter_by_id(id_) 

72 

73 entity = await anext(self.get_all(query, 1), None) 

74 if entity is None: 

75 raise PageNotFoundException(f"Page with {id_} does not exist.") 

76 

77 return entity 

78 

79 async def get_all( 

80 self, 

81 query: PageQuery | None = None, 

82 limit: int | None = None, 

83 offset: int | None = None, 

84 ) -> AsyncIterator[PageEntity]: 

85 if query is None: 

86 query = self.create_query() 

87 

88 group_by_column = "page_id" 

89 

90 row_iterator = query.fetch(limit, offset) 

91 async for _, group in async_groupby( 

92 row_iterator, key=lambda row: row[group_by_column] 

93 ): 

94 mapped = list(map(PageQueryRow.map, group)) 

95 yield PageQueryRow.create_entity(mapped)