update GET Users api to be minimal working rfc version

This commit is contained in:
Jonathan Griffin 2025-04-17 14:01:59 +02:00
parent 13f46fe566
commit bcfa421b8f
4 changed files with 363 additions and 220 deletions

View file

@ -288,33 +288,20 @@ def get_by_uuid(user_uuid, tenant_id):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
f"""SELECT
users.user_id,
users.tenant_id,
email,
role,
users.name,
users.data,
users.internal_id,
(CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin,
(CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin,
(CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member,
origin,
role_id,
roles.name AS role_name,
roles.permissions,
roles.all_projects,
basic_authentication.password IS NOT NULL AS has_password,
users.service_account
FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id
LEFT JOIN public.roles USING (role_id)
WHERE
users.data->>'user_id' = %(user_uuid)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s)
LIMIT 1;""",
{"user_uuid": user_uuid, "tenant_id": tenant_id})
"""
SELECT *
FROM public.users
WHERE
users.deleted_at IS NULL
AND users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
LIMIT 1;
""",
{
"user_id": user_uuid,
"tenant_id": tenant_id,
},
)
)
r = cur.fetchone()
return helper.dict_to_camel_case(r)
@ -481,39 +468,28 @@ def get_by_email_only(email):
r = cur.fetchone()
return helper.dict_to_camel_case(r)
def get_users_paginated(start_index, count=None, email=None):
def get_users_paginated(start_index, tenant_id, count=None):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
f"""SELECT
users.user_id AS id,
users.tenant_id,
users.email AS email,
users.data AS data,
users.role,
users.name AS name,
(CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin,
(CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin,
(CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member,
origin,
basic_authentication.password IS NOT NULL AS has_password,
role_id,
internal_id,
roles.name AS role_name
FROM public.users LEFT JOIN public.basic_authentication USING(user_id)
INNER JOIN public.roles USING(role_id)
WHERE users.deleted_at IS NULL
AND users.data ? 'user_id'
AND email = COALESCE(%(email)s, email)
LIMIT %(count)s
OFFSET %(startIndex)s;;""",
{"startIndex": start_index - 1, "count": count, "email": email})
"""
SELECT *
FROM public.users
WHERE
users.deleted_at IS NULL
AND users.tenant_id = %(tenant_id)s
LIMIT %(limit)s
OFFSET %(offset)s;
""",
{
"offset": start_index - 1,
"limit": count,
"tenant_id": tenant_id
},
)
)
r = cur.fetchall()
if len(r):
r = helper.list_to_camel_case(r)
return r
return []
return helper.list_to_camel_case(r)
def get_member(tenant_id, user_id):

View file

@ -1,20 +1,22 @@
import logging
import re
import uuid
from typing import Optional
from typing import Any, Literal, Optional
import copy
from datetime import datetime
from decouple import config
from fastapi import Depends, HTTPException, Header, Query, Response, Request
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_serializer
import schemas
from chalicelib.core import users, roles, tenants
from chalicelib.utils.scim_auth import auth_optional, auth_required, create_tokens, verify_refresh_token
from routers.base import get_routers
from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG
from routers import scim_helpers
logger = logging.getLogger(__name__)
@ -189,90 +191,124 @@ class UserRequest(BaseModel):
password: str = Field(default=None)
active: bool
class UserResponse(BaseModel):
schemas: list[str]
id: str
userName: str
name: Name
emails: list[Email] # ignore for now
displayName: str
locale: str
externalId: str
active: bool
groups: list[dict]
meta: dict = Field(default={"resourceType": "User"})
class PatchUserRequest(BaseModel):
schemas: list[str]
Operations: list[dict]
@public_app.get("/Users", dependencies=[Depends(auth_required)])
async def get_users(
start_index: int = Query(1, alias="startIndex"),
count: Optional[int] = Query(None, alias="count"),
email: Optional[str] = Query(None, alias="filter"),
):
"""Get SCIM Users"""
if email:
email = email.split(" ")[2].strip('"')
result_users = users.get_users_paginated(start_index, count, email)
class ResourceMetaResponse(BaseModel):
resourceType: Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None = None
created: datetime | None = None
lastModified: datetime | None = None
location: str | None = None
version: str | None = None
serialized_users = []
for user in result_users:
serialized_users.append(
UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"],
userName = user["email"],
name = Name.model_validate(user["data"]["name"]),
emails = [Email.model_validate(user["data"]["emails"])],
displayName = user["name"],
locale = user["data"]["locale"],
externalId = user["internalId"],
active = True, # ignore for now, since, can't insert actual timestamp
groups = [], # ignore
).model_dump(mode='json')
)
@field_serializer("created", "lastModified")
def serialize_datetime(self, dt: datetime) -> str | None:
if not dt:
return None
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
class CommonResourceResponse(BaseModel):
id: str
externalId: str | None = None
schemas: list[
Literal[
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
"urn:ietf:params:scim:schemas:core:2.0:ResourceType",
"urn:ietf:params:scim:schemas:core:2.0:Schema",
"urn:ietf:params:scim:schemas:core:2.0:User",
]
]
meta: ResourceMetaResponse | None = None
class UserResponse(CommonResourceResponse):
schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = ["urn:ietf:params:scim:schemas:core:2.0:User"]
userName: str | None = None
class QueryResourceResponse(BaseModel):
schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = ["urn:ietf:params:scim:api:messages:2.0:ListResponse"]
totalResults: int
# todo(jon): add the other schemas
Resources: list[UserResponse]
startIndex: int
itemsPerPage: int
MAX_USERS_PER_PAGE = 10
def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str] | None = None, excluded_attributes: list[str] | None = None) -> UserResponse:
user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"]
all_attributes = scim_helpers.get_all_attribute_names(user_schema)
attributes = attributes or all_attributes
always_returned_attributes = scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema)
included_attributes = list(set(attributes).union(set(always_returned_attributes)))
excluded_attributes = excluded_attributes or []
excluded_attributes = list(set(excluded_attributes).difference(set(always_returned_attributes)))
scim_user = {
"id": str(db_user["userId"]),
"meta": {
"resourceType": "User",
"created": db_user["createdAt"],
"lastModified": db_user["createdAt"], # todo(jon): we currently don't keep track of this in the db
"location": f"Users/{db_user['userId']}"
},
"userName": db_user["email"],
}
scim_user = scim_helpers.filter_attributes(scim_user, included_attributes)
scim_user = scim_helpers.exclude_attributes(scim_user, excluded_attributes)
return UserResponse(**scim_user)
@public_app.get("/Users")
async def get_users(
tenant_id = Depends(auth_required),
requested_start_index: int = Query(1, alias="startIndex"),
requested_items_per_page: int | None = Query(None, alias="count"),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
start_index = max(1, requested_start_index)
items_per_page = min(max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE)
# todo(jon): this might not be the most efficient thing to do. could be better to just do a count.
# but this is the fastest thing at the moment just to test that it's working
total_users = users.get_users_paginated(1, tenant_id)
db_users = users.get_users_paginated(start_index, tenant_id, count=items_per_page)
scim_users = [
_convert_db_user_to_scim_user(user, attributes, excluded_attributes)
for user in db_users
]
return JSONResponse(
status_code=200,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"totalResults": len(serialized_users),
"startIndex": start_index,
"itemsPerPage": len(serialized_users),
"Resources": serialized_users,
},
content=QueryResourceResponse(
totalResults=len(total_users),
startIndex=start_index,
itemsPerPage=len(scim_users),
Resources=scim_users,
).model_dump(mode="json", exclude_none=True),
)
@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)])
def get_user(user_id: str):
"""Get SCIM User"""
tenant_id = 1
user = users.get_by_uuid(user_id, tenant_id)
if not user:
return JSONResponse(
status_code=404,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "User not found",
"status": 404,
}
)
res = UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"],
userName = user["email"],
name = Name.model_validate(user["data"]["name"]),
emails = [Email.model_validate(user["data"]["emails"])],
displayName = user["name"],
locale = user["data"]["locale"],
externalId = user["internalId"],
active = True, # ignore for now, since, can't insert actual timestamp
groups = [], # ignore
@public_app.get("/Users/{user_id}")
def get_user(
user_id: str,
tenant_id = Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
db_user = users.get_by_uuid(user_id, tenant_id)
if not db_user:
return _not_found_error_response(user_id)
scim_user = _convert_db_user_to_scim_user(db_user, attributes, excluded_attributes)
return JSONResponse(
status_code=200,
content=scim_user.model_dump(mode="json", exclude_none=True)
)
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
@public_app.post("/Users", dependencies=[Depends(auth_required)])

View file

@ -1,5 +1,6 @@
# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants
from typing import Any
from typing import Any, Literal
def _attribute_characteristics(
name: str,
@ -102,12 +103,12 @@ def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
"urn:ietf:params:scim:schemas:core:2.0:ResourceType",
"urn:ietf:params:scim:schemas:core:2.0:Schema",
# todo(jon): add the user and group schem when completed
"urn:ietf:params:scim:schemas:core:2.0:User",
],
case_exact=True,
mutability="readOnly",
returned="default",
required=True,
returned="always",
),
_attribute_characteristics(
name="meta",
@ -670,13 +671,38 @@ SCHEMA_SCHEMA = {
}
USER_SCHEMA = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
"id": "urn:ietf:params:scim:schemas:core:2.0:User",
"name": "User",
"description": "User account.",
"meta": {
"resourceType": "Schema",
"created": "2025-04-16T14:48:00Z",
# note(jon): we might want to think about adding this resource as part of our db
# and then updating these timestamps from an api and such. for now, if we update
# the configuration, we should update the timestamp here.
"lastModified": "2025-04-16T14:48:00Z",
"location": "Schemas/urn:ietf:params:scim:schemas:core:2.0:User",
},
"attributes": [
*_common_resource_attributes(),
_attribute_characteristics(
name="userName",
description="A service provider's unique identifier for the user.",
required=True,
),
],
}
SCHEMAS = sorted(
# todo(jon): add the user schema
[
SERVICE_PROVIDER_CONFIG_SCHEMA,
RESOURCE_TYPE_SCHEMA,
SCHEMA_SCHEMA,
USER_SCHEMA,
],
key=lambda x: x["id"],
)

View file

@ -0,0 +1,105 @@
from typing import Any
from copy import deepcopy
def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
result = []
def _walk(attrs, prefix=None):
for attr in attrs:
name = attr["name"]
path = f"{prefix}.{name}" if prefix else name
result.append(path)
if attr["type"] == "complex":
sub = attr.get("subAttributes") or attr.get("attributes") or []
_walk(sub, path)
_walk(schema["attributes"])
return result
def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) -> list[str]:
result = []
def _walk(attrs, prefix=None):
for attr in attrs:
name = attr["name"]
path = f"{prefix}.{name}" if prefix else name
if attr["returned"] == "always":
result.append(path)
if attr["type"] == "complex":
sub = attr.get("subAttributes") or attr.get("attributes") or []
_walk(sub, path)
_walk(schema["attributes"])
return result
def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict[str, Any]:
result = {}
for attr in include_list:
parts = attr.split(".", 1)
key = parts[0]
if key not in resource:
continue
if len(parts) == 1:
# toplevel attr
result[key] = resource[key]
else:
# nested attr
sub = resource[key]
rest = parts[1]
if isinstance(sub, dict):
filtered = filter_attributes(sub, [rest])
if filtered:
result.setdefault(key, {}).update(filtered)
elif isinstance(sub, list):
# apply to each element
new_list = []
for item in sub:
if isinstance(item, dict):
f = filter_attributes(item, [rest])
if f:
new_list.append(f)
if new_list:
result[key] = new_list
return result
def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dict[str, Any]:
exclude_map = {}
for attr in exclude_list:
parts = attr.split(".", 1)
key = parts[0]
# rest is empty string for top-level exclusion
rest = parts[1] if len(parts) == 2 else ""
exclude_map.setdefault(key, []).append(rest)
new_resource = {}
for key, value in resource.items():
if key in exclude_map:
subs = exclude_map[key]
# If any attr has no rest, exclude entire key
if "" in subs:
continue
# Exclude nested attributes
if isinstance(value, dict):
new_sub = exclude_attributes(value, subs)
if not new_sub:
continue
new_resource[key] = new_sub
elif isinstance(value, list):
new_list = []
for item in value:
if isinstance(item, dict):
new_item = exclude_attributes(item, subs)
new_list.append(new_item)
else:
new_list.append(item)
new_resource[key] = new_list
else:
new_resource[key] = value
else:
# No exclusion for this key: copy safely
if isinstance(value, (dict, list)):
new_resource[key] = deepcopy(value)
else:
new_resource[key] = value
return new_resource