Coverage for apps/kwai-api/src/kwai_api/dependencies.py: 82%

56 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that integrates the dependencies in FastAPI.""" 

2 

3from typing import Annotated, AsyncGenerator 

4 

5import jwt 

6 

7from fastapi import Cookie, Depends, HTTPException, status 

8from fastapi.security import OAuth2PasswordBearer 

9from fastapi.templating import Jinja2Templates 

10from jwt import ExpiredSignatureError 

11from kwai_bc_identity.tokens.access_token_db_repository import AccessTokenDbRepository 

12from kwai_bc_identity.tokens.access_token_repository import AccessTokenNotFoundException 

13from kwai_bc_identity.tokens.token_identifier import TokenIdentifier 

14from kwai_bc_identity.users.user import UserEntity 

15from kwai_core.db.database import Database 

16from kwai_core.events.publisher import Publisher 

17from kwai_core.events.redis_bus import RedisBus 

18from kwai_core.settings import SecuritySettings, Settings, get_settings 

19from kwai_core.template.jinja2_engine import Jinja2Engine 

20from redis.asyncio import Redis 

21 

22 

23oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") 

24 

25 

26async def create_database( 

27 settings=Depends(get_settings), 

28) -> AsyncGenerator[Database, None]: 

29 """Create the database dependency.""" 

30 database = Database(settings.db) 

31 try: 

32 yield database 

33 finally: 

34 await database.close() 

35 

36 

37async def create_templates(settings=Depends(get_settings)) -> Jinja2Templates: 

38 """Create the template engine dependency.""" 

39 return Jinja2Engine(website=settings.website).web_templates 

40 

41 

42async def get_current_user( 

43 settings: Annotated[Settings, Depends(get_settings)], 

44 db: Annotated[Database, Depends(create_database)], 

45 access_token: Annotated[str | None, Cookie()] = None, 

46) -> UserEntity: 

47 """Try to get the current user from the access token. 

48 

49 Not authorized will be raised when the access token is not found, expired, revoked 

50 or when the user is revoked. 

51 """ 

52 if not access_token: 

53 raise HTTPException( 

54 status.HTTP_401_UNAUTHORIZED, detail="Access token cookie missing" 

55 ) 

56 return await _get_user_from_token(access_token, settings.security, db) 

57 

58 

59optional_oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False) 

60 

61 

62async def get_publisher( 

63 settings=Depends(get_settings), 

64) -> AsyncGenerator[Publisher, None]: 

65 """Get the publisher dependency.""" 

66 redis = Redis( 

67 host=settings.redis.host, 

68 port=settings.redis.port, 

69 password=settings.redis.password, 

70 ) 

71 bus = RedisBus(redis) 

72 yield bus 

73 

74 

75async def get_optional_user( 

76 settings: Annotated[Settings, Depends(get_settings)], 

77 db: Annotated[Database, Depends(create_database)], 

78 access_token: Annotated[str | None, Cookie()] = None, 

79) -> UserEntity | None: 

80 """Try to get the current user from an access token. 

81 

82 When no token is available in the request, None will be returned. 

83 

84 Not authorized will be raised when the access token is expired, revoked 

85 or when the user is revoked. 

86 """ 

87 if access_token is None: 

88 return None 

89 

90 return await _get_user_from_token(access_token, settings.security, db) 

91 

92 

93async def _get_user_from_token( 

94 token: str, security_settings: SecuritySettings, db: Database 

95) -> UserEntity: 

96 """Try to get the user from the token. 

97 

98 Returns: The user associated with the access token. 

99 """ 

100 try: 

101 payload = jwt.decode( 

102 token, 

103 security_settings.jwt_secret, 

104 algorithms=[security_settings.jwt_algorithm], 

105 ) 

106 except ExpiredSignatureError as exc: 

107 raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc 

108 

109 access_token_repo = AccessTokenDbRepository(db) 

110 try: 

111 access_token = await access_token_repo.get_by_identifier( 

112 TokenIdentifier(hex_string=payload["jti"]) 

113 ) 

114 except AccessTokenNotFoundException as exc: 

115 raise HTTPException( 

116 status.HTTP_401_UNAUTHORIZED, detail="The access token is unknown." 

117 ) from exc 

118 

119 # Check if the access token is assigned to the user we have in the subject of JWT. 

120 if not access_token.user_account.user.uuid == payload["sub"]: 

121 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

122 

123 if access_token.revoked: 

124 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

125 

126 if access_token.user_account.revoked: 

127 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

128 

129 if access_token.expired: 

130 raise HTTPException(status.HTTP_401_UNAUTHORIZED) 

131 

132 return access_token.user_account.user