Coverage for apps/kwai-api/src/kwai_api/dependencies.py: 82%
56 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that integrates the dependencies in FastAPI."""
3from typing import Annotated, AsyncGenerator
5import jwt
7from fastapi import Cookie, Depends, HTTPException, status
8from fastapi.security import OAuth2PasswordBearer
9from fastapi.templating import Jinja2Templates
10from jwt import ExpiredSignatureError
11from kwai_bc_identity.tokens.access_token_db_repository import AccessTokenDbRepository
12from kwai_bc_identity.tokens.access_token_repository import AccessTokenNotFoundException
13from kwai_bc_identity.tokens.token_identifier import TokenIdentifier
14from kwai_bc_identity.users.user import UserEntity
15from kwai_core.db.database import Database
16from kwai_core.events.publisher import Publisher
17from kwai_core.events.redis_bus import RedisBus
18from kwai_core.settings import SecuritySettings, Settings, get_settings
19from kwai_core.template.jinja2_engine import Jinja2Engine
20from redis.asyncio import Redis
23oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
26async def create_database(
27 settings=Depends(get_settings),
28) -> AsyncGenerator[Database, None]:
29 """Create the database dependency."""
30 database = Database(settings.db)
31 try:
32 yield database
33 finally:
34 await database.close()
37async def create_templates(settings=Depends(get_settings)) -> Jinja2Templates:
38 """Create the template engine dependency."""
39 return Jinja2Engine(website=settings.website).web_templates
42async def get_current_user(
43 settings: Annotated[Settings, Depends(get_settings)],
44 db: Annotated[Database, Depends(create_database)],
45 access_token: Annotated[str | None, Cookie()] = None,
46) -> UserEntity:
47 """Try to get the current user from the access token.
49 Not authorized will be raised when the access token is not found, expired, revoked
50 or when the user is revoked.
51 """
52 if not access_token:
53 raise HTTPException(
54 status.HTTP_401_UNAUTHORIZED, detail="Access token cookie missing"
55 )
56 return await _get_user_from_token(access_token, settings.security, db)
59optional_oauth = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)
62async def get_publisher(
63 settings=Depends(get_settings),
64) -> AsyncGenerator[Publisher, None]:
65 """Get the publisher dependency."""
66 redis = Redis(
67 host=settings.redis.host,
68 port=settings.redis.port,
69 password=settings.redis.password,
70 )
71 bus = RedisBus(redis)
72 yield bus
75async def get_optional_user(
76 settings: Annotated[Settings, Depends(get_settings)],
77 db: Annotated[Database, Depends(create_database)],
78 access_token: Annotated[str | None, Cookie()] = None,
79) -> UserEntity | None:
80 """Try to get the current user from an access token.
82 When no token is available in the request, None will be returned.
84 Not authorized will be raised when the access token is expired, revoked
85 or when the user is revoked.
86 """
87 if access_token is None:
88 return None
90 return await _get_user_from_token(access_token, settings.security, db)
93async def _get_user_from_token(
94 token: str, security_settings: SecuritySettings, db: Database
95) -> UserEntity:
96 """Try to get the user from the token.
98 Returns: The user associated with the access token.
99 """
100 try:
101 payload = jwt.decode(
102 token,
103 security_settings.jwt_secret,
104 algorithms=[security_settings.jwt_algorithm],
105 )
106 except ExpiredSignatureError as exc:
107 raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
109 access_token_repo = AccessTokenDbRepository(db)
110 try:
111 access_token = await access_token_repo.get_by_identifier(
112 TokenIdentifier(hex_string=payload["jti"])
113 )
114 except AccessTokenNotFoundException as exc:
115 raise HTTPException(
116 status.HTTP_401_UNAUTHORIZED, detail="The access token is unknown."
117 ) from exc
119 # Check if the access token is assigned to the user we have in the subject of JWT.
120 if not access_token.user_account.user.uuid == payload["sub"]:
121 raise HTTPException(status.HTTP_401_UNAUTHORIZED)
123 if access_token.revoked:
124 raise HTTPException(status.HTTP_401_UNAUTHORIZED)
126 if access_token.user_account.revoked:
127 raise HTTPException(status.HTTP_401_UNAUTHORIZED)
129 if access_token.expired:
130 raise HTTPException(status.HTTP_401_UNAUTHORIZED)
132 return access_token.user_account.user