Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/news/news_item_db_repository.py: 98%
48 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements a news item repository for a database."""
3from typing import AsyncIterator
5from kwai_core.db.database import Database
6from kwai_core.functions import async_groupby
7from sql_smith.functions import field
9from kwai_bc_portal.domain.news_item import NewsItemEntity, NewsItemIdentifier
10from kwai_bc_portal.news._tables import (
11 NewsItemRow,
12 NewsItemTextRow,
13)
14from kwai_bc_portal.news.news_item_db_query import (
15 NewsItemDbQuery,
16 NewsItemQueryRow,
17)
18from kwai_bc_portal.news.news_item_query import NewsItemQuery
19from kwai_bc_portal.news.news_item_repository import (
20 NewsItemNotFoundException,
21 NewsItemRepository,
22)
25class NewsItemDbRepository(NewsItemRepository):
26 """A news item database repository.
28 Attributes:
29 _database: the database for the repository.
30 """
32 def __init__(self, database: Database):
33 self._database = database
35 async def create(self, news_item: NewsItemEntity) -> NewsItemEntity:
36 new_id = await self._database.insert(
37 NewsItemRow.__table_name__, NewsItemRow.persist(news_item)
38 )
39 result = news_item.set_id(NewsItemIdentifier(new_id))
41 content_rows = [
42 NewsItemTextRow.persist(result, content) for content in news_item.texts
43 ]
44 await self._database.insert(NewsItemTextRow.__table_name__, *content_rows)
46 await self._database.commit()
47 return result
49 async def update(self, news_item: NewsItemEntity):
50 await self._database.update(
51 news_item.id.value,
52 NewsItemRow.__table_name__,
53 NewsItemRow.persist(news_item),
54 )
56 delete_contents_query = (
57 self._database.create_query_factory()
58 .delete(NewsItemTextRow.__table_name__)
59 .where(field("news_id").eq(news_item.id.value))
60 )
61 await self._database.execute(delete_contents_query)
63 content_rows = [
64 NewsItemTextRow.persist(news_item, content) for content in news_item.texts
65 ]
66 await self._database.insert(NewsItemTextRow.__table_name__, *content_rows)
67 await self._database.commit()
69 async def delete(self, news_item: NewsItemEntity):
70 delete_contents_query = (
71 self._database.create_query_factory()
72 .delete(NewsItemTextRow.__table_name__)
73 .where(field("news_id").eq(news_item.id.value))
74 )
75 await self._database.execute(delete_contents_query)
76 await self._database.delete(news_item.id.value, NewsItemRow.__table_name__)
77 await self._database.commit()
79 def create_query(self) -> NewsItemQuery:
80 return NewsItemDbQuery(self._database)
82 async def get_by_id(self, id_: NewsItemIdentifier) -> NewsItemEntity:
83 query = self.create_query()
84 query.filter_by_id(id_)
86 entity = await anext(self.get_all(query, 1), None)
87 if entity is None:
88 raise NewsItemNotFoundException(f"News item with {id_} does not exist.")
90 return entity
92 async def get_all(
93 self,
94 query: NewsItemQuery | None = None,
95 limit: int | None = None,
96 offset: int | None = None,
97 ) -> AsyncIterator[NewsItemEntity]:
98 if query is None:
99 query = self.create_query()
101 group_by_column = "news_item_id"
103 row_iterator = query.fetch(limit, offset)
104 async for _, group in async_groupby(
105 row_iterator, key=lambda row: row[group_by_column]
106 ):
107 mapped = list(map(NewsItemQueryRow.map, group))
108 yield NewsItemQueryRow.create_entity(mapped)