Coverage for bc/kwai-bc-identity/src/kwai_bc_identity/tokens/access_token_db_repository.py: 85%

41 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements an access token repository for a database.""" 

2 

3from typing import Any, AsyncIterator 

4 

5from kwai_core.db.database import Database 

6 

7from kwai_bc_identity.tokens.access_token import ( 

8 AccessTokenEntity, 

9 AccessTokenIdentifier, 

10) 

11from kwai_bc_identity.tokens.access_token_db_query import AccessTokenDbQuery 

12from kwai_bc_identity.tokens.access_token_query import AccessTokenQuery 

13from kwai_bc_identity.tokens.access_token_repository import ( 

14 AccessTokenNotFoundException, 

15 AccessTokenRepository, 

16) 

17from kwai_bc_identity.tokens.token_identifier import TokenIdentifier 

18from kwai_bc_identity.tokens.token_tables import AccessTokenRow 

19from kwai_bc_identity.users.user_tables import UserAccountRow 

20 

21 

22def _create_entity(row: dict[str, Any]) -> AccessTokenEntity: 

23 """Create an access token entity from a row.""" 

24 return AccessTokenRow.map(row).create_entity( 

25 UserAccountRow.map(row).create_entity() 

26 ) 

27 

28 

29class AccessTokenDbRepository(AccessTokenRepository): 

30 """Database repository for the access token entity.""" 

31 

32 def __init__(self, database: Database): 

33 self._database = database 

34 

35 def create_query(self) -> AccessTokenQuery: 

36 return AccessTokenDbQuery(self._database) 

37 

38 async def get(self, id_: AccessTokenIdentifier) -> AccessTokenEntity: 

39 query = self.create_query() 

40 query.filter_by_id(id_) 

41 

42 row = await query.fetch_one() 

43 if row: 

44 return _create_entity(row) 

45 

46 raise AccessTokenNotFoundException 

47 

48 async def get_by_identifier(self, identifier: TokenIdentifier) -> AccessTokenEntity: 

49 query = self.create_query() 

50 query.filter_by_token_identifier(identifier) 

51 

52 row = await query.fetch_one() 

53 if row: 

54 return _create_entity(row) 

55 

56 raise AccessTokenNotFoundException 

57 

58 async def get_all( 

59 self, 

60 query: AccessTokenQuery | None = None, 

61 limit: int | None = None, 

62 offset: int | None = None, 

63 ) -> AsyncIterator[AccessTokenEntity]: 

64 query = query or self.create_query() 

65 async for row in query.fetch(limit, offset): 

66 yield _create_entity(row) 

67 

68 async def create(self, access_token: AccessTokenEntity) -> AccessTokenEntity: 

69 new_id = await self._database.insert( 

70 AccessTokenRow.__table_name__, AccessTokenRow.persist(access_token) 

71 ) 

72 return access_token.set_id(AccessTokenIdentifier(new_id)) 

73 

74 async def update(self, access_token: AccessTokenEntity): 

75 await self._database.update( 

76 access_token.id.value, 

77 AccessTokenRow.__table_name__, 

78 AccessTokenRow.persist(access_token), 

79 ) 

80 

81 async def delete(self, access_token: AccessTokenEntity): 

82 await self._database.delete( 

83 access_token.id.value, AccessTokenRow.__table_name__ 

84 )