From bcfa421b8fed5c58d7fbeda3d47585d7e836505c Mon Sep 17 00:00:00 2001 From: Jonathan Griffin Date: Thu, 17 Apr 2025 14:01:59 +0200 Subject: [PATCH] update GET Users api to be minimal working rfc version --- ee/api/chalicelib/core/users.py | 262 ++++++++++++++----------------- ee/api/routers/scim.py | 182 ++++++++++++--------- ee/api/routers/scim_constants.py | 34 +++- ee/api/routers/scim_helpers.py | 105 +++++++++++++ 4 files changed, 363 insertions(+), 220 deletions(-) create mode 100644 ee/api/routers/scim_helpers.py diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 94d1e8d41..4fd1e8b09 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -30,7 +30,7 @@ def create_new_member(tenant_id, email, invitation_token, admin, name, owner=Fal query = cur.mogrify(f"""\ WITH u AS ( INSERT INTO public.users (tenant_id, email, role, name, data, role_id) - VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, + VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))) @@ -78,7 +78,7 @@ def restore_member(tenant_id, user_id, email, invitation_token, admin, name, own (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))) WHERE user_id=%(user_id)s - RETURNING + RETURNING tenant_id, user_id, email, @@ -104,7 +104,7 @@ def restore_member(tenant_id, user_id, email, invitation_token, admin, name, own u.role_id, roles.name AS role_name, TRUE AS has_password - FROM au,u LEFT JOIN roles USING(tenant_id) + FROM au,u LEFT JOIN roles USING(tenant_id) WHERE roles.role_id IS NULL OR roles.role_id = (SELECT u.role_id FROM u);""", {"tenant_id": tenant_id, "user_id": user_id, "email": email, "role": "owner" if owner else "admin" if admin else "member", "name": name, @@ -240,7 +240,7 @@ def __get_invitation_link(invitation_token): def allow_password_change(user_id, delta_min=10): pass_token = secrets.token_urlsafe(8) with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""UPDATE public.basic_authentication + query = cur.mogrify(f"""UPDATE public.basic_authentication SET change_pwd_expire_at = timezone('utc'::text, now()+INTERVAL '%(delta)s MINUTES'), change_pwd_token = %(pass_token)s WHERE user_id = %(user_id)s""", @@ -255,11 +255,11 @@ def get(user_id, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, users.tenant_id, - email, - role, + email, + role, users.name, (CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -283,38 +283,25 @@ def get(user_id, tenant_id): ) r = cur.fetchone() return helper.dict_to_camel_case(r) - + def get_by_uuid(user_uuid, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT - users.user_id, - users.tenant_id, - email, - role, - users.name, - users.data, - users.internal_id, - (CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin, - role_id, - roles.name AS role_name, - roles.permissions, - roles.all_projects, - basic_authentication.password IS NOT NULL AS has_password, - users.service_account - FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id - LEFT JOIN public.roles USING (role_id) - WHERE - users.data->>'user_id' = %(user_uuid)s - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) - LIMIT 1;""", - {"user_uuid": user_uuid, "tenant_id": tenant_id}) + """ + SELECT * + FROM public.users + WHERE + users.deleted_at IS NULL + AND users.user_id = %(user_id)s + AND users.tenant_id = %(tenant_id)s + LIMIT 1; + """, + { + "user_id": user_uuid, + "tenant_id": tenant_id, + }, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) @@ -323,11 +310,11 @@ def get_deleted_by_uuid(user_uuid, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, users.tenant_id, - email, - role, + email, + role, users.name, users.data, users.internal_id, @@ -375,8 +362,8 @@ def __get_account_info(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT users.name, - tenants.name AS tenant_name, + f"""SELECT users.name, + tenants.name AS tenant_name, tenants.opt_out FROM public.users INNER JOIN public.tenants USING (tenant_id) WHERE users.user_id = %(userId)s @@ -457,11 +444,11 @@ def get_by_email_only(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, users.tenant_id, - users.email, - users.role, + users.email, + users.role, users.name, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -473,7 +460,7 @@ def get_by_email_only(email): roles.name AS role_name FROM public.users LEFT JOIN public.basic_authentication USING(user_id) INNER JOIN public.roles USING(role_id) - WHERE users.email = %(email)s + WHERE users.email = %(email)s AND users.deleted_at IS NULL LIMIT 1;""", {"email": email}) @@ -481,50 +468,39 @@ def get_by_email_only(email): r = cur.fetchone() return helper.dict_to_camel_case(r) -def get_users_paginated(start_index, count=None, email=None): +def get_users_paginated(start_index, tenant_id, count=None): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT - users.user_id AS id, - users.tenant_id, - users.email AS email, - users.data AS data, - users.role, - users.name AS name, - (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin, - basic_authentication.password IS NOT NULL AS has_password, - role_id, - internal_id, - roles.name AS role_name - FROM public.users LEFT JOIN public.basic_authentication USING(user_id) - INNER JOIN public.roles USING(role_id) - WHERE users.deleted_at IS NULL - AND users.data ? 'user_id' - AND email = COALESCE(%(email)s, email) - LIMIT %(count)s - OFFSET %(startIndex)s;;""", - {"startIndex": start_index - 1, "count": count, "email": email}) + """ + SELECT * + FROM public.users + WHERE + users.deleted_at IS NULL + AND users.tenant_id = %(tenant_id)s + LIMIT %(limit)s + OFFSET %(offset)s; + """, + { + "offset": start_index - 1, + "limit": count, + "tenant_id": tenant_id + }, + ) ) r = cur.fetchall() - if len(r): - r = helper.list_to_camel_case(r) - return r - return [] + return helper.list_to_camel_case(r) def get_member(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, - users.email, - users.role, - users.name, + users.email, + users.role, + users.name, users.created_at, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -535,7 +511,7 @@ def get_member(tenant_id, user_id): invitation_token, role_id, roles.name AS role_name - FROM public.users + FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id LEFT JOIN public.roles USING (role_id) WHERE users.tenant_id = %(tenant_id)s AND users.deleted_at IS NULL AND users.user_id = %(user_id)s @@ -557,11 +533,11 @@ def get_members(tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, - users.email, - users.role, - users.name, + users.email, + users.role, + users.name, users.created_at, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -572,10 +548,10 @@ def get_members(tenant_id): invitation_token, role_id, roles.name AS role_name - FROM public.users + FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id LEFT JOIN public.roles USING (role_id) - WHERE users.tenant_id = %(tenant_id)s + WHERE users.tenant_id = %(tenant_id)s AND users.deleted_at IS NULL AND NOT users.service_account ORDER BY name, user_id""", @@ -614,7 +590,7 @@ def delete_member(user_id, tenant_id, id_to_delete): cur.execute( cur.mogrify(f"""UPDATE public.users SET deleted_at = timezone('utc'::text, now()), - jwt_iat= NULL, jwt_refresh_jti= NULL, + jwt_iat= NULL, jwt_refresh_jti= NULL, jwt_refresh_iat= NULL, role_id=NULL WHERE user_id=%(user_id)s AND tenant_id=%(tenant_id)s;""", @@ -634,10 +610,10 @@ def delete_member_as_admin(tenant_id, id_to_delete): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT users.user_id AS user_id, users.tenant_id, - email, + email, role, users.name, origin, @@ -654,12 +630,12 @@ def delete_member_as_admin(tenant_id, id_to_delete): role = 'owner' AND users.tenant_id = %(tenant_id)s AND users.deleted_at IS NULL - AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) + AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) LIMIT 1;""", {"tenant_id": tenant_id, "user_uuid": id_to_delete}) ) r = cur.fetchone() - + if r["user_id"] == id_to_delete: return {"errors": ["unauthorized, cannot delete self"]} @@ -677,7 +653,7 @@ def delete_member_as_admin(tenant_id, id_to_delete): cur.execute( cur.mogrify(f"""UPDATE public.users SET deleted_at = timezone('utc'::text, now()), - jwt_iat= NULL, jwt_refresh_jti= NULL, + jwt_iat= NULL, jwt_refresh_jti= NULL, jwt_refresh_iat= NULL, role_id=NULL WHERE user_id=%(user_id)s AND tenant_id=%(tenant_id)s;""", @@ -743,8 +719,8 @@ def email_exists(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT - count(user_id) + f"""SELECT + count(user_id) FROM public.users WHERE email = %(email)s @@ -760,8 +736,8 @@ def get_deleted_user_by_email(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT - * + f"""SELECT + * FROM public.users WHERE email = %(email)s @@ -777,7 +753,7 @@ def get_by_invitation_token(token, pass_token=None): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT *, DATE_PART('day',timezone('utc'::text, now()) \ - COALESCE(basic_authentication.invited_at,'2000-01-01'::timestamp ))>=1 AS expired_invitation, @@ -797,15 +773,15 @@ def auth_exists(user_id, tenant_id, jwt_iat) -> bool: cur.execute( cur.mogrify( f"""SELECT user_id, - EXTRACT(epoch FROM jwt_iat)::BIGINT AS jwt_iat, + EXTRACT(epoch FROM jwt_iat)::BIGINT AS jwt_iat, changed_at, service_account, basic_authentication.user_id IS NOT NULL AS has_basic_auth - FROM public.users - LEFT JOIN public.basic_authentication USING(user_id) - WHERE user_id = %(userId)s - AND tenant_id = %(tenant_id)s - AND deleted_at IS NULL + FROM public.users + LEFT JOIN public.basic_authentication USING(user_id) + WHERE user_id = %(userId)s + AND tenant_id = %(tenant_id)s + AND deleted_at IS NULL LIMIT 1;""", {"userId": user_id, "tenant_id": tenant_id}) ) @@ -819,9 +795,9 @@ def auth_exists(user_id, tenant_id, jwt_iat) -> bool: def refresh_auth_exists(user_id, tenant_id, jwt_jti=None): with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify(f"""SELECT user_id - FROM public.users - WHERE user_id = %(userId)s + cur.mogrify(f"""SELECT user_id + FROM public.users + WHERE user_id = %(userId)s AND tenant_id= %(tenant_id)s AND deleted_at IS NULL AND jwt_refresh_jti = %(jwt_jti)s @@ -866,17 +842,17 @@ def change_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""UPDATE public.users SET jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), - jwt_refresh_jti = 0, + jwt_refresh_jti = 0, jwt_refresh_iat = timezone('utc'::text, now()-INTERVAL '10s'), spot_jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), - spot_jwt_refresh_jti = 0, - spot_jwt_refresh_iat = timezone('utc'::text, now()-INTERVAL '10s') - WHERE user_id = %(user_id)s - RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, - jwt_refresh_jti, + spot_jwt_refresh_jti = 0, + spot_jwt_refresh_iat = timezone('utc'::text, now()-INTERVAL '10s') + WHERE user_id = %(user_id)s + RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, + jwt_refresh_jti, EXTRACT (epoch FROM jwt_refresh_iat)::BIGINT AS jwt_refresh_iat, - EXTRACT (epoch FROM spot_jwt_iat)::BIGINT AS spot_jwt_iat, - spot_jwt_refresh_jti, + EXTRACT (epoch FROM spot_jwt_iat)::BIGINT AS spot_jwt_iat, + spot_jwt_refresh_jti, EXTRACT (epoch FROM spot_jwt_refresh_iat)::BIGINT AS spot_jwt_refresh_iat;""", {"user_id": user_id}) cur.execute(query) @@ -888,10 +864,10 @@ def refresh_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""UPDATE public.users SET jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), - jwt_refresh_jti = jwt_refresh_jti + 1 - WHERE user_id = %(user_id)s - RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, - jwt_refresh_jti, + jwt_refresh_jti = jwt_refresh_jti + 1 + WHERE user_id = %(user_id)s + RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, + jwt_refresh_jti, EXTRACT (epoch FROM jwt_refresh_iat)::BIGINT AS jwt_refresh_iat;""", {"user_id": user_id}) cur.execute(query) @@ -904,7 +880,7 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No return {"errors": ["must sign-in with SSO, enforced by admin"]} with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, users.tenant_id, users.role, @@ -919,7 +895,7 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No users.service_account FROM public.users AS users INNER JOIN public.basic_authentication USING(user_id) LEFT JOIN public.roles ON (roles.role_id = users.role_id AND roles.tenant_id = users.tenant_id) - WHERE users.email = %(email)s + WHERE users.email = %(email)s AND basic_authentication.password = crypt(%(password)s, basic_authentication.password) AND basic_authentication.user_id = (SELECT su.user_id FROM public.users AS su WHERE su.email=%(email)s AND su.deleted_at IS NULL LIMIT 1) AND (roles.role_id IS NULL OR roles.deleted_at IS NULL) @@ -932,7 +908,7 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No query = cur.mogrify( f"""SELECT 1 FROM public.users - WHERE users.email = %(email)s + WHERE users.email = %(email)s AND users.deleted_at IS NULL AND users.origin IS NOT NULL LIMIT 1;""", @@ -983,17 +959,17 @@ def get_user_role(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, - users.email, - users.role, - users.name, + users.email, + users.role, + users.name, users.created_at, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member - FROM public.users - WHERE users.deleted_at IS NULL + FROM public.users + WHERE users.deleted_at IS NULL AND users.user_id=%(user_id)s AND users.tenant_id=%(tenant_id)s LIMIT 1""", @@ -1007,7 +983,7 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id= query = cur.mogrify(f"""\ WITH u AS ( INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id) - VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, + VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))) @@ -1033,7 +1009,7 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id= query ) return helper.dict_to_camel_case(cur.fetchone()) - + def create_scim_user( tenant_id, user_uuid, @@ -1094,7 +1070,7 @@ def __hard_delete_user_uuid(user_uuid): with pg_client.PostgresClient() as cur: query = cur.mogrify( f"""DELETE FROM public.users - WHERE users.data->>'user_id' = %(user_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL + WHERE users.data->>'user_id' = %(user_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL {"user_uuid": user_uuid}) cur.execute(query) @@ -1124,7 +1100,7 @@ def refresh(user_id: int, tenant_id: int = -1) -> dict: def authenticate_sso(email: str, internal_id: str): with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""SELECT + f"""SELECT users.user_id, users.tenant_id, users.role, @@ -1173,13 +1149,13 @@ def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, in with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""\ WITH u AS ( - UPDATE public.users + UPDATE public.users SET tenant_id= %(tenant_id)s, - role= %(role)s, + role= %(role)s, name= %(name)s, - data= %(data)s, - origin= %(origin)s, - internal_id= %(internal_id)s, + data= %(data)s, + origin= %(origin)s, + internal_id= %(internal_id)s, role_id= (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))), @@ -1198,7 +1174,7 @@ def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, in invited_at= default, change_pwd_token= default, change_pwd_expire_at= default, - changed_at= NULL + changed_at= NULL WHERE user_id = %(user_id)s RETURNING user_id ) @@ -1237,13 +1213,13 @@ def restore_scim_user( with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""\ WITH u AS ( - UPDATE public.users + UPDATE public.users SET tenant_id= %(tenant_id)s, - role= %(role)s, + role= %(role)s, name= %(name)s, - data= %(data)s, - origin= %(origin)s, - internal_id= %(internal_id)s, + data= %(data)s, + origin= %(origin)s, + internal_id= %(internal_id)s, role_id= (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))), @@ -1262,7 +1238,7 @@ def restore_scim_user( invited_at= default, change_pwd_token= default, change_pwd_expire_at= default, - changed_at= NULL + changed_at= NULL WHERE user_id = %(user_id)s RETURNING user_id ) @@ -1290,10 +1266,10 @@ def get_user_settings(user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + f"""SELECT settings - FROM public.users - WHERE users.deleted_at IS NULL + FROM public.users + WHERE users.deleted_at IS NULL AND users.user_id=%(user_id)s LIMIT 1""", {"user_id": user_id}) diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index 2787f185b..61889c82e 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -1,20 +1,22 @@ import logging import re import uuid -from typing import Optional +from typing import Any, Literal, Optional import copy +from datetime import datetime from decouple import config from fastapi import Depends, HTTPException, Header, Query, Response, Request from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_serializer import schemas from chalicelib.core import users, roles, tenants from chalicelib.utils.scim_auth import auth_optional, auth_required, create_tokens, verify_refresh_token from routers.base import get_routers from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG +from routers import scim_helpers logger = logging.getLogger(__name__) @@ -189,90 +191,124 @@ class UserRequest(BaseModel): password: str = Field(default=None) active: bool -class UserResponse(BaseModel): - schemas: list[str] - id: str - userName: str - name: Name - emails: list[Email] # ignore for now - displayName: str - locale: str - externalId: str - active: bool - groups: list[dict] - meta: dict = Field(default={"resourceType": "User"}) class PatchUserRequest(BaseModel): schemas: list[str] Operations: list[dict] -@public_app.get("/Users", dependencies=[Depends(auth_required)]) -async def get_users( - start_index: int = Query(1, alias="startIndex"), - count: Optional[int] = Query(None, alias="count"), - email: Optional[str] = Query(None, alias="filter"), -): - """Get SCIM Users""" - if email: - email = email.split(" ")[2].strip('"') - result_users = users.get_users_paginated(start_index, count, email) +class ResourceMetaResponse(BaseModel): + resourceType: Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None = None + created: datetime | None = None + lastModified: datetime | None = None + location: str | None = None + version: str | None = None - serialized_users = [] - for user in result_users: - serialized_users.append( - UserResponse( - schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], - userName = user["email"], - name = Name.model_validate(user["data"]["name"]), - emails = [Email.model_validate(user["data"]["emails"])], - displayName = user["name"], - locale = user["data"]["locale"], - externalId = user["internalId"], - active = True, # ignore for now, since, can't insert actual timestamp - groups = [], # ignore - ).model_dump(mode='json') - ) + @field_serializer("created", "lastModified") + def serialize_datetime(self, dt: datetime) -> str | None: + if not dt: + return None + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + + +class CommonResourceResponse(BaseModel): + id: str + externalId: str | None = None + schemas: list[ + Literal[ + "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig", + "urn:ietf:params:scim:schemas:core:2.0:ResourceType", + "urn:ietf:params:scim:schemas:core:2.0:Schema", + "urn:ietf:params:scim:schemas:core:2.0:User", + ] + ] + meta: ResourceMetaResponse | None = None + + +class UserResponse(CommonResourceResponse): + schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = ["urn:ietf:params:scim:schemas:core:2.0:User"] + userName: str | None = None + + +class QueryResourceResponse(BaseModel): + schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = ["urn:ietf:params:scim:api:messages:2.0:ListResponse"] + totalResults: int + # todo(jon): add the other schemas + Resources: list[UserResponse] + startIndex: int + itemsPerPage: int + + +MAX_USERS_PER_PAGE = 10 + + +def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str] | None = None, excluded_attributes: list[str] | None = None) -> UserResponse: + user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"] + all_attributes = scim_helpers.get_all_attribute_names(user_schema) + attributes = attributes or all_attributes + always_returned_attributes = scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema) + included_attributes = list(set(attributes).union(set(always_returned_attributes))) + excluded_attributes = excluded_attributes or [] + excluded_attributes = list(set(excluded_attributes).difference(set(always_returned_attributes))) + scim_user = { + "id": str(db_user["userId"]), + "meta": { + "resourceType": "User", + "created": db_user["createdAt"], + "lastModified": db_user["createdAt"], # todo(jon): we currently don't keep track of this in the db + "location": f"Users/{db_user['userId']}" + }, + "userName": db_user["email"], + } + scim_user = scim_helpers.filter_attributes(scim_user, included_attributes) + scim_user = scim_helpers.exclude_attributes(scim_user, excluded_attributes) + return UserResponse(**scim_user) + + +@public_app.get("/Users") +async def get_users( + tenant_id = Depends(auth_required), + requested_start_index: int = Query(1, alias="startIndex"), + requested_items_per_page: int | None = Query(None, alias="count"), + attributes: list[str] | None = Query(None), + excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), +): + start_index = max(1, requested_start_index) + items_per_page = min(max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE) + # todo(jon): this might not be the most efficient thing to do. could be better to just do a count. + # but this is the fastest thing at the moment just to test that it's working + total_users = users.get_users_paginated(1, tenant_id) + db_users = users.get_users_paginated(start_index, tenant_id, count=items_per_page) + scim_users = [ + _convert_db_user_to_scim_user(user, attributes, excluded_attributes) + for user in db_users + ] return JSONResponse( status_code=200, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], - "totalResults": len(serialized_users), - "startIndex": start_index, - "itemsPerPage": len(serialized_users), - "Resources": serialized_users, - }, + content=QueryResourceResponse( + totalResults=len(total_users), + startIndex=start_index, + itemsPerPage=len(scim_users), + Resources=scim_users, + ).model_dump(mode="json", exclude_none=True), ) -@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)]) -def get_user(user_id: str): - """Get SCIM User""" - tenant_id = 1 - user = users.get_by_uuid(user_id, tenant_id) - if not user: - return JSONResponse( - status_code=404, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": "User not found", - "status": 404, - } - ) - res = UserResponse( - schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], - userName = user["email"], - name = Name.model_validate(user["data"]["name"]), - emails = [Email.model_validate(user["data"]["emails"])], - displayName = user["name"], - locale = user["data"]["locale"], - externalId = user["internalId"], - active = True, # ignore for now, since, can't insert actual timestamp - groups = [], # ignore +@public_app.get("/Users/{user_id}") +def get_user( + user_id: str, + tenant_id = Depends(auth_required), + attributes: list[str] | None = Query(None), + excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), +): + db_user = users.get_by_uuid(user_id, tenant_id) + if not db_user: + return _not_found_error_response(user_id) + scim_user = _convert_db_user_to_scim_user(db_user, attributes, excluded_attributes) + return JSONResponse( + status_code=200, + content=scim_user.model_dump(mode="json", exclude_none=True) ) - return JSONResponse(status_code=201, content=res.model_dump(mode='json')) @public_app.post("/Users", dependencies=[Depends(auth_required)]) diff --git a/ee/api/routers/scim_constants.py b/ee/api/routers/scim_constants.py index ee11cfee9..255bfba3e 100644 --- a/ee/api/routers/scim_constants.py +++ b/ee/api/routers/scim_constants.py @@ -1,5 +1,6 @@ # note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants -from typing import Any +from typing import Any, Literal + def _attribute_characteristics( name: str, @@ -102,12 +103,12 @@ def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig", "urn:ietf:params:scim:schemas:core:2.0:ResourceType", "urn:ietf:params:scim:schemas:core:2.0:Schema", - # todo(jon): add the user and group schem when completed + "urn:ietf:params:scim:schemas:core:2.0:User", ], case_exact=True, mutability="readOnly", - returned="default", required=True, + returned="always", ), _attribute_characteristics( name="meta", @@ -670,13 +671,38 @@ SCHEMA_SCHEMA = { } +USER_SCHEMA = { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], + "id": "urn:ietf:params:scim:schemas:core:2.0:User", + "name": "User", + "description": "User account.", + "meta": { + "resourceType": "Schema", + "created": "2025-04-16T14:48:00Z", + # note(jon): we might want to think about adding this resource as part of our db + # and then updating these timestamps from an api and such. for now, if we update + # the configuration, we should update the timestamp here. + "lastModified": "2025-04-16T14:48:00Z", + "location": "Schemas/urn:ietf:params:scim:schemas:core:2.0:User", + }, + "attributes": [ + *_common_resource_attributes(), + _attribute_characteristics( + name="userName", + description="A service provider's unique identifier for the user.", + required=True, + ), + ], +} + + SCHEMAS = sorted( - # todo(jon): add the user schema [ SERVICE_PROVIDER_CONFIG_SCHEMA, RESOURCE_TYPE_SCHEMA, SCHEMA_SCHEMA, + USER_SCHEMA, ], key=lambda x: x["id"], ) diff --git a/ee/api/routers/scim_helpers.py b/ee/api/routers/scim_helpers.py new file mode 100644 index 000000000..d1cc0b651 --- /dev/null +++ b/ee/api/routers/scim_helpers.py @@ -0,0 +1,105 @@ +from typing import Any +from copy import deepcopy + + +def get_all_attribute_names(schema: dict[str, Any]) -> list[str]: + result = [] + def _walk(attrs, prefix=None): + for attr in attrs: + name = attr["name"] + path = f"{prefix}.{name}" if prefix else name + result.append(path) + if attr["type"] == "complex": + sub = attr.get("subAttributes") or attr.get("attributes") or [] + _walk(sub, path) + _walk(schema["attributes"]) + return result + + +def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) -> list[str]: + result = [] + def _walk(attrs, prefix=None): + for attr in attrs: + name = attr["name"] + path = f"{prefix}.{name}" if prefix else name + if attr["returned"] == "always": + result.append(path) + if attr["type"] == "complex": + sub = attr.get("subAttributes") or attr.get("attributes") or [] + _walk(sub, path) + _walk(schema["attributes"]) + return result + + +def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict[str, Any]: + result = {} + for attr in include_list: + parts = attr.split(".", 1) + key = parts[0] + if key not in resource: + continue + + if len(parts) == 1: + # top‑level attr + result[key] = resource[key] + else: + # nested attr + sub = resource[key] + rest = parts[1] + if isinstance(sub, dict): + filtered = filter_attributes(sub, [rest]) + if filtered: + result.setdefault(key, {}).update(filtered) + elif isinstance(sub, list): + # apply to each element + new_list = [] + for item in sub: + if isinstance(item, dict): + f = filter_attributes(item, [rest]) + if f: + new_list.append(f) + if new_list: + result[key] = new_list + return result + + +def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dict[str, Any]: + exclude_map = {} + for attr in exclude_list: + parts = attr.split(".", 1) + key = parts[0] + # rest is empty string for top-level exclusion + rest = parts[1] if len(parts) == 2 else "" + exclude_map.setdefault(key, []).append(rest) + + new_resource = {} + for key, value in resource.items(): + if key in exclude_map: + subs = exclude_map[key] + # If any attr has no rest, exclude entire key + if "" in subs: + continue + # Exclude nested attributes + if isinstance(value, dict): + new_sub = exclude_attributes(value, subs) + if not new_sub: + continue + new_resource[key] = new_sub + elif isinstance(value, list): + new_list = [] + for item in value: + if isinstance(item, dict): + new_item = exclude_attributes(item, subs) + new_list.append(new_item) + else: + new_list.append(item) + new_resource[key] = new_list + else: + new_resource[key] = value + else: + # No exclusion for this key: copy safely + if isinstance(value, (dict, list)): + new_resource[key] = deepcopy(value) + else: + new_resource[key] = value + return new_resource