Coverage for bc/kwai-bc-identity/src/kwai_bc_identity/tokens/token_tables.py: 97%

37 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that defines all tables for tokens.""" 

2 

3from dataclasses import dataclass 

4from datetime import datetime 

5from typing import Self 

6 

7from kwai_core.db.table_row import TableRow, unwrap 

8from kwai_core.domain.value_objects.timestamp import Timestamp 

9from kwai_core.domain.value_objects.traceable_time import TraceableTime 

10 

11from kwai_bc_identity.tokens.access_token import ( 

12 AccessTokenEntity, 

13 AccessTokenIdentifier, 

14) 

15from kwai_bc_identity.tokens.refresh_token import ( 

16 RefreshTokenEntity, 

17 RefreshTokenIdentifier, 

18) 

19from kwai_bc_identity.tokens.token_identifier import TokenIdentifier 

20from kwai_bc_identity.tokens.user_log import UserLogEntity, UserLogIdentifier 

21from kwai_bc_identity.tokens.value_objects import IpAddress, OpenId 

22from kwai_bc_identity.users.user_account import UserAccountEntity 

23 

24 

25@dataclass(kw_only=True, frozen=True, slots=True) 

26class AccessTokenRow(TableRow): 

27 """Represent a table row in the access tokens table.""" 

28 

29 __table_name__ = "oauth_access_tokens" 

30 

31 id: int | None 

32 identifier: str 

33 expiration: datetime 

34 user_id: int 

35 revoked: int 

36 created_at: datetime 

37 updated_at: datetime | None 

38 

39 def create_entity(self, user_account: UserAccountEntity) -> AccessTokenEntity: 

40 """Create an entity from the table row.""" 

41 return AccessTokenEntity( 

42 id=AccessTokenIdentifier(unwrap(self.id)), 

43 identifier=TokenIdentifier(hex_string=self.identifier), 

44 expiration=Timestamp.create_utc(self.expiration), 

45 user_account=user_account, 

46 revoked=self.revoked == 1, 

47 traceable_time=TraceableTime( 

48 created_at=Timestamp.create_utc(self.created_at), 

49 updated_at=Timestamp.create_utc(self.updated_at), 

50 ), 

51 ) 

52 

53 @classmethod 

54 def persist(cls, access_token: AccessTokenEntity) -> "AccessTokenRow": 

55 """Persist an access token entity to a table record.""" 

56 return AccessTokenRow( 

57 id=access_token.id.value, 

58 identifier=str(access_token.identifier), 

59 expiration=unwrap(access_token.expiration.timestamp), 

60 user_id=access_token.user_account.id.value, 

61 revoked=1 if access_token.revoked else 0, 

62 created_at=unwrap(access_token.traceable_time.created_at.timestamp), 

63 updated_at=access_token.traceable_time.updated_at.timestamp, 

64 ) 

65 

66 

67@dataclass(kw_only=True, frozen=True, slots=True) 

68class RefreshTokenRow(TableRow): 

69 """Represent a table row in the refresh token table.""" 

70 

71 __table_name__ = "oauth_refresh_tokens" 

72 

73 id: int | None = None 

74 identifier: str 

75 access_token_id: int 

76 expiration: datetime 

77 revoked: int 

78 created_at: datetime 

79 updated_at: datetime | None 

80 

81 def create_entity(self, access_token: AccessTokenEntity) -> RefreshTokenEntity: 

82 """Create a refresh token entity from the table row.""" 

83 return RefreshTokenEntity( 

84 id=RefreshTokenIdentifier(unwrap(self.id)), 

85 identifier=TokenIdentifier(hex_string=self.identifier), 

86 access_token=access_token, 

87 expiration=Timestamp.create_utc(self.expiration), 

88 revoked=self.revoked == 1, 

89 traceable_time=TraceableTime( 

90 created_at=unwrap(Timestamp.create_utc(self.created_at)), 

91 updated_at=Timestamp.create_utc(self.updated_at), 

92 ), 

93 ) 

94 

95 @classmethod 

96 def persist(cls, refresh_token: RefreshTokenEntity) -> "RefreshTokenRow": 

97 """Transform a refresh token entity into a table record.""" 

98 return RefreshTokenRow( 

99 id=refresh_token.id.value, 

100 identifier=str(refresh_token.identifier), 

101 access_token_id=refresh_token.access_token.id.value, 

102 expiration=unwrap(refresh_token.expiration.timestamp), 

103 revoked=1 if refresh_token.revoked else 0, 

104 created_at=unwrap(refresh_token.traceable_time.created_at.timestamp), 

105 updated_at=refresh_token.traceable_time.updated_at.timestamp, 

106 ) 

107 

108 

109@dataclass(kw_only=True, frozen=True, slots=True) 

110class UserLogRow(TableRow): 

111 """Represent a table row in the user logs table.""" 

112 

113 __table_name__ = "user_logs" 

114 

115 id: int 

116 success: int 

117 email: str 

118 user_id: int | None 

119 refresh_token_id: int | None 

120 client_ip: str 

121 user_agent: str 

122 openid_sub: str 

123 openid_provider: str 

124 remark: str 

125 created_at: datetime 

126 

127 def create_entity( 

128 self, 

129 user_account: UserAccountEntity | None, 

130 refresh_token: RefreshTokenEntity | None, 

131 ) -> UserLogEntity: 

132 """Create a User Log entity from the table row.""" 

133 return UserLogEntity( 

134 id=UserLogIdentifier(self.id), 

135 email=self.email, 

136 user_account=user_account, 

137 refresh_token=refresh_token, 

138 client_ip=IpAddress.create(self.client_ip), 

139 user_agent=self.user_agent, 

140 remark=self.remark, 

141 openid=OpenId(sub=self.openid_sub, provider=self.openid_provider), 

142 created_at=Timestamp.create_utc(self.created_at), 

143 ) 

144 

145 @classmethod 

146 def persist(cls, user_log: UserLogEntity) -> Self: 

147 """Transform a user log entity into a table record.""" 

148 return cls( 

149 id=user_log.id.value, 

150 success=1 if user_log.success else 0, 

151 email=user_log.email, 

152 user_id=None 

153 if user_log.user_account is None 

154 else user_log.user_account.id.value, 

155 refresh_token_id=None 

156 if user_log.refresh_token is None 

157 else user_log.refresh_token.id.value, 

158 client_ip=str(user_log.client_ip), 

159 user_agent=user_log.user_agent, 

160 openid_sub=user_log.openid.sub, 

161 openid_provider=user_log.openid.provider, 

162 remark=user_log.remark, 

163 created_at=unwrap(user_log.created_at.timestamp), 

164 )