Coverage for bc/kwai-bc-identity/src/kwai_bc_identity/tokens/access_token_db_repository.py: 85%
41 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements an access token repository for a database."""
3from typing import Any, AsyncIterator
5from kwai_core.db.database import Database
7from kwai_bc_identity.tokens.access_token import (
8 AccessTokenEntity,
9 AccessTokenIdentifier,
10)
11from kwai_bc_identity.tokens.access_token_db_query import AccessTokenDbQuery
12from kwai_bc_identity.tokens.access_token_query import AccessTokenQuery
13from kwai_bc_identity.tokens.access_token_repository import (
14 AccessTokenNotFoundException,
15 AccessTokenRepository,
16)
17from kwai_bc_identity.tokens.token_identifier import TokenIdentifier
18from kwai_bc_identity.tokens.token_tables import AccessTokenRow
19from kwai_bc_identity.users.user_tables import UserAccountRow
22def _create_entity(row: dict[str, Any]) -> AccessTokenEntity:
23 """Create an access token entity from a row."""
24 return AccessTokenRow.map(row).create_entity(
25 UserAccountRow.map(row).create_entity()
26 )
29class AccessTokenDbRepository(AccessTokenRepository):
30 """Database repository for the access token entity."""
32 def __init__(self, database: Database):
33 self._database = database
35 def create_query(self) -> AccessTokenQuery:
36 return AccessTokenDbQuery(self._database)
38 async def get(self, id_: AccessTokenIdentifier) -> AccessTokenEntity:
39 query = self.create_query()
40 query.filter_by_id(id_)
42 row = await query.fetch_one()
43 if row:
44 return _create_entity(row)
46 raise AccessTokenNotFoundException
48 async def get_by_identifier(self, identifier: TokenIdentifier) -> AccessTokenEntity:
49 query = self.create_query()
50 query.filter_by_token_identifier(identifier)
52 row = await query.fetch_one()
53 if row:
54 return _create_entity(row)
56 raise AccessTokenNotFoundException
58 async def get_all(
59 self,
60 query: AccessTokenQuery | None = None,
61 limit: int | None = None,
62 offset: int | None = None,
63 ) -> AsyncIterator[AccessTokenEntity]:
64 query = query or self.create_query()
65 async for row in query.fetch(limit, offset):
66 yield _create_entity(row)
68 async def create(self, access_token: AccessTokenEntity) -> AccessTokenEntity:
69 new_id = await self._database.insert(
70 AccessTokenRow.__table_name__, AccessTokenRow.persist(access_token)
71 )
72 return access_token.set_id(AccessTokenIdentifier(new_id))
74 async def update(self, access_token: AccessTokenEntity):
75 await self._database.update(
76 access_token.id.value,
77 AccessTokenRow.__table_name__,
78 AccessTokenRow.persist(access_token),
79 )
81 async def delete(self, access_token: AccessTokenEntity):
82 await self._database.delete(
83 access_token.id.value, AccessTokenRow.__table_name__
84 )