From 9057637b84e0f91b3012b4cd467e2b4da29271e2 Mon Sep 17 00:00:00 2001 From: Jonathan Griffin Date: Tue, 22 Apr 2025 16:40:37 +0200 Subject: [PATCH] added groups endpoints and reformatted code --- ee/api/chalicelib/core/users.py | 11 +- ee/api/routers/fixtures/group_schema.json | 107 +++++- ee/api/routers/scim.py | 304 +++++++++++++----- ee/api/routers/scim_constants.py | 3 +- ee/api/routers/scim_groups.py | 211 ++++++++++++ ee/api/routers/scim_helpers.py | 38 ++- .../db/init_dbs/postgresql/init_schema.sql | 13 +- 7 files changed, 579 insertions(+), 108 deletions(-) create mode 100644 ee/api/routers/scim_groups.py diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 8310c8d10..f7e084b8a 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -444,12 +444,13 @@ def create_scim_user( def restore_scim_user( - user_id: int, - tenant_id: int, + userId: int, + tenantId: int, email: str, name: str = "", internal_id: str | None = None, role_id: int | None = None, + **kwargs, ): with pg_client.PostgresClient() as cur: cur.execute( @@ -477,8 +478,8 @@ def restore_scim_user( FROM u LEFT JOIN public.roles USING (role_id); """, { - "tenant_id": tenant_id, - "user_id": user_id, + "tenant_id": tenantId, + "user_id": userId, "email": email, "name": name, "internal_id": internal_id, @@ -642,7 +643,7 @@ def edit_member( return {"data": user} -def get_existing_scim_user_by_unique_values_from_all_users(email): +def get_existing_scim_user_by_unique_values_from_all_users(email: str, **kwargs): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( diff --git a/ee/api/routers/fixtures/group_schema.json b/ee/api/routers/fixtures/group_schema.json index f1ef0f71f..1a56a3cdf 100644 --- a/ee/api/routers/fixtures/group_schema.json +++ b/ee/api/routers/fixtures/group_schema.json @@ -3,12 +3,107 @@ "name": "Group", "description": "Group", "attributes": [ + { + "name": "schemas", + "type": "string", + "multiValued": true, + "description": "An array of Strings containing URI that are used to indicate the namespaces of the SCIM schemas that define the attributes present in the current JSON structure.", + "required": true, + "caseExact": false, + "mutability": "immutable", + "returned": "always", + "uniqueness": "none" + }, + { + "name": "id", + "type": "string", + "multiValued": false, + "description": "Unique identifier for the resource, assigned by the service provider. MUST be non-empty, unique, stable, and non-reassignable. Clients MUST NOT specify this value.", + "required": true, + "caseExact": true, + "mutability": "readOnly", + "returned": "always", + "uniqueness": "server" + }, + { + "name": "externalId", + "type": "string", + "multiValued": false, + "description": "Identifier for the resource as defined by the provisioning client. OPTIONAL; clients MAY include a non-empty value.", + "required": false, + "caseExact": true, + "mutability": "readWrite", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "meta", + "type": "complex", + "multiValued": false, + "description": "Resource metadata. MUST be ignored when provided by clients.", + "required": false, + "mutability": "readOnly", + "returned": "default", + "subAttributes": [ + { + "name": "resourceType", + "type": "string", + "multiValued": false, + "description": "The resource type name.", + "required": false, + "caseExact": true, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "created", + "type": "dateTime", + "multiValued": false, + "description": "The date and time the resource was added.", + "required": false, + "mutability": "readOnly", + "returned": "default" + }, + { + "name": "lastModified", + "type": "dateTime", + "multiValued": false, + "description": "The most recent date and time the resource was modified.", + "required": false, + "mutability": "readOnly", + "returned": "default" + }, + { + "name": "location", + "type": "reference", + "referenceTypes": ["external"], + "multiValued": false, + "description": "The URI of the resource being returned.", + "required": false, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "version", + "type": "string", + "multiValued": false, + "description": "The version (ETag) of the resource being returned.", + "required": false, + "caseExact": true, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + } + ] + }, { "name": "displayName", "type": "string", "multiValued": false, "description": "Human readable name for the Group. REQUIRED.", - "required": false, + "required": true, "caseExact": false, "mutability": "readWrite", "returned": "default", @@ -26,8 +121,8 @@ "type": "string", "multiValued": false, "description": "Identifier of the member of this Group.", - "required": false, - "caseExact": false, + "required": true, + "caseExact": true, "mutability": "immutable", "returned": "default", "uniqueness": "none" @@ -35,7 +130,7 @@ { "name": "$ref", "type": "reference", - "referenceTypes": ["User", "Group"], + "referenceTypes": ["User"], "multiValued": false, "description": "The URI of the corresponding member resource of this Group.", "required": false, @@ -48,10 +143,10 @@ "name": "type", "type": "string", "multiValued": false, - "description": "A label indicating the type of resource; e.g., 'User' or 'Group'.", + "description": "A label indicating the type of resource; e.g., 'User'.", "required": false, "caseExact": false, - "canonicalValues": ["User", "Group"], + "canonicalValues": ["User"], "mutability": "immutable", "returned": "default", "uniqueness": "none" diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index 3e832dcf8..c50ea41a5 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -1,12 +1,14 @@ from copy import deepcopy import logging -from typing import Any +from typing import Any, Callable +from enum import Enum from decouple import config from fastapi import Depends, HTTPException, Header, Query, Response, Request from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel +from psycopg2 import errors from chalicelib.core import users, roles, tenants from chalicelib.utils.scim_auth import ( @@ -17,7 +19,7 @@ from chalicelib.utils.scim_auth import ( ) from routers.base import get_routers from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG -from routers import scim_helpers +from routers import scim_helpers, scim_groups logger = logging.getLogger(__name__) @@ -65,7 +67,7 @@ RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS = { } -def _not_found_error_response(resource_id: str): +def _not_found_error_response(resource_id: int): return JSONResponse( status_code=404, content={ @@ -123,6 +125,17 @@ def _invalid_value_error_response(): ) +def _internal_server_error_response(detail: str): + return JSONResponse( + status_code=500, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": detail, + "status": "500", + }, + ) + + @public_app.get("/ResourceTypes", dependencies=[Depends(auth_required)]) async def get_resource_types(filter_param: str | None = Query(None, alias="filter")): if filter_param is not None: @@ -211,7 +224,28 @@ async def get_service_provider_config( return JSONResponse(status_code=200, content=SERVICE_PROVIDER_CONFIG) -MAX_USERS_PER_PAGE = 10 +def _serialize_db_resource_to_scim_resource_with_attribute_awareness( + db_resource: dict[str, Any], + schema_id: str, + serialize_db_resource_to_scim_resource: Callable[[dict[str, Any]], dict[str, Any]], + attributes: list[str] | None = None, + excluded_attributes: list[str] | None = None, +) -> dict[str, Any]: + schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id] + all_attributes = scim_helpers.get_all_attribute_names(schema) + attributes = attributes or all_attributes + always_returned_attributes = ( + scim_helpers.get_all_attribute_names_where_returned_is_always(schema) + ) + included_attributes = list(set(attributes).union(set(always_returned_attributes))) + excluded_attributes = excluded_attributes or [] + excluded_attributes = list( + set(excluded_attributes).difference(set(always_returned_attributes)) + ) + scim_resource = serialize_db_resource_to_scim_resource(db_resource) + scim_resource = scim_helpers.filter_attributes(scim_resource, included_attributes) + scim_resource = scim_helpers.exclude_attributes(scim_resource, excluded_attributes) + return scim_resource def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, Any]: @@ -229,25 +263,8 @@ def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, An return result -def _serialize_db_user_to_scim_user( - db_user: dict[str, Any], - attributes: list[str] | None = None, - excluded_attributes: list[str] | None = None, -) -> dict[str, Any]: - user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[ - "urn:ietf:params:scim:schemas:core:2.0:User" - ] - all_attributes = scim_helpers.get_all_attribute_names(user_schema) - attributes = attributes or all_attributes - always_returned_attributes = ( - scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema) - ) - included_attributes = list(set(attributes).union(set(always_returned_attributes))) - excluded_attributes = excluded_attributes or [] - excluded_attributes = list( - set(excluded_attributes).difference(set(always_returned_attributes)) - ) - scim_user = { +def _serialize_db_user_to_scim_user(db_user: dict[str, Any]) -> dict[str, Any]: + return { "id": str(db_user["userId"]), "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], "meta": { @@ -266,32 +283,107 @@ def _serialize_db_user_to_scim_user( "userType": db_user.get("roleName"), "active": db_user["deletedAt"] is None, } - scim_user = scim_helpers.filter_attributes(scim_user, included_attributes) - scim_user = scim_helpers.exclude_attributes(scim_user, excluded_attributes) - return scim_user -@public_app.get("/Users") -async def get_users( +def _serialize_db_group_to_scim_group(db_resource: dict[str, Any]) -> dict[str, Any]: + members = db_resource["users"] or [] + return { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], + "id": str(db_resource["groupId"]), + "externalId": db_resource["externalId"], + "meta": { + "resourceType": "Group", + "created": db_resource["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), + "lastModified": db_resource["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), + "location": f"Groups/{db_resource['groupId']}", + }, + "displayName": db_resource["name"], + "members": [ + { + "value": str(member["userId"]), + "$ref": f"Users/{member['userId']}", + "type": "User", + } + for member in members + ], + } + + +def _parse_scim_group_input(data: dict[str, Any], tenant_id: int) -> dict[str, Any]: + return { + "name": data["displayName"], + "external_id": data.get("externalId"), + "user_ids": [int(member["value"]) for member in data.get("members", [])], + } + + +RESOURCE_TYPE_TO_RESOURCE_CONFIG = { + "Users": { + "max_items_per_page": 10, + "schema_id": "urn:ietf:params:scim:schemas:core:2.0:User", + "db_to_scim_serializer": _serialize_db_user_to_scim_user, + "get_paginated_resources": users.get_scim_users_paginated, + "get_unique_resource": users.get_scim_user_by_id, + "parse_post_payload": _parse_scim_user_input, + "get_resource_by_unique_values": users.get_existing_scim_user_by_unique_values_from_all_users, + "restore_resource": users.restore_scim_user, + "create_resource": users.create_scim_user, + "delete_resource": users.soft_delete_scim_user_by_id, + "parse_put_payload": _parse_scim_user_input, + "update_resource": users.update_scim_user, + }, + "Groups": { + "max_items_per_page": 10, + "schema_id": "urn:ietf:params:scim:schemas:core:2.0:Group", + "db_to_scim_serializer": _serialize_db_group_to_scim_group, + "get_paginated_resources": scim_groups.get_resources_paginated, + "get_unique_resource": scim_groups.get_resource_by_id, + "parse_post_payload": _parse_scim_group_input, + "get_resource_by_unique_values": scim_groups.get_existing_resource_by_unique_values_from_all_resources, + "restore_resource": scim_groups.restore_resource, + "create_resource": scim_groups.create_resource, + "delete_resource": scim_groups.delete_resource, + "parse_put_payload": _parse_scim_group_input, + "update_resource": scim_groups.update_resource, + }, +} + + +class ListResourceType(str, Enum): + USERS = "Users" + GROUPS = "Groups" + + +@public_app.get("/{resource_type}") +async def get_resources( + resource_type: ListResourceType, tenant_id=Depends(auth_required), requested_start_index: int = Query(1, alias="startIndex"), requested_items_per_page: int | None = Query(None, alias="count"), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): + resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] start_index = max(1, requested_start_index) + max_items_per_page = resource_config["max_items_per_page"] items_per_page = min( - max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE + max(0, requested_items_per_page or max_items_per_page), max_items_per_page ) # todo(jon): this might not be the most efficient thing to do. could be better to just do a count. # but this is the fastest thing at the moment just to test that it's working - total_resources = users.get_scim_users_paginated(1, tenant_id) - db_resources = users.get_scim_users_paginated( - start_index, tenant_id, count=items_per_page + total_resources = resource_config["get_paginated_resources"](1, tenant_id) + db_resources = resource_config["get_paginated_resources"]( + start_index, tenant_id, items_per_page ) scim_resources = [ - _serialize_db_user_to_scim_user(resource, attributes, excluded_attributes) - for resource in db_resources + _serialize_db_resource_to_scim_resource_with_attribute_awareness( + db_resource, + resource_config["schema_id"], + resource_config["db_to_scim_serializer"], + attributes, + excluded_attributes, + ) + for db_resource in db_resources ] return JSONResponse( status_code=200, @@ -304,87 +396,145 @@ async def get_users( ) -@public_app.get("/Users/{user_id}") -async def get_user( - user_id: str, +class GetResourceType(str, Enum): + USERS = "Users" + GROUPS = "Groups" + + +@public_app.get("/{resource_type}/{resource_id}") +async def get_resource( + resource_type: GetResourceType, + resource_id: int, tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): - db_resource = users.get_scim_user_by_id(user_id, tenant_id) + resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) if not db_resource: - return _not_found_error_response(user_id) - scim_resource = _serialize_db_user_to_scim_user( - db_resource, attributes, excluded_attributes + return _not_found_error_response(resource_id) + scim_resource = _serialize_db_resource_to_scim_resource_with_attribute_awareness( + db_resource, + resource_config["schema_id"], + resource_config["db_to_scim_serializer"], + attributes, + excluded_attributes, ) return JSONResponse(status_code=200, content=scim_resource) -@public_app.post("/Users") -async def create_user(r: Request, tenant_id=Depends(auth_required)): +class PostResourceType(str, Enum): + USERS = "Users" + GROUPS = "Groups" + + +@public_app.post("/{resource_type}") +async def create_resource( + resource_type: PostResourceType, + r: Request, + tenant_id=Depends(auth_required), +): + resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] scim_payload = await r.json() try: - db_payload = _parse_scim_user_input(scim_payload, tenant_id) + db_payload = resource_config["parse_post_payload"](scim_payload, tenant_id) except KeyError: return _invalid_value_error_response() - existing_db_resource = users.get_existing_scim_user_by_unique_values_from_all_users( - db_payload["email"] + existing_db_resource = resource_config["get_resource_by_unique_values"]( + **db_payload ) - if existing_db_resource and existing_db_resource["deletedAt"] is None: + if existing_db_resource and existing_db_resource.get("deletedAt") is None: return _uniqueness_error_response() - if existing_db_resource and existing_db_resource["deletedAt"] is not None: - db_resource = users.restore_scim_user( - user_id=existing_db_resource["userId"], - tenant_id=tenant_id, - **db_payload, - ) + if existing_db_resource and existing_db_resource.get("deletedAt") is not None: + # todo(jon): not a super elegant solution overwriting the existing db resource. + # maybe we should try something else. + existing_db_resource.update(db_payload) + db_resource = resource_config["restore_resource"](**existing_db_resource) else: - db_resource = users.create_scim_user( + db_resource = resource_config["create_resource"]( tenant_id=tenant_id, **db_payload, ) - scim_resource = _serialize_db_user_to_scim_user(db_resource) + scim_resource = _serialize_db_resource_to_scim_resource_with_attribute_awareness( + db_resource, + resource_config["schema_id"], + resource_config["db_to_scim_serializer"], + ) response = JSONResponse(status_code=201, content=scim_resource) response.headers["Location"] = scim_resource["meta"]["location"] return response -@public_app.put("/Users/{user_id}") -async def update_user(user_id: str, r: Request, tenant_id=Depends(auth_required)): - db_resource = users.get_scim_user_by_id(user_id, tenant_id) +class DeleteResourceType(str, Enum): + USERS = "Users" + GROUPS = "Groups" + + +@public_app.delete("/{resource_type}/{resource_id}") +async def delete_resource( + resource_type: DeleteResourceType, + resource_id: str, + tenant_id=Depends(auth_required), +): + # note(jon): this can be a soft or a hard delete + resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) if not db_resource: - return _not_found_error_response(user_id) - current_scim_resource = _serialize_db_user_to_scim_user(db_resource) + return _not_found_error_response(resource_id) + resource_config["delete_resource"](resource_id, tenant_id) + return Response(status_code=204, content="") + + +class PutResourceType(str, Enum): + USERS = "Users" + GROUPS = "Groups" + + +@public_app.put("/{resource_type}/{resource_id}") +async def update_resource( + resource_type: PutResourceType, + resource_id: str, + r: Request, + tenant_id=Depends(auth_required), +): + resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) + if not db_resource: + return _not_found_error_response(resource_id) + current_scim_resource = ( + _serialize_db_resource_to_scim_resource_with_attribute_awareness( + db_resource, + resource_config["schema_id"], + resource_config["db_to_scim_serializer"], + ) + ) requested_scim_changes = await r.json() - schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"] + schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[resource_config["schema_id"]] try: valid_mutable_scim_changes = scim_helpers.filter_mutable_attributes( schema, requested_scim_changes, current_scim_resource ) except ValueError: return _mutability_error_response() - valid_mutable_db_changes = _parse_scim_user_input( + valid_mutable_db_changes = resource_config["parse_put_payload"]( valid_mutable_scim_changes, tenant_id, ) try: - updated_db_resource = users.update_scim_user( - user_id, + updated_db_resource = resource_config["update_resource"]( + resource_id, tenant_id, **valid_mutable_db_changes, ) - updated_scim_resource = _serialize_db_user_to_scim_user(updated_db_resource) + updated_scim_resource = ( + _serialize_db_resource_to_scim_resource_with_attribute_awareness( + updated_db_resource, + resource_config["schema_id"], + resource_config["db_to_scim_serializer"], + ) + ) return JSONResponse(status_code=200, content=updated_scim_resource) - except Exception: - # note(jon): for now, this is the only error that would happen when updating the scim user + except errors.UniqueViolation: return _uniqueness_error_response() - - -@public_app.delete("/Users/{user_id}") -async def delete_user(user_id: str, tenant_id=Depends(auth_required)): - # note(jon): this is a soft delete - db_resource = users.get_scim_user_by_id(user_id, tenant_id) - if not db_resource: - return _not_found_error_response(user_id) - users.soft_delete_scim_user_by_id(user_id, tenant_id) - return Response(status_code=204, content="") + except Exception as e: + return _internal_server_error_response(str(e)) diff --git a/ee/api/routers/scim_constants.py b/ee/api/routers/scim_constants.py index 090135ffc..4b52b1a8e 100644 --- a/ee/api/routers/scim_constants.py +++ b/ee/api/routers/scim_constants.py @@ -7,8 +7,7 @@ SCHEMAS = sorted( json.load(open("routers/fixtures/resource_type_schema.json", "r")), json.load(open("routers/fixtures/schema_schema.json", "r")), json.load(open("routers/fixtures/user_schema.json", "r")), - # todo(jon): add this when we have groups - # json.load(open("routers/schemas/group_schema.json", "r")), + json.load(open("routers/fixtures/group_schema.json", "r")), ], key=lambda x: x["id"], ) diff --git a/ee/api/routers/scim_groups.py b/ee/api/routers/scim_groups.py new file mode 100644 index 000000000..d80bf818d --- /dev/null +++ b/ee/api/routers/scim_groups.py @@ -0,0 +1,211 @@ +from typing import Any + +from chalicelib.utils import helper, pg_client + + +def get_resources_paginated( + offset_one_indexed: int, tenant_id: int, limit: int | None = None +) -> list[dict[str, Any]]: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT + groups.*, + users_data.array as users + FROM public.groups + LEFT JOIN LATERAL ( + SELECT json_agg(users) AS array + FROM public.users + WHERE users.group_id = groups.group_id + ) users_data ON true + WHERE groups.tenant_id = %(tenant_id)s + LIMIT %(limit)s + OFFSET %(offset)s; + """, + { + "offset": offset_one_indexed - 1, + "limit": limit, + "tenant_id": tenant_id, + }, + ) + ) + return helper.list_to_camel_case(cur.fetchall()) + + +def get_resource_by_id(group_id: int, tenant_id: int) -> dict[str, Any]: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT + groups.*, + users_data.array as users + FROM public.groups + LEFT JOIN LATERAL ( + SELECT json_agg(users) AS array + FROM public.users + WHERE users.group_id = groups.group_id + ) users_data ON true + WHERE + groups.tenant_id = %(tenant_id)s + AND groups.group_id = %(group_id)s + LIMIT 1; + """, + {"group_id": group_id, "tenant_id": tenant_id}, + ) + ) + return helper.dict_to_camel_case(cur.fetchone()) + + +def get_existing_resource_by_unique_values_from_all_resources( + **kwargs, +) -> dict[str, Any] | None: + # note(jon): we do not really use this for groups as we don't have unique values outside + # of the primary key + return None + + +def restore_resource(**kwargs: dict[str, Any]) -> dict[str, Any] | None: + # note(jon): we're not soft deleting groups, so we don't need this + return None + + +def create_resource( + name: str, tenant_id: int, **kwargs: dict[str, Any] +) -> dict[str, Any]: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH g AS( + INSERT INTO public.groups + (tenant_id, name, external_id) + VALUES (%(tenant_id)s, %(name)s, %(external_id)s) + RETURNING * + ) + SELECT g.group_id + FROM g; + """, + { + "tenant_id": tenant_id, + "name": name, + "external_id": kwargs.get("external_id"), + }, + ) + ) + group_id = cur.fetchone()["group_id"] + user_ids = kwargs.get("user_ids", []) + if user_ids: + cur.execute( + cur.mogrify( + """ + UPDATE public.users + SET group_id = %s + WHERE users.user_id = ANY(%s) + """, + (group_id, user_ids), + ) + ) + cur.execute( + cur.mogrify( + """ + SELECT + groups.*, + users_data.array as users + FROM public.groups + LEFT JOIN LATERAL ( + SELECT json_agg(users) AS array + FROM public.users + WHERE users.group_id = %(group_id)s + ) users_data ON true + WHERE + groups.group_id = %(group_id)s + AND groups.tenant_id = %(tenant_id)s + LIMIT 1; + """, + { + "group_id": group_id, + "tenant_id": tenant_id, + "name": name, + "external_id": kwargs.get("external_id"), + }, + ) + ) + return helper.dict_to_camel_case(cur.fetchone()) + + +def delete_resource(group_id: int, tenant_id: int) -> None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + DELETE FROM public.groups + WHERE groups.group_id = %(group_id)s AND groups.tenant_id = %(tenant_id)s; + """ + ), + {"tenant_id": tenant_id, "group_id": group_id}, + ) + + +def update_resource( + group_id: int, tenant_id: int, name: str, **kwargs: dict[str, Any] +) -> dict[str, Any]: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + UPDATE public.users + SET group_id = null + WHERE users.group_id = %(group_id)s; + """, + {"group_id": group_id}, + ) + ) + user_ids = kwargs.get("user_ids", []) + if user_ids: + cur.execute( + cur.mogrify( + """ + UPDATE public.users + SET group_id = %s + WHERE users.user_id = ANY(%s); + """, + (group_id, user_ids), + ) + ) + cur.execute( + cur.mogrify( + """ + WITH g AS ( + UPDATE public.groups + SET + tenant_id = %(tenant_id)s, + name = %(name)s, + external_id = %(external_id)s, + updated_at = default + WHERE + groups.group_id = %(group_id)s + AND groups.tenant_id = %(tenant_id)s + RETURNING * + ) + SELECT + g.*, + users_data.array as users + FROM g + LEFT JOIN LATERAL ( + SELECT json_agg(users) AS array + FROM public.users + WHERE users.group_id = g.group_id + ) users_data ON true + LIMIT 1; + """, + { + "group_id": group_id, + "tenant_id": tenant_id, + "name": name, + "external_id": kwargs.get("external_id"), + }, + ) + ) + return helper.dict_to_camel_case(cur.fetchone()) diff --git a/ee/api/routers/scim_helpers.py b/ee/api/routers/scim_helpers.py index 6c04ecab8..cda66b29c 100644 --- a/ee/api/routers/scim_helpers.py +++ b/ee/api/routers/scim_helpers.py @@ -41,31 +41,35 @@ def filter_attributes( resource: dict[str, Any], include_list: list[str] ) -> dict[str, Any]: result = {} - for attr in include_list: - parts = attr.split(".", 1) + + # Group include paths by top-level key + includes_by_key = {} + for path in include_list: + parts = path.split(".", 1) key = parts[0] + rest = parts[1] if len(parts) == 2 else None + includes_by_key.setdefault(key, []).append(rest) + + for key, subpaths in includes_by_key.items(): if key not in resource: continue - if len(parts) == 1: - # top‑level attr - result[key] = resource[key] + value = resource[key] + if all(p is None for p in subpaths): + result[key] = value else: - # nested attr - sub = resource[key] - rest = parts[1] - if isinstance(sub, dict): - filtered = filter_attributes(sub, [rest]) + nested_paths = [p for p in subpaths if p is not None] + if isinstance(value, dict): + filtered = filter_attributes(value, nested_paths) if filtered: - result.setdefault(key, {}).update(filtered) - elif isinstance(sub, list): - # apply to each element + result[key] = filtered + elif isinstance(value, list): new_list = [] - for item in sub: + for item in value: if isinstance(item, dict): - f = filter_attributes(item, [rest]) - if f: - new_list.append(f) + filtered_item = filter_attributes(item, nested_paths) + if filtered_item: + new_list.append(filtered_item) if new_list: result[key] = new_list return result diff --git a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql index caf4e7467..96fb5a23b 100644 --- a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql +++ b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql @@ -122,6 +122,16 @@ CREATE TABLE public.roles service_role bool NOT NULL DEFAULT FALSE ); +CREATE TABLE public.groups +( + group_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY, + tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE, + external_id text, + name text NOT NULL, + created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'), + updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc') +); + CREATE TYPE user_role AS ENUM ('owner','admin','member','service'); CREATE TABLE public.users @@ -151,7 +161,8 @@ CREATE TABLE public.users origin text NULL DEFAULT NULL, role_id integer REFERENCES public.roles (role_id) ON DELETE SET NULL, internal_id text NULL DEFAULT NULL, - service_account bool NOT NULL DEFAULT FALSE + service_account bool NOT NULL DEFAULT FALSE, + group_id integer REFERENCES public.groups (group_id) ON DELETE SET NULL ); CREATE INDEX users_tenant_id_deleted_at_N_idx ON public.users (tenant_id) WHERE deleted_at ISNULL; CREATE INDEX users_name_gin_idx ON public.users USING GIN (name gin_trgm_ops);