Coverage for apps/kwai-api/src/kwai_api/dependencies.py: 81%

54 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that integrates the dependencies in FastAPI.""" 

2 

3from typing import Annotated, AsyncGenerator 

4 

5import jwt 

6 

7from fastapi import Depends, HTTPException, status 

8from fastapi.templating import Jinja2Templates 

9from jwt import ExpiredSignatureError 

10from kwai_bc_identity.tokens.access_token_db_repository import AccessTokenDbRepository 

11from kwai_bc_identity.tokens.access_token_repository import AccessTokenNotFoundException 

12from kwai_bc_identity.tokens.token_identifier import TokenIdentifier 

13from kwai_bc_identity.users.user import UserEntity 

14from kwai_core.db.database import Database 

15from kwai_core.events.publisher import Publisher 

16from kwai_core.events.redis_bus import RedisBus 

17from kwai_core.settings import SecuritySettings, Settings, get_settings 

18from kwai_core.template.jinja2_engine import Jinja2Engine 

19from redis.asyncio import Redis 

20 

21from kwai_api.v1.auth.cookies import use_access_token 

22 

23 

24async def create_database( 

25 settings=Depends(get_settings), 

26) -> AsyncGenerator[Database, None]: 

27 """Create the database dependency.""" 

28 database = Database(settings.db) 

29 try: 

30 yield database 

31 finally: 

32 await database.close() 

33 

34 

35async def create_templates(settings=Depends(get_settings)) -> Jinja2Templates: 

36 """Create the template engine dependency.""" 

37 return Jinja2Engine(website=settings.website).web_templates 

38 

39 

40async def get_current_user( 

41 settings: Annotated[Settings, Depends(get_settings)], 

42 db: Annotated[Database, Depends(create_database)], 

43 access_token: Annotated[str | None, Depends(use_access_token)] = None, 

44) -> UserEntity: 

45 """Try to get the current user from the access token. 

46 

47 Not authorized will be raised when the access token is not found, expired, revoked 

48 or when the user is revoked. 

49 """ 

50 if not access_token: 

51 raise HTTPException( 

52 status.HTTP_401_UNAUTHORIZED, detail="Access token cookie missing" 

53 ) 

54 return await _get_user_from_token(access_token, settings.security, db) 

55 

56 

57async def get_publisher( 

58 settings=Depends(get_settings), 

59) -> AsyncGenerator[Publisher, None]: 

60 """Get the publisher dependency.""" 

61 redis = Redis( 

62 host=settings.redis.host, 

63 port=settings.redis.port, 

64 password=settings.redis.password, 

65 ) 

66 bus = RedisBus(redis) 

67 yield bus 

68 

69 

70async def get_optional_user( 

71 settings: Annotated[Settings, Depends(get_settings)], 

72 db: Annotated[Database, Depends(create_database)], 

73 access_token: Annotated[str | None, Depends(use_access_token)] = None, 

74) -> UserEntity | None: 

75 """Try to get the current user from an access token. 

76 

77 When no token is available in the request, None will be returned. 

78 

79 Not authorized will be raised when the access token is expired, revoked 

80 or when the user is revoked. 

81 """ 

82 if access_token is None: 

83 return None 

84 

85 return await _get_user_from_token(access_token, settings.security, db) 

86 

87 

88async def _get_user_from_token( 

89 token: str, security_settings: SecuritySettings, db: Database 

90) -> UserEntity: 

91 """Try to get the user from the token. 

92 

93 Returns: The user associated with the access token. 

94 """ 

95 try: 

96 payload = jwt.decode( 

97 token, 

98 security_settings.jwt_secret, 

99 algorithms=[security_settings.jwt_algorithm], 

100 ) 

101 except ExpiredSignatureError as exc: 

102 raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc 

103 

104 access_token_repo = AccessTokenDbRepository(db) 

105 try: 

106 access_token = await access_token_repo.get_by_identifier( 

107 TokenIdentifier(hex_string=payload["jti"]) 

108 ) 

109 except AccessTokenNotFoundException as exc: 

110 raise HTTPException( 

111 status.HTTP_401_UNAUTHORIZED, detail="The access token is unknown." 

112 ) from exc 

113 

114 # Check if the access token is assigned to the user we have in the subject of JWT. 

115 if not access_token.user_account.user.uuid == payload["sub"]: 

116 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

117 

118 if access_token.revoked: 

119 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

120 

121 if access_token.user_account.revoked: 

122 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

123 

124 if access_token.expired: 

125 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

126 

127 return access_token.user_account.user