From af4a344c85bc187cf677848058de4e554fe7cc0a Mon Sep 17 00:00:00 2001 From: Taha Yassine Kraiem Date: Mon, 10 Mar 2025 13:14:10 +0100 Subject: [PATCH] fix(chalice): fix multi-refresh token fix(chalice): fix spot multi-refresh token --- api/chalicelib/core/authorizers.py | 3 --- api/chalicelib/core/spot.py | 12 +++++------ api/chalicelib/core/users.py | 32 ++++++++++++++++++++---------- api/routers/subs/spot.py | 6 +++++- ee/api/chalicelib/core/users.py | 31 ++++++++++++++++++++--------- ee/api/routers/subs/spot.py | 6 +++++- 6 files changed, 60 insertions(+), 30 deletions(-) diff --git a/api/chalicelib/core/authorizers.py b/api/chalicelib/core/authorizers.py index c0e874d86..26aa38127 100644 --- a/api/chalicelib/core/authorizers.py +++ b/api/chalicelib/core/authorizers.py @@ -28,9 +28,6 @@ def jwt_authorizer(scheme: str, token: str, leeway=0) -> dict | None: if scheme.lower() != "bearer": return None try: - logger.warning("Checking JWT token: %s", token) - logger.warning("Against: %s", config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET")) - logger.warning(get_supported_audience()) payload = jwt.decode(jwt=token, key=config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"), algorithms=config("JWT_ALGORITHM"), diff --git a/api/chalicelib/core/spot.py b/api/chalicelib/core/spot.py index 4dab51a41..12b16acef 100644 --- a/api/chalicelib/core/spot.py +++ b/api/chalicelib/core/spot.py @@ -18,7 +18,7 @@ def refresh_spot_jwt_iat_jti(user_id): {"user_id": user_id}) cur.execute(query) row = cur.fetchone() - return row.get("spot_jwt_iat"), row.get("spot_jwt_refresh_jti"), row.get("spot_jwt_refresh_iat") + return users.RefreshSpotJWTs(**row) def logout(user_id: int): @@ -26,13 +26,13 @@ def logout(user_id: int): def refresh(user_id: int, tenant_id: int = -1) -> dict: - spot_jwt_iat, spot_jwt_r_jti, spot_jwt_r_iat = refresh_spot_jwt_iat_jti(user_id=user_id) + j = refresh_spot_jwt_iat_jti(user_id=user_id) return { - "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=spot_jwt_iat, + "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.spot_jwt_iat, aud=AUDIENCE, for_spot=True), - "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=spot_jwt_r_iat, - aud=AUDIENCE, jwt_jti=spot_jwt_r_jti, for_spot=True), - "refreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) - (spot_jwt_iat - spot_jwt_r_iat) + "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.spot_jwt_refresh_iat, + aud=AUDIENCE, jwt_jti=j.spot_jwt_refresh_jti, for_spot=True), + "refreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) - (j.spot_jwt_iat - j.spot_jwt_refresh_iat) } diff --git a/api/chalicelib/core/users.py b/api/chalicelib/core/users.py index a02a2241c..c8fe3c4bf 100644 --- a/api/chalicelib/core/users.py +++ b/api/chalicelib/core/users.py @@ -1,5 +1,6 @@ import json import secrets +from typing import Optional from decouple import config from fastapi import BackgroundTasks @@ -83,7 +84,6 @@ def restore_member(user_id, email, invitation_token, admin, name, owner=False): "name": name, "invitation_token": invitation_token}) cur.execute(query) result = cur.fetchone() - cur.execute(query) result["created_at"] = TimeUTC.datetime_to_timestamp(result["created_at"]) return helper.dict_to_camel_case(result) @@ -552,7 +552,7 @@ def refresh_auth_exists(user_id, jwt_jti=None): return r is not None -class ChangeJwt(BaseModel): +class FullLoginJWTs(BaseModel): jwt_iat: int jwt_refresh_jti: str jwt_refresh_iat: int @@ -565,11 +565,23 @@ class ChangeJwt(BaseModel): def _transform_data(cls, values): if values.get("jwt_refresh_jti") is not None: values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"]) - if values.get("jwt_refresh_jti") is not None: + if values.get("spot_jwt_refresh_jti") is not None: values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"]) return values +class RefreshLoginJWTs(FullLoginJWTs): + spot_jwt_iat: Optional[int] = None + spot_jwt_refresh_jti: Optional[str] = None + spot_jwt_refresh_iat: Optional[int] = None + + +class RefreshSpotJWTs(FullLoginJWTs): + jwt_iat: Optional[int] = None + jwt_refresh_jti: Optional[str] = None + jwt_refresh_iat: Optional[int] = None + + def change_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""UPDATE public.users @@ -589,7 +601,7 @@ def change_jwt_iat_jti(user_id): {"user_id": user_id}) cur.execute(query) row = cur.fetchone() - return ChangeJwt(**row) + return FullLoginJWTs(**row) def refresh_jwt_iat_jti(user_id): @@ -604,7 +616,7 @@ def refresh_jwt_iat_jti(user_id): {"user_id": user_id}) cur.execute(query) row = cur.fetchone() - return row.get("jwt_iat"), row.get("jwt_refresh_jti"), row.get("jwt_refresh_iat") + return RefreshLoginJWTs(**row) def authenticate(email, password, for_change_password=False) -> dict | bool | None: @@ -672,13 +684,13 @@ def logout(user_id: int): def refresh(user_id: int, tenant_id: int = -1) -> dict: - jwt_iat, jwt_r_jti, jwt_r_iat = refresh_jwt_iat_jti(user_id=user_id) + j = refresh_jwt_iat_jti(user_id=user_id) return { - "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=jwt_iat, + "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat, aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=jwt_r_iat, - aud=AUDIENCE, jwt_jti=jwt_r_jti), - "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (jwt_iat - jwt_r_iat) + "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat, + aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti), + "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat), } diff --git a/api/routers/subs/spot.py b/api/routers/subs/spot.py index 42519ac2e..fad1c9332 100644 --- a/api/routers/subs/spot.py +++ b/api/routers/subs/spot.py @@ -1,3 +1,4 @@ +from decouple import config from fastapi import Depends from starlette.responses import JSONResponse, Response @@ -8,7 +9,10 @@ from routers.base import get_routers public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"]) -COOKIE_PATH = "/api/spot/refresh" +if config("LOCAL_DEV", cast=bool, default=False): + COOKIE_PATH = "/spot/refresh" +else: + COOKIE_PATH = "/api/spot/refresh" @app.get('/logout') diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 8f322826d..80ebd9271 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -1,6 +1,7 @@ import json import logging import secrets +from typing import Optional from decouple import config from fastapi import BackgroundTasks, HTTPException @@ -657,7 +658,7 @@ def refresh_auth_exists(user_id, tenant_id, jwt_jti=None): return r is not None -class ChangeJwt(BaseModel): +class FullLoginJWTs(BaseModel): jwt_iat: int jwt_refresh_jti: str jwt_refresh_iat: int @@ -670,11 +671,23 @@ class ChangeJwt(BaseModel): def _transform_data(cls, values): if values.get("jwt_refresh_jti") is not None: values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"]) - if values.get("jwt_refresh_jti") is not None: + if values.get("spot_jwt_refresh_jti") is not None: values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"]) return values +class RefreshLoginJWTs(FullLoginJWTs): + spot_jwt_iat: Optional[int] = None + spot_jwt_refresh_jti: Optional[str] = None + spot_jwt_refresh_iat: Optional[int] = None + + +class RefreshSpotJWTs(FullLoginJWTs): + jwt_iat: Optional[int] = None + jwt_refresh_jti: Optional[str] = None + jwt_refresh_iat: Optional[int] = None + + def change_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""UPDATE public.users @@ -694,7 +707,7 @@ def change_jwt_iat_jti(user_id): {"user_id": user_id}) cur.execute(query) row = cur.fetchone() - return ChangeJwt(**row) + return FullLoginJWTs(**row) def refresh_jwt_iat_jti(user_id): @@ -709,7 +722,7 @@ def refresh_jwt_iat_jti(user_id): {"user_id": user_id}) cur.execute(query) row = cur.fetchone() - return row.get("jwt_iat"), row.get("jwt_refresh_jti"), row.get("jwt_refresh_iat") + return RefreshLoginJWTs(**row) def authenticate(email, password, for_change_password=False) -> dict | bool | None: @@ -869,13 +882,13 @@ def logout(user_id: int): def refresh(user_id: int, tenant_id: int = -1) -> dict: - jwt_iat, jwt_r_jti, jwt_r_iat = refresh_jwt_iat_jti(user_id=user_id) + j = refresh_jwt_iat_jti(user_id=user_id) return { - "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=jwt_iat, + "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat, aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=jwt_r_iat, - aud=AUDIENCE, jwt_jti=jwt_r_jti), - "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (jwt_iat - jwt_r_iat) + "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat, + aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti), + "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat), } diff --git a/ee/api/routers/subs/spot.py b/ee/api/routers/subs/spot.py index 45210c75c..6814942ab 100644 --- a/ee/api/routers/subs/spot.py +++ b/ee/api/routers/subs/spot.py @@ -1,3 +1,4 @@ +from decouple import config from fastapi import Depends from starlette.responses import JSONResponse, Response @@ -8,7 +9,10 @@ from routers.base import get_routers public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"]) -COOKIE_PATH = "/api/spot/refresh" +if config("LOCAL_DEV", cast=bool, default=False): + COOKIE_PATH = "/spot/refresh" +else: + COOKIE_PATH = "/api/spot/refresh" @app.get('/logout')