Coverage for bc/kwai-bc-identity/src/kwai_bc_identity/tokens/refresh_token_db_repository.py: 85%

41 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that implements a refresh token repository for a database.""" 

2 

3from typing import AsyncIterator 

4 

5from kwai_core.db.database import Database 

6 

7from kwai_bc_identity.tokens.refresh_token import ( 

8 RefreshTokenEntity, 

9 RefreshTokenIdentifier, 

10) 

11from kwai_bc_identity.tokens.refresh_token_db_query import RefreshTokenDbQuery 

12from kwai_bc_identity.tokens.refresh_token_query import RefreshTokenQuery 

13from kwai_bc_identity.tokens.refresh_token_repository import ( 

14 RefreshTokenNotFoundException, 

15 RefreshTokenRepository, 

16) 

17from kwai_bc_identity.tokens.token_identifier import TokenIdentifier 

18from kwai_bc_identity.tokens.token_tables import ( 

19 AccessTokenRow, 

20 RefreshTokenRow, 

21) 

22from kwai_bc_identity.users.user_tables import UserAccountRow 

23 

24 

25def _create_entity(row) -> RefreshTokenEntity: 

26 """Create a refresh token entity from a row.""" 

27 return RefreshTokenRow.map(row).create_entity( 

28 AccessTokenRow.map(row).create_entity(UserAccountRow.map(row).create_entity()) 

29 ) 

30 

31 

32class RefreshTokenDbRepository(RefreshTokenRepository): 

33 """Database repository for the refresh token entity.""" 

34 

35 def __init__(self, database: Database): 

36 self._database = database 

37 

38 def create_query(self) -> RefreshTokenQuery: 

39 return RefreshTokenDbQuery(self._database) 

40 

41 async def get_by_token_identifier( 

42 self, identifier: TokenIdentifier 

43 ) -> RefreshTokenEntity: 

44 query = self.create_query() 

45 query.filter_by_token_identifier(identifier) 

46 

47 row = await query.fetch_one() 

48 if row: 

49 return _create_entity(row) 

50 

51 raise RefreshTokenNotFoundException( 

52 f"Token with identifier {identifier} not found" 

53 ) 

54 

55 async def get(self, id_: RefreshTokenIdentifier) -> RefreshTokenEntity: 

56 query = self.create_query() 

57 query.filter_by_id(id_) 

58 row = await query.fetch_one() 

59 if row: 

60 return _create_entity(row) 

61 

62 raise RefreshTokenNotFoundException(f"Token with id {id_} not found") 

63 

64 async def get_all( 

65 self, 

66 query: RefreshTokenDbQuery | None = None, 

67 limit: int | None = None, 

68 offset: int | None = None, 

69 ) -> AsyncIterator[RefreshTokenEntity]: 

70 query = query or self.create_query() 

71 async for row in query.fetch(limit, offset): 

72 yield _create_entity(row) 

73 

74 async def create(self, refresh_token: RefreshTokenEntity) -> RefreshTokenEntity: 

75 new_id = await self._database.insert( 

76 RefreshTokenRow.__table_name__, RefreshTokenRow.persist(refresh_token) 

77 ) 

78 return refresh_token.set_id(RefreshTokenIdentifier(new_id)) 

79 

80 async def update(self, refresh_token: RefreshTokenEntity): 

81 await self._database.update( 

82 refresh_token.id.value, 

83 RefreshTokenRow.__table_name__, 

84 RefreshTokenRow.persist(refresh_token), 

85 ) 

86 

87 async def delete(self, refresh_token: RefreshTokenEntity): 

88 await self._database.delete( 

89 refresh_token.id.value, RefreshTokenRow.__table_name__ 

90 )