Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/news/news_item_db_repository.py: 98%

48 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a news item repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai_core.db.database import Database 

6from kwai_core.functions import async_groupby 

7from sql_smith.functions import field 

8 

9from kwai_bc_portal.domain.news_item import NewsItemEntity, NewsItemIdentifier 

10from kwai_bc_portal.news._tables import ( 

11 NewsItemRow, 

12 NewsItemTextRow, 

13) 

14from kwai_bc_portal.news.news_item_db_query import ( 

15 NewsItemDbQuery, 

16 NewsItemQueryRow, 

17) 

18from kwai_bc_portal.news.news_item_query import NewsItemQuery 

19from kwai_bc_portal.news.news_item_repository import ( 

20 NewsItemNotFoundException, 

21 NewsItemRepository, 

22) 

23 

24 

25class NewsItemDbRepository(NewsItemRepository): 

26 """A news item database repository. 

27 

28 Attributes: 

29 _database: the database for the repository. 

30 """ 

31 

32 def __init__(self, database: Database): 

33 self._database = database 

34 

35 async def create(self, news_item: NewsItemEntity) -> NewsItemEntity: 

36 new_id = await self._database.insert( 

37 NewsItemRow.__table_name__, NewsItemRow.persist(news_item) 

38 ) 

39 result = news_item.set_id(NewsItemIdentifier(new_id)) 

40 

41 content_rows = [ 

42 NewsItemTextRow.persist(result, content) for content in news_item.texts 

43 ] 

44 await self._database.insert(NewsItemTextRow.__table_name__, *content_rows) 

45 

46 await self._database.commit() 

47 return result 

48 

49 async def update(self, news_item: NewsItemEntity): 

50 await self._database.update( 

51 news_item.id.value, 

52 NewsItemRow.__table_name__, 

53 NewsItemRow.persist(news_item), 

54 ) 

55 

56 delete_contents_query = ( 

57 self._database.create_query_factory() 

58 .delete(NewsItemTextRow.__table_name__) 

59 .where(field("news_id").eq(news_item.id.value)) 

60 ) 

61 await self._database.execute(delete_contents_query) 

62 

63 content_rows = [ 

64 NewsItemTextRow.persist(news_item, content) for content in news_item.texts 

65 ] 

66 await self._database.insert(NewsItemTextRow.__table_name__, *content_rows) 

67 await self._database.commit() 

68 

69 async def delete(self, news_item: NewsItemEntity): 

70 delete_contents_query = ( 

71 self._database.create_query_factory() 

72 .delete(NewsItemTextRow.__table_name__) 

73 .where(field("news_id").eq(news_item.id.value)) 

74 ) 

75 await self._database.execute(delete_contents_query) 

76 await self._database.delete(news_item.id.value, NewsItemRow.__table_name__) 

77 await self._database.commit() 

78 

79 def create_query(self) -> NewsItemQuery: 

80 return NewsItemDbQuery(self._database) 

81 

82 async def get_by_id(self, id_: NewsItemIdentifier) -> NewsItemEntity: 

83 query = self.create_query() 

84 query.filter_by_id(id_) 

85 

86 entity = await anext(self.get_all(query, 1), None) 

87 if entity is None: 

88 raise NewsItemNotFoundException(f"News item with {id_} does not exist.") 

89 

90 return entity 

91 

92 async def get_all( 

93 self, 

94 query: NewsItemQuery | None = None, 

95 limit: int | None = None, 

96 offset: int | None = None, 

97 ) -> AsyncIterator[NewsItemEntity]: 

98 if query is None: 

99 query = self.create_query() 

100 

101 group_by_column = "news_item_id" 

102 

103 row_iterator = query.fetch(limit, offset) 

104 async for _, group in async_groupby( 

105 row_iterator, key=lambda row: row[group_by_column] 

106 ): 

107 mapped = list(map(NewsItemQueryRow.map, group)) 

108 yield NewsItemQueryRow.create_entity(mapped)