Coverage for bc/kwai-bc-identity/src/kwai_bc_identity/tokens/refresh_token_db_repository.py: 85%
41 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that implements a refresh token repository for a database."""
3from typing import AsyncIterator
5from kwai_core.db.database import Database
7from kwai_bc_identity.tokens.refresh_token import (
8 RefreshTokenEntity,
9 RefreshTokenIdentifier,
10)
11from kwai_bc_identity.tokens.refresh_token_db_query import RefreshTokenDbQuery
12from kwai_bc_identity.tokens.refresh_token_query import RefreshTokenQuery
13from kwai_bc_identity.tokens.refresh_token_repository import (
14 RefreshTokenNotFoundException,
15 RefreshTokenRepository,
16)
17from kwai_bc_identity.tokens.token_identifier import TokenIdentifier
18from kwai_bc_identity.tokens.token_tables import (
19 AccessTokenRow,
20 RefreshTokenRow,
21)
22from kwai_bc_identity.users.user_tables import UserAccountRow
25def _create_entity(row) -> RefreshTokenEntity:
26 """Create a refresh token entity from a row."""
27 return RefreshTokenRow.map(row).create_entity(
28 AccessTokenRow.map(row).create_entity(UserAccountRow.map(row).create_entity())
29 )
32class RefreshTokenDbRepository(RefreshTokenRepository):
33 """Database repository for the refresh token entity."""
35 def __init__(self, database: Database):
36 self._database = database
38 def create_query(self) -> RefreshTokenQuery:
39 return RefreshTokenDbQuery(self._database)
41 async def get_by_token_identifier(
42 self, identifier: TokenIdentifier
43 ) -> RefreshTokenEntity:
44 query = self.create_query()
45 query.filter_by_token_identifier(identifier)
47 row = await query.fetch_one()
48 if row:
49 return _create_entity(row)
51 raise RefreshTokenNotFoundException(
52 f"Token with identifier {identifier} not found"
53 )
55 async def get(self, id_: RefreshTokenIdentifier) -> RefreshTokenEntity:
56 query = self.create_query()
57 query.filter_by_id(id_)
58 row = await query.fetch_one()
59 if row:
60 return _create_entity(row)
62 raise RefreshTokenNotFoundException(f"Token with id {id_} not found")
64 async def get_all(
65 self,
66 query: RefreshTokenDbQuery | None = None,
67 limit: int | None = None,
68 offset: int | None = None,
69 ) -> AsyncIterator[RefreshTokenEntity]:
70 query = query or self.create_query()
71 async for row in query.fetch(limit, offset):
72 yield _create_entity(row)
74 async def create(self, refresh_token: RefreshTokenEntity) -> RefreshTokenEntity:
75 new_id = await self._database.insert(
76 RefreshTokenRow.__table_name__, RefreshTokenRow.persist(refresh_token)
77 )
78 return refresh_token.set_id(RefreshTokenIdentifier(new_id))
80 async def update(self, refresh_token: RefreshTokenEntity):
81 await self._database.update(
82 refresh_token.id.value,
83 RefreshTokenRow.__table_name__,
84 RefreshTokenRow.persist(refresh_token),
85 )
87 async def delete(self, refresh_token: RefreshTokenEntity):
88 await self._database.delete(
89 refresh_token.id.value, RefreshTokenRow.__table_name__
90 )