Coverage for bc/kwai-bc-portal/src/kwai_bc_portal/pages/page_db_repository.py: 100%
48 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements a page repository for a database."""
3from typing import AsyncIterator
5from kwai_core.db.database import Database
6from kwai_core.functions import async_groupby
7from sql_smith.functions import field
9from kwai_bc_portal.domain.page import PageEntity, PageIdentifier
10from kwai_bc_portal.pages._tables import (
11 PageRow,
12 PageTextRow,
13)
14from kwai_bc_portal.pages.page_db_query import PageDbQuery, PageQueryRow
15from kwai_bc_portal.pages.page_query import PageQuery
16from kwai_bc_portal.pages.page_repository import (
17 PageNotFoundException,
18 PageRepository,
19)
22class PageDbRepository(PageRepository):
23 """Page repository for a database."""
25 def __init__(self, database: Database):
26 self._database = database
28 async def create(self, page: PageEntity) -> PageEntity:
29 new_id = await self._database.insert(
30 PageRow.__table_name__, PageRow.persist(page)
31 )
32 result = page.set_id(PageIdentifier(new_id))
34 content_rows = [PageTextRow.persist(result, content) for content in page.texts]
35 await self._database.insert(PageTextRow.__table_name__, *content_rows)
37 await self._database.commit()
38 return result
40 async def update(self, page: PageEntity):
41 await self._database.update(
42 page.id.value, PageRow.__table_name__, PageRow.persist(page)
43 )
45 delete_contents_query = (
46 self._database.create_query_factory()
47 .delete(PageTextRow.__table_name__)
48 .where(field("page_id").eq(page.id.value))
49 )
50 await self._database.execute(delete_contents_query)
52 content_rows = [PageTextRow.persist(page, content) for content in page.texts]
53 await self._database.insert(PageTextRow.__table_name__, *content_rows)
54 await self._database.commit()
56 async def delete(self, page: PageEntity):
57 delete_contents_query = (
58 self._database.create_query_factory()
59 .delete(PageTextRow.__table_name__)
60 .where(field("page_id").eq(page.id.value))
61 )
62 await self._database.execute(delete_contents_query)
63 await self._database.delete(page.id.value, PageRow.__table_name__)
64 await self._database.commit()
66 def create_query(self) -> PageQuery:
67 return PageDbQuery(self._database)
69 async def get_by_id(self, id_: PageIdentifier) -> PageEntity:
70 query = self.create_query()
71 query.filter_by_id(id_)
73 entity = await anext(self.get_all(query, 1), None)
74 if entity is None:
75 raise PageNotFoundException(f"Page with {id_} does not exist.")
77 return entity
79 async def get_all(
80 self,
81 query: PageQuery | None = None,
82 limit: int | None = None,
83 offset: int | None = None,
84 ) -> AsyncIterator[PageEntity]:
85 if query is None:
86 query = self.create_query()
88 group_by_column = "page_id"
90 row_iterator = query.fetch(limit, offset)
91 async for _, group in async_groupby(
92 row_iterator, key=lambda row: row[group_by_column]
93 ):
94 mapped = list(map(PageQueryRow.map, group))
95 yield PageQueryRow.create_entity(mapped)