Coverage for apps/kwai-api/src/kwai_api/v1/auth/endpoints/validation.py: 83%

36 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that defines some endpoints to validate a login.""" 

2 

3from typing import Annotated 

4 

5import jwt 

6 

7from fastapi import APIRouter, Cookie, Header, HTTPException, Request, Response, status 

8from fastapi.params import Depends 

9from kwai_bc_identity.exceptions import AuthenticationException 

10from kwai_bc_identity.refresh_access_token import ( 

11 RefreshAccessToken, 

12 RefreshAccessTokenCommand, 

13) 

14from kwai_bc_identity.tokens.access_token_db_repository import AccessTokenDbRepository 

15from kwai_bc_identity.tokens.log_user_login_db_service import LogUserLoginDbService 

16from kwai_bc_identity.tokens.refresh_token_db_repository import RefreshTokenDbRepository 

17from kwai_core.db.database import Database 

18from kwai_core.db.uow import UnitOfWork 

19from kwai_core.settings import Settings 

20 

21from kwai_api.dependencies import create_database, get_settings 

22from kwai_api.v1.auth.cookies import create_cookies 

23 

24 

25router = APIRouter() 

26 

27 

28@router.get( 

29 "/validate", 

30 summary="Validate a current login", 

31 responses={ 

32 200: {"description": "The access token is still valid."}, 

33 401: {"description": "Not authorized."}, 

34 }, 

35) 

36async def validate( 

37 request: Request, 

38 response: Response, 

39 settings: Annotated[Settings, Depends(get_settings)], 

40 db: Annotated[Database, Depends(create_database)], 

41 access_token: Annotated[str | None, Cookie()] = None, 

42 refresh_token: Annotated[str | None, Cookie()] = None, 

43 x_forwarded_for: Annotated[str | None, Header()] = None, 

44 user_agent: Annotated[str | None, Header()] = "", 

45): 

46 """Validate the user. 

47 

48 When the user has an expired access token it will automatically be renewed when the refresh token 

49 is still valid. 

50 """ 

51 if access_token is None and refresh_token is None: 

52 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 

53 

54 if refresh_token is None: 

55 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) 

56 

57 try: 

58 decoded_refresh_token = jwt.decode( 

59 refresh_token, 

60 key=settings.security.jwt_refresh_secret, 

61 algorithms=[settings.security.jwt_algorithm], 

62 ) 

63 except jwt.ExpiredSignatureError as exc: 

64 raise HTTPException( 

65 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

66 ) from exc 

67 

68 if x_forwarded_for: 

69 client_ip = x_forwarded_for 

70 else: 

71 client_ip = request.client.host if request.client else "" 

72 

73 command = RefreshAccessTokenCommand( 

74 identifier=decoded_refresh_token["jti"], 

75 access_token_expiry_minutes=settings.security.access_token_expires_in, 

76 refresh_token_expiry_minutes=settings.security.refresh_token_expires_in, 

77 ) 

78 

79 try: 

80 async with UnitOfWork(db, always_commit=True): 

81 new_refresh_token = await RefreshAccessToken( 

82 RefreshTokenDbRepository(db), 

83 AccessTokenDbRepository(db), 

84 LogUserLoginDbService( 

85 db, 

86 email="", 

87 user_agent=user_agent or "", 

88 client_ip=client_ip, 

89 ), 

90 ).execute(command) 

91 except AuthenticationException as exc: 

92 raise HTTPException( 

93 status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc) 

94 ) from exc 

95 

96 create_cookies(response, new_refresh_token, settings) 

97 response.status_code = status.HTTP_200_OK