From ebeff746cb80b4d26f263b08c377e1eaf08de3e6 Mon Sep 17 00:00:00 2001 From: Jonathan Griffin Date: Tue, 22 Apr 2025 09:00:47 +0200 Subject: [PATCH] added new fields to user endpoints --- ee/api/chalicelib/core/users.py | 279 ++++++++++++++++++-------------- ee/api/routers/scim.py | 86 ++++++---- 2 files changed, 221 insertions(+), 144 deletions(-) diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index a57a194e5..8310c8d10 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -162,37 +162,6 @@ def reset_member(tenant_id, editor_id, user_id_to_update): return {"data": {"invitationLink": generate_new_invitation(user_id_to_update)}} -def update_scim_user( - user_id: int, - tenant_id: int, - email: str, -): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - WITH u AS ( - UPDATE public.users - SET email = %(email)s - WHERE - users.user_id = %(user_id)s - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - RETURNING * - ) - SELECT * - FROM u; - """, - { - "tenant_id": tenant_id, - "user_id": user_id, - "email": email, - }, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - def update(tenant_id, user_id, changes, output=True): AUTH_KEYS = [ "password", @@ -381,13 +350,39 @@ def get(user_id, tenant_id): return helper.dict_to_camel_case(r) +def get_scim_users_paginated(start_index, tenant_id, count=None): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT + users.*, + roles.name AS role_name + FROM public.users + LEFT JOIN public.roles USING (role_id) + WHERE + users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + LIMIT %(limit)s + OFFSET %(offset)s; + """, + {"offset": start_index - 1, "limit": count, "tenant_id": tenant_id}, + ) + ) + r = cur.fetchall() + return helper.list_to_camel_case(r) + + def get_scim_user_by_id(user_id, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( """ - SELECT * + SELECT + users.*, + roles.name AS role_name FROM public.users + LEFT JOIN public.roles USING (role_id) WHERE users.user_id = %(user_id)s AND users.tenant_id = %(tenant_id)s @@ -403,6 +398,140 @@ def get_scim_user_by_id(user_id, tenant_id): return helper.dict_to_camel_case(cur.fetchone()) +def create_scim_user( + email: str, + tenant_id: int, + name: str = "", + internal_id: str | None = None, + role_id: int | None = None, +): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH u AS ( + INSERT INTO public.users ( + tenant_id, + email, + name, + internal_id, + role_id + ) + VALUES ( + %(tenant_id)s, + %(email)s, + %(name)s, + %(internal_id)s, + %(role_id)s + ) + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id); + """, + { + "tenant_id": tenant_id, + "email": email, + "name": name, + "internal_id": internal_id, + "role_id": role_id, + }, + ) + ) + return helper.dict_to_camel_case(cur.fetchone()) + + +def restore_scim_user( + user_id: int, + tenant_id: int, + email: str, + name: str = "", + internal_id: str | None = None, + role_id: int | None = None, +): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH u AS ( + UPDATE public.users + SET + tenant_id = %(tenant_id)s, + email = %(email)s, + name = %(name)s, + internal_id = %(internal_id)s, + role_id = %(role_id)s, + deleted_at = NULL, + created_at = default, + api_key = default, + jwt_iat = NULL, + weekly_report = default + WHERE users.user_id = %(user_id)s + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id); + """, + { + "tenant_id": tenant_id, + "user_id": user_id, + "email": email, + "name": name, + "internal_id": internal_id, + "role_id": role_id, + }, + ) + ) + return helper.dict_to_camel_case(cur.fetchone()) + + +def update_scim_user( + user_id: int, + tenant_id: int, + email: str, + name: str = "", + internal_id: str | None = None, + role_id: int | None = None, +): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH u AS ( + UPDATE public.users + SET + email = %(email)s, + name = %(name)s, + internal_id = %(internal_id)s, + role_id = %(role_id)s + WHERE + users.user_id = %(user_id)s + AND users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id); + """, + { + "tenant_id": tenant_id, + "user_id": user_id, + "email": email, + "name": name, + "internal_id": internal_id, + "role_id": role_id, + }, + ) + ) + return helper.dict_to_camel_case(cur.fetchone()) + + def generate_new_api_key(user_id): with pg_client.PostgresClient() as cur: cur.execute( @@ -513,7 +642,7 @@ def edit_member( return {"data": user} -def get_existing_scim_user_by_unique_values(email): +def get_existing_scim_user_by_unique_values_from_all_users(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( @@ -558,26 +687,6 @@ def get_by_email_only(email): return helper.dict_to_camel_case(r) -def get_users_paginated(start_index, tenant_id, count=None): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - SELECT * - FROM public.users - WHERE - users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - LIMIT %(limit)s - OFFSET %(offset)s; - """, - {"offset": start_index - 1, "limit": count, "tenant_id": tenant_id}, - ) - ) - r = cur.fetchall() - return helper.list_to_camel_case(r) - - def get_member(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( @@ -1093,41 +1202,6 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id= return helper.dict_to_camel_case(cur.fetchone()) -def create_scim_user( - email, - name, - tenant_id, -): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - WITH u AS ( - INSERT INTO public.users ( - tenant_id, - email, - name - ) - VALUES ( - %(tenant_id)s, - %(email)s, - %(name)s - ) - RETURNING * - ) - SELECT * - FROM u; - """, - { - "tenant_id": tenant_id, - "email": email, - "name": name, - }, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - def soft_delete_scim_user_by_id(user_id, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( @@ -1314,35 +1388,6 @@ def restore_sso_user( return helper.dict_to_camel_case(cur.fetchone()) -def restore_scim_user( - user_id, - tenant_id, -): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - WITH u AS ( - UPDATE public.users - SET - tenant_id = %(tenant_id)s, - deleted_at = NULL, - created_at = default, - api_key = default, - jwt_iat = NULL, - weekly_report = default - WHERE users.user_id = %(user_id)s - RETURNING * - ) - SELECT * - FROM u; - """, - {"tenant_id": tenant_id, "user_id": user_id}, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - def get_user_settings(user_id): # read user settings from users.settings:jsonb column with pg_client.PostgresClient() as cur: diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index a33aed847..3e832dcf8 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -1,3 +1,4 @@ +from copy import deepcopy import logging from typing import Any @@ -7,7 +8,7 @@ from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel -from chalicelib.core import users, tenants +from chalicelib.core import users, roles, tenants from chalicelib.utils.scim_auth import ( auth_optional, auth_required, @@ -171,13 +172,21 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")): ) -@public_app.get("/Schemas/{schema_id}", dependencies=[Depends(auth_required)]) -async def get_schema(schema_id: str): +@public_app.get("/Schemas/{schema_id}") +async def get_schema(schema_id: str, tenant_id=Depends(auth_required)): if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS: return _not_found_error_response(schema_id) + schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id]) + if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User": + db_roles = roles.get_roles(tenant_id) + role_names = [role["name"] for role in db_roles] + user_type_attribute = next( + filter(lambda x: x["name"] == "userType", schema["attributes"]) + ) + user_type_attribute["canonicalValues"] = role_names return JSONResponse( status_code=200, - content=SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id], + content=schema, ) @@ -205,7 +214,22 @@ async def get_service_provider_config( MAX_USERS_PER_PAGE = 10 -def _convert_db_user_to_scim_user( +def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, Any]: + role_id = None + if "userType" in data: + role = roles.get_role_by_name(tenant_id, data["userType"]) + role_id = role["roleId"] if role else None + result = { + "email": data["userName"], + "internal_id": data.get("externalId"), + "name": data.get("name", {}).get("formatted") or data.get("displayName"), + "role_id": role_id, + } + result = {k: v for k, v in result.items() if v is not None} + return result + + +def _serialize_db_user_to_scim_user( db_user: dict[str, Any], attributes: list[str] | None = None, excluded_attributes: list[str] | None = None, @@ -239,7 +263,8 @@ def _convert_db_user_to_scim_user( "formatted": db_user["name"], }, "displayName": db_user["name"] or db_user["email"], - + "userType": db_user.get("roleName"), + "active": db_user["deletedAt"] is None, } scim_user = scim_helpers.filter_attributes(scim_user, included_attributes) scim_user = scim_helpers.exclude_attributes(scim_user, excluded_attributes) @@ -260,12 +285,12 @@ async def get_users( ) # todo(jon): this might not be the most efficient thing to do. could be better to just do a count. # but this is the fastest thing at the moment just to test that it's working - total_resources = users.get_users_paginated(1, tenant_id) - db_resources = users.get_users_paginated( + total_resources = users.get_scim_users_paginated(1, tenant_id) + db_resources = users.get_scim_users_paginated( start_index, tenant_id, count=items_per_page ) scim_resources = [ - _convert_db_user_to_scim_user(resource, attributes, excluded_attributes) + _serialize_db_user_to_scim_user(resource, attributes, excluded_attributes) for resource in db_resources ] return JSONResponse( @@ -289,7 +314,7 @@ async def get_user( db_resource = users.get_scim_user_by_id(user_id, tenant_id) if not db_resource: return _not_found_error_response(user_id) - scim_resource = _convert_db_user_to_scim_user( + scim_resource = _serialize_db_user_to_scim_user( db_resource, attributes, excluded_attributes ) return JSONResponse(status_code=200, content=scim_resource) @@ -297,26 +322,28 @@ async def get_user( @public_app.post("/Users") async def create_user(r: Request, tenant_id=Depends(auth_required)): - payload = await r.json() - if "userName" not in payload: + scim_payload = await r.json() + try: + db_payload = _parse_scim_user_input(scim_payload, tenant_id) + except KeyError: return _invalid_value_error_response() - # note(jon): this method will return soft deleted users as well - existing_db_resource = users.get_existing_scim_user_by_unique_values( - payload["userName"] + existing_db_resource = users.get_existing_scim_user_by_unique_values_from_all_users( + db_payload["email"] ) if existing_db_resource and existing_db_resource["deletedAt"] is None: return _uniqueness_error_response() if existing_db_resource and existing_db_resource["deletedAt"] is not None: - db_resource = users.restore_scim_user(existing_db_resource["userId"], tenant_id) + db_resource = users.restore_scim_user( + user_id=existing_db_resource["userId"], + tenant_id=tenant_id, + **db_payload, + ) else: db_resource = users.create_scim_user( - email=payload["userName"], - # note(jon): scim schema does not require the `name.formatted` attribute, but we require `name`. - # so, we have to define the value ourselves here - name="", tenant_id=tenant_id, + **db_payload, ) - scim_resource = _convert_db_user_to_scim_user(db_resource) + scim_resource = _serialize_db_user_to_scim_user(db_resource) response = JSONResponse(status_code=201, content=scim_resource) response.headers["Location"] = scim_resource["meta"]["location"] return response @@ -327,22 +354,26 @@ async def update_user(user_id: str, r: Request, tenant_id=Depends(auth_required) db_resource = users.get_scim_user_by_id(user_id, tenant_id) if not db_resource: return _not_found_error_response(user_id) - current_scim_resource = _convert_db_user_to_scim_user(db_resource) - changes = await r.json() + current_scim_resource = _serialize_db_user_to_scim_user(db_resource) + requested_scim_changes = await r.json() schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"] try: - valid_mutable_changes = scim_helpers.filter_mutable_attributes( - schema, changes, current_scim_resource + valid_mutable_scim_changes = scim_helpers.filter_mutable_attributes( + schema, requested_scim_changes, current_scim_resource ) except ValueError: return _mutability_error_response() + valid_mutable_db_changes = _parse_scim_user_input( + valid_mutable_scim_changes, + tenant_id, + ) try: updated_db_resource = users.update_scim_user( user_id, tenant_id, - email=valid_mutable_changes["userName"], + **valid_mutable_db_changes, ) - updated_scim_resource = _convert_db_user_to_scim_user(updated_db_resource) + updated_scim_resource = _serialize_db_user_to_scim_user(updated_db_resource) return JSONResponse(status_code=200, content=updated_scim_resource) except Exception: # note(jon): for now, this is the only error that would happen when updating the scim user @@ -351,6 +382,7 @@ async def update_user(user_id: str, r: Request, tenant_id=Depends(auth_required) @public_app.delete("/Users/{user_id}") async def delete_user(user_id: str, tenant_id=Depends(auth_required)): + # note(jon): this is a soft delete db_resource = users.get_scim_user_by_id(user_id, tenant_id) if not db_resource: return _not_found_error_response(user_id)