update GET Users api to be minimal working rfc version
This commit is contained in:
parent
13f46fe566
commit
bcfa421b8f
4 changed files with 363 additions and 220 deletions
|
|
@ -288,33 +288,20 @@ def get_by_uuid(user_uuid, tenant_id):
|
|||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
f"""SELECT
|
||||
users.user_id,
|
||||
users.tenant_id,
|
||||
email,
|
||||
role,
|
||||
users.name,
|
||||
users.data,
|
||||
users.internal_id,
|
||||
(CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin,
|
||||
(CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin,
|
||||
(CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member,
|
||||
origin,
|
||||
role_id,
|
||||
roles.name AS role_name,
|
||||
roles.permissions,
|
||||
roles.all_projects,
|
||||
basic_authentication.password IS NOT NULL AS has_password,
|
||||
users.service_account
|
||||
FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id
|
||||
LEFT JOIN public.roles USING (role_id)
|
||||
WHERE
|
||||
users.data->>'user_id' = %(user_uuid)s
|
||||
AND users.tenant_id = %(tenant_id)s
|
||||
AND users.deleted_at IS NULL
|
||||
AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s)
|
||||
LIMIT 1;""",
|
||||
{"user_uuid": user_uuid, "tenant_id": tenant_id})
|
||||
"""
|
||||
SELECT *
|
||||
FROM public.users
|
||||
WHERE
|
||||
users.deleted_at IS NULL
|
||||
AND users.user_id = %(user_id)s
|
||||
AND users.tenant_id = %(tenant_id)s
|
||||
LIMIT 1;
|
||||
""",
|
||||
{
|
||||
"user_id": user_uuid,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
r = cur.fetchone()
|
||||
return helper.dict_to_camel_case(r)
|
||||
|
|
@ -481,39 +468,28 @@ def get_by_email_only(email):
|
|||
r = cur.fetchone()
|
||||
return helper.dict_to_camel_case(r)
|
||||
|
||||
def get_users_paginated(start_index, count=None, email=None):
|
||||
def get_users_paginated(start_index, tenant_id, count=None):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
f"""SELECT
|
||||
users.user_id AS id,
|
||||
users.tenant_id,
|
||||
users.email AS email,
|
||||
users.data AS data,
|
||||
users.role,
|
||||
users.name AS name,
|
||||
(CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin,
|
||||
(CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin,
|
||||
(CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member,
|
||||
origin,
|
||||
basic_authentication.password IS NOT NULL AS has_password,
|
||||
role_id,
|
||||
internal_id,
|
||||
roles.name AS role_name
|
||||
FROM public.users LEFT JOIN public.basic_authentication USING(user_id)
|
||||
INNER JOIN public.roles USING(role_id)
|
||||
WHERE users.deleted_at IS NULL
|
||||
AND users.data ? 'user_id'
|
||||
AND email = COALESCE(%(email)s, email)
|
||||
LIMIT %(count)s
|
||||
OFFSET %(startIndex)s;;""",
|
||||
{"startIndex": start_index - 1, "count": count, "email": email})
|
||||
"""
|
||||
SELECT *
|
||||
FROM public.users
|
||||
WHERE
|
||||
users.deleted_at IS NULL
|
||||
AND users.tenant_id = %(tenant_id)s
|
||||
LIMIT %(limit)s
|
||||
OFFSET %(offset)s;
|
||||
""",
|
||||
{
|
||||
"offset": start_index - 1,
|
||||
"limit": count,
|
||||
"tenant_id": tenant_id
|
||||
},
|
||||
)
|
||||
)
|
||||
r = cur.fetchall()
|
||||
if len(r):
|
||||
r = helper.list_to_camel_case(r)
|
||||
return r
|
||||
return []
|
||||
return helper.list_to_camel_case(r)
|
||||
|
||||
|
||||
def get_member(tenant_id, user_id):
|
||||
|
|
|
|||
|
|
@ -1,20 +1,22 @@
|
|||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Any, Literal, Optional
|
||||
import copy
|
||||
from datetime import datetime
|
||||
|
||||
from decouple import config
|
||||
from fastapi import Depends, HTTPException, Header, Query, Response, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
|
||||
import schemas
|
||||
from chalicelib.core import users, roles, tenants
|
||||
from chalicelib.utils.scim_auth import auth_optional, auth_required, create_tokens, verify_refresh_token
|
||||
from routers.base import get_routers
|
||||
from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG
|
||||
from routers import scim_helpers
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -189,90 +191,124 @@ class UserRequest(BaseModel):
|
|||
password: str = Field(default=None)
|
||||
active: bool
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
schemas: list[str]
|
||||
id: str
|
||||
userName: str
|
||||
name: Name
|
||||
emails: list[Email] # ignore for now
|
||||
displayName: str
|
||||
locale: str
|
||||
externalId: str
|
||||
active: bool
|
||||
groups: list[dict]
|
||||
meta: dict = Field(default={"resourceType": "User"})
|
||||
|
||||
class PatchUserRequest(BaseModel):
|
||||
schemas: list[str]
|
||||
Operations: list[dict]
|
||||
|
||||
|
||||
@public_app.get("/Users", dependencies=[Depends(auth_required)])
|
||||
async def get_users(
|
||||
start_index: int = Query(1, alias="startIndex"),
|
||||
count: Optional[int] = Query(None, alias="count"),
|
||||
email: Optional[str] = Query(None, alias="filter"),
|
||||
):
|
||||
"""Get SCIM Users"""
|
||||
if email:
|
||||
email = email.split(" ")[2].strip('"')
|
||||
result_users = users.get_users_paginated(start_index, count, email)
|
||||
class ResourceMetaResponse(BaseModel):
|
||||
resourceType: Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None = None
|
||||
created: datetime | None = None
|
||||
lastModified: datetime | None = None
|
||||
location: str | None = None
|
||||
version: str | None = None
|
||||
|
||||
serialized_users = []
|
||||
for user in result_users:
|
||||
serialized_users.append(
|
||||
UserResponse(
|
||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id = user["data"]["userId"],
|
||||
userName = user["email"],
|
||||
name = Name.model_validate(user["data"]["name"]),
|
||||
emails = [Email.model_validate(user["data"]["emails"])],
|
||||
displayName = user["name"],
|
||||
locale = user["data"]["locale"],
|
||||
externalId = user["internalId"],
|
||||
active = True, # ignore for now, since, can't insert actual timestamp
|
||||
groups = [], # ignore
|
||||
).model_dump(mode='json')
|
||||
)
|
||||
@field_serializer("created", "lastModified")
|
||||
def serialize_datetime(self, dt: datetime) -> str | None:
|
||||
if not dt:
|
||||
return None
|
||||
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
class CommonResourceResponse(BaseModel):
|
||||
id: str
|
||||
externalId: str | None = None
|
||||
schemas: list[
|
||||
Literal[
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ResourceType",
|
||||
"urn:ietf:params:scim:schemas:core:2.0:Schema",
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
]
|
||||
]
|
||||
meta: ResourceMetaResponse | None = None
|
||||
|
||||
|
||||
class UserResponse(CommonResourceResponse):
|
||||
schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = ["urn:ietf:params:scim:schemas:core:2.0:User"]
|
||||
userName: str | None = None
|
||||
|
||||
|
||||
class QueryResourceResponse(BaseModel):
|
||||
schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = ["urn:ietf:params:scim:api:messages:2.0:ListResponse"]
|
||||
totalResults: int
|
||||
# todo(jon): add the other schemas
|
||||
Resources: list[UserResponse]
|
||||
startIndex: int
|
||||
itemsPerPage: int
|
||||
|
||||
|
||||
MAX_USERS_PER_PAGE = 10
|
||||
|
||||
|
||||
def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str] | None = None, excluded_attributes: list[str] | None = None) -> UserResponse:
|
||||
user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"]
|
||||
all_attributes = scim_helpers.get_all_attribute_names(user_schema)
|
||||
attributes = attributes or all_attributes
|
||||
always_returned_attributes = scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema)
|
||||
included_attributes = list(set(attributes).union(set(always_returned_attributes)))
|
||||
excluded_attributes = excluded_attributes or []
|
||||
excluded_attributes = list(set(excluded_attributes).difference(set(always_returned_attributes)))
|
||||
scim_user = {
|
||||
"id": str(db_user["userId"]),
|
||||
"meta": {
|
||||
"resourceType": "User",
|
||||
"created": db_user["createdAt"],
|
||||
"lastModified": db_user["createdAt"], # todo(jon): we currently don't keep track of this in the db
|
||||
"location": f"Users/{db_user['userId']}"
|
||||
},
|
||||
"userName": db_user["email"],
|
||||
}
|
||||
scim_user = scim_helpers.filter_attributes(scim_user, included_attributes)
|
||||
scim_user = scim_helpers.exclude_attributes(scim_user, excluded_attributes)
|
||||
return UserResponse(**scim_user)
|
||||
|
||||
|
||||
@public_app.get("/Users")
|
||||
async def get_users(
|
||||
tenant_id = Depends(auth_required),
|
||||
requested_start_index: int = Query(1, alias="startIndex"),
|
||||
requested_items_per_page: int | None = Query(None, alias="count"),
|
||||
attributes: list[str] | None = Query(None),
|
||||
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
|
||||
):
|
||||
start_index = max(1, requested_start_index)
|
||||
items_per_page = min(max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE)
|
||||
# todo(jon): this might not be the most efficient thing to do. could be better to just do a count.
|
||||
# but this is the fastest thing at the moment just to test that it's working
|
||||
total_users = users.get_users_paginated(1, tenant_id)
|
||||
db_users = users.get_users_paginated(start_index, tenant_id, count=items_per_page)
|
||||
scim_users = [
|
||||
_convert_db_user_to_scim_user(user, attributes, excluded_attributes)
|
||||
for user in db_users
|
||||
]
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
|
||||
"totalResults": len(serialized_users),
|
||||
"startIndex": start_index,
|
||||
"itemsPerPage": len(serialized_users),
|
||||
"Resources": serialized_users,
|
||||
},
|
||||
content=QueryResourceResponse(
|
||||
totalResults=len(total_users),
|
||||
startIndex=start_index,
|
||||
itemsPerPage=len(scim_users),
|
||||
Resources=scim_users,
|
||||
).model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)])
|
||||
def get_user(user_id: str):
|
||||
"""Get SCIM User"""
|
||||
tenant_id = 1
|
||||
user = users.get_by_uuid(user_id, tenant_id)
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
||||
"detail": "User not found",
|
||||
"status": 404,
|
||||
}
|
||||
)
|
||||
|
||||
res = UserResponse(
|
||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id = user["data"]["userId"],
|
||||
userName = user["email"],
|
||||
name = Name.model_validate(user["data"]["name"]),
|
||||
emails = [Email.model_validate(user["data"]["emails"])],
|
||||
displayName = user["name"],
|
||||
locale = user["data"]["locale"],
|
||||
externalId = user["internalId"],
|
||||
active = True, # ignore for now, since, can't insert actual timestamp
|
||||
groups = [], # ignore
|
||||
@public_app.get("/Users/{user_id}")
|
||||
def get_user(
|
||||
user_id: str,
|
||||
tenant_id = Depends(auth_required),
|
||||
attributes: list[str] | None = Query(None),
|
||||
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
|
||||
):
|
||||
db_user = users.get_by_uuid(user_id, tenant_id)
|
||||
if not db_user:
|
||||
return _not_found_error_response(user_id)
|
||||
scim_user = _convert_db_user_to_scim_user(db_user, attributes, excluded_attributes)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=scim_user.model_dump(mode="json", exclude_none=True)
|
||||
)
|
||||
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
|
||||
|
||||
|
||||
@public_app.post("/Users", dependencies=[Depends(auth_required)])
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
def _attribute_characteristics(
|
||||
name: str,
|
||||
|
|
@ -102,12 +103,12 @@ def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none
|
|||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ResourceType",
|
||||
"urn:ietf:params:scim:schemas:core:2.0:Schema",
|
||||
# todo(jon): add the user and group schem when completed
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
],
|
||||
case_exact=True,
|
||||
mutability="readOnly",
|
||||
returned="default",
|
||||
required=True,
|
||||
returned="always",
|
||||
),
|
||||
_attribute_characteristics(
|
||||
name="meta",
|
||||
|
|
@ -670,13 +671,38 @@ SCHEMA_SCHEMA = {
|
|||
}
|
||||
|
||||
|
||||
USER_SCHEMA = {
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
|
||||
"id": "urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
"name": "User",
|
||||
"description": "User account.",
|
||||
"meta": {
|
||||
"resourceType": "Schema",
|
||||
"created": "2025-04-16T14:48:00Z",
|
||||
# note(jon): we might want to think about adding this resource as part of our db
|
||||
# and then updating these timestamps from an api and such. for now, if we update
|
||||
# the configuration, we should update the timestamp here.
|
||||
"lastModified": "2025-04-16T14:48:00Z",
|
||||
"location": "Schemas/urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
},
|
||||
"attributes": [
|
||||
*_common_resource_attributes(),
|
||||
_attribute_characteristics(
|
||||
name="userName",
|
||||
description="A service provider's unique identifier for the user.",
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
||||
SCHEMAS = sorted(
|
||||
# todo(jon): add the user schema
|
||||
[
|
||||
SERVICE_PROVIDER_CONFIG_SCHEMA,
|
||||
RESOURCE_TYPE_SCHEMA,
|
||||
SCHEMA_SCHEMA,
|
||||
USER_SCHEMA,
|
||||
],
|
||||
key=lambda x: x["id"],
|
||||
)
|
||||
|
|
|
|||
105
ee/api/routers/scim_helpers.py
Normal file
105
ee/api/routers/scim_helpers.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from typing import Any
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
|
||||
result = []
|
||||
def _walk(attrs, prefix=None):
|
||||
for attr in attrs:
|
||||
name = attr["name"]
|
||||
path = f"{prefix}.{name}" if prefix else name
|
||||
result.append(path)
|
||||
if attr["type"] == "complex":
|
||||
sub = attr.get("subAttributes") or attr.get("attributes") or []
|
||||
_walk(sub, path)
|
||||
_walk(schema["attributes"])
|
||||
return result
|
||||
|
||||
|
||||
def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) -> list[str]:
|
||||
result = []
|
||||
def _walk(attrs, prefix=None):
|
||||
for attr in attrs:
|
||||
name = attr["name"]
|
||||
path = f"{prefix}.{name}" if prefix else name
|
||||
if attr["returned"] == "always":
|
||||
result.append(path)
|
||||
if attr["type"] == "complex":
|
||||
sub = attr.get("subAttributes") or attr.get("attributes") or []
|
||||
_walk(sub, path)
|
||||
_walk(schema["attributes"])
|
||||
return result
|
||||
|
||||
|
||||
def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict[str, Any]:
|
||||
result = {}
|
||||
for attr in include_list:
|
||||
parts = attr.split(".", 1)
|
||||
key = parts[0]
|
||||
if key not in resource:
|
||||
continue
|
||||
|
||||
if len(parts) == 1:
|
||||
# top‑level attr
|
||||
result[key] = resource[key]
|
||||
else:
|
||||
# nested attr
|
||||
sub = resource[key]
|
||||
rest = parts[1]
|
||||
if isinstance(sub, dict):
|
||||
filtered = filter_attributes(sub, [rest])
|
||||
if filtered:
|
||||
result.setdefault(key, {}).update(filtered)
|
||||
elif isinstance(sub, list):
|
||||
# apply to each element
|
||||
new_list = []
|
||||
for item in sub:
|
||||
if isinstance(item, dict):
|
||||
f = filter_attributes(item, [rest])
|
||||
if f:
|
||||
new_list.append(f)
|
||||
if new_list:
|
||||
result[key] = new_list
|
||||
return result
|
||||
|
||||
|
||||
def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dict[str, Any]:
|
||||
exclude_map = {}
|
||||
for attr in exclude_list:
|
||||
parts = attr.split(".", 1)
|
||||
key = parts[0]
|
||||
# rest is empty string for top-level exclusion
|
||||
rest = parts[1] if len(parts) == 2 else ""
|
||||
exclude_map.setdefault(key, []).append(rest)
|
||||
|
||||
new_resource = {}
|
||||
for key, value in resource.items():
|
||||
if key in exclude_map:
|
||||
subs = exclude_map[key]
|
||||
# If any attr has no rest, exclude entire key
|
||||
if "" in subs:
|
||||
continue
|
||||
# Exclude nested attributes
|
||||
if isinstance(value, dict):
|
||||
new_sub = exclude_attributes(value, subs)
|
||||
if not new_sub:
|
||||
continue
|
||||
new_resource[key] = new_sub
|
||||
elif isinstance(value, list):
|
||||
new_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
new_item = exclude_attributes(item, subs)
|
||||
new_list.append(new_item)
|
||||
else:
|
||||
new_list.append(item)
|
||||
new_resource[key] = new_list
|
||||
else:
|
||||
new_resource[key] = value
|
||||
else:
|
||||
# No exclusion for this key: copy safely
|
||||
if isinstance(value, (dict, list)):
|
||||
new_resource[key] = deepcopy(value)
|
||||
else:
|
||||
new_resource[key] = value
|
||||
return new_resource
|
||||
Loading…
Add table
Reference in a new issue