Merge 7324283c29 into cd70633d1f
This commit is contained in:
commit
e1fdbb1c36
19 changed files with 1958 additions and 1260 deletions
1
ee/api/.gitignore
vendored
1
ee/api/.gitignore
vendored
|
|
@ -283,4 +283,3 @@ Pipfile.lock
|
||||||
/chalicelib/utils/contextual_validators.py
|
/chalicelib/utils/contextual_validators.py
|
||||||
/routers/subs/product_analytics.py
|
/routers/subs/product_analytics.py
|
||||||
/schemas/product_analytics.py
|
/schemas/product_analytics.py
|
||||||
/ee/bin/*
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ xmlsec = "==1.3.14"
|
||||||
python-multipart = "==0.0.20"
|
python-multipart = "==0.0.20"
|
||||||
redis = "==6.1.0"
|
redis = "==6.1.0"
|
||||||
azure-storage-blob = "==12.25.1"
|
azure-storage-blob = "==12.25.1"
|
||||||
|
scim2-server = "*"
|
||||||
|
scim2-models = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from decouple import config
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||||
from psycopg import AsyncConnection
|
from psycopg import AsyncConnection
|
||||||
from psycopg.rows import dict_row
|
from psycopg.rows import dict_row
|
||||||
from starlette import status
|
from starlette import status
|
||||||
|
|
@ -21,12 +22,20 @@ from chalicelib.utils import pg_client, ch_client
|
||||||
from crons import core_crons, ee_crons, core_dynamic_crons
|
from crons import core_crons, ee_crons, core_dynamic_crons
|
||||||
from routers import core, core_dynamic
|
from routers import core, core_dynamic
|
||||||
from routers import ee
|
from routers import ee
|
||||||
from routers.subs import insights, metrics, v1_api, health, usability_tests, spot, product_analytics
|
from routers.subs import (
|
||||||
|
insights,
|
||||||
|
metrics,
|
||||||
|
v1_api,
|
||||||
|
health,
|
||||||
|
usability_tests,
|
||||||
|
spot,
|
||||||
|
product_analytics,
|
||||||
|
)
|
||||||
from routers.subs import v1_api_ee
|
from routers.subs import v1_api_ee
|
||||||
|
|
||||||
if config("ENABLE_SSO", cast=bool, default=True):
|
if config("ENABLE_SSO", cast=bool, default=True):
|
||||||
from routers import saml
|
from routers import saml
|
||||||
from routers import scim
|
from routers.scim import api as scim
|
||||||
|
|
||||||
loglevel = config("LOGLEVEL", default=logging.WARNING)
|
loglevel = config("LOGLEVEL", default=logging.WARNING)
|
||||||
print(f">Loglevel set to: {loglevel}")
|
print(f">Loglevel set to: {loglevel}")
|
||||||
|
|
@ -34,7 +43,6 @@ logging.basicConfig(level=loglevel)
|
||||||
|
|
||||||
|
|
||||||
class ORPYAsyncConnection(AsyncConnection):
|
class ORPYAsyncConnection(AsyncConnection):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, row_factory=dict_row, **kwargs)
|
super().__init__(*args, row_factory=dict_row, **kwargs)
|
||||||
|
|
||||||
|
|
@ -43,7 +51,7 @@ class ORPYAsyncConnection(AsyncConnection):
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup
|
# Startup
|
||||||
logging.info(">>>>> starting up <<<<<")
|
logging.info(">>>>> starting up <<<<<")
|
||||||
ap_logger = logging.getLogger('apscheduler')
|
ap_logger = logging.getLogger("apscheduler")
|
||||||
ap_logger.setLevel(loglevel)
|
ap_logger.setLevel(loglevel)
|
||||||
|
|
||||||
app.schedule = AsyncIOScheduler()
|
app.schedule = AsyncIOScheduler()
|
||||||
|
|
@ -53,12 +61,23 @@ async def lifespan(app: FastAPI):
|
||||||
await events_queue.init()
|
await events_queue.init()
|
||||||
app.schedule.start()
|
app.schedule.start()
|
||||||
|
|
||||||
for job in core_crons.cron_jobs + core_dynamic_crons.cron_jobs + traces.cron_jobs + ee_crons.ee_cron_jobs:
|
for job in (
|
||||||
|
core_crons.cron_jobs
|
||||||
|
+ core_dynamic_crons.cron_jobs
|
||||||
|
+ traces.cron_jobs
|
||||||
|
+ ee_crons.ee_cron_jobs
|
||||||
|
):
|
||||||
app.schedule.add_job(id=job["func"].__name__, **job)
|
app.schedule.add_job(id=job["func"].__name__, **job)
|
||||||
|
|
||||||
ap_logger.info(">Scheduled jobs:")
|
ap_logger.info(">Scheduled jobs:")
|
||||||
for job in app.schedule.get_jobs():
|
for job in app.schedule.get_jobs():
|
||||||
ap_logger.info({"Name": str(job.id), "Run Frequency": str(job.trigger), "Next Run": str(job.next_run_time)})
|
ap_logger.info(
|
||||||
|
{
|
||||||
|
"Name": str(job.id),
|
||||||
|
"Run Frequency": str(job.trigger),
|
||||||
|
"Next Run": str(job.next_run_time),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
database = {
|
database = {
|
||||||
"host": config("pg_host", default="localhost"),
|
"host": config("pg_host", default="localhost"),
|
||||||
|
|
@ -69,9 +88,12 @@ async def lifespan(app: FastAPI):
|
||||||
"application_name": "AIO" + config("APP_NAME", default="PY"),
|
"application_name": "AIO" + config("APP_NAME", default="PY"),
|
||||||
}
|
}
|
||||||
|
|
||||||
database = psycopg_pool.AsyncConnectionPool(kwargs=database, connection_class=ORPYAsyncConnection,
|
database = psycopg_pool.AsyncConnectionPool(
|
||||||
min_size=config("PG_AIO_MINCONN", cast=int, default=1),
|
kwargs=database,
|
||||||
max_size=config("PG_AIO_MAXCONN", cast=int, default=5), )
|
connection_class=ORPYAsyncConnection,
|
||||||
|
min_size=config("PG_AIO_MINCONN", cast=int, default=1),
|
||||||
|
max_size=config("PG_AIO_MAXCONN", cast=int, default=5),
|
||||||
|
)
|
||||||
app.state.postgresql = database
|
app.state.postgresql = database
|
||||||
|
|
||||||
# App listening
|
# App listening
|
||||||
|
|
@ -86,16 +108,24 @@ async def lifespan(app: FastAPI):
|
||||||
await pg_client.terminate()
|
await pg_client.terminate()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(root_path=config("root_path", default="/api"), docs_url=config("docs_url", default=""),
|
app = FastAPI(
|
||||||
redoc_url=config("redoc_url", default=""), lifespan=lifespan)
|
root_path=config("root_path", default="/api"),
|
||||||
|
docs_url=config("docs_url", default=""),
|
||||||
|
redoc_url=config("redoc_url", default=""),
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware('http')
|
@app.middleware("http")
|
||||||
async def or_middleware(request: Request, call_next):
|
async def or_middleware(request: Request, call_next):
|
||||||
from chalicelib.core import unlock
|
from chalicelib.core import unlock
|
||||||
|
|
||||||
if not unlock.is_valid():
|
if not unlock.is_valid():
|
||||||
return JSONResponse(content={"errors": ["expired license"]}, status_code=status.HTTP_403_FORBIDDEN)
|
return JSONResponse(
|
||||||
|
content={"errors": ["expired license"]},
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
if helper.TRACK_TIME:
|
if helper.TRACK_TIME:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
@ -110,8 +140,10 @@ async def or_middleware(request: Request, call_next):
|
||||||
now = time.time() - now
|
now = time.time() - now
|
||||||
if now > 2:
|
if now > 2:
|
||||||
now = round(now, 2)
|
now = round(now, 2)
|
||||||
logging.warning(f"Execution time: {now} s for {request.method}: {request.url.path}")
|
logging.warning(
|
||||||
response.headers["x-robots-tag"] = 'noindex, nofollow'
|
f"Execution time: {now} s for {request.method}: {request.url.path}"
|
||||||
|
)
|
||||||
|
response.headers["x-robots-tag"] = "noindex, nofollow"
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -162,3 +194,4 @@ if config("ENABLE_SSO", cast=bool, default=True):
|
||||||
app.include_router(scim.public_app)
|
app.include_router(scim.public_app)
|
||||||
app.include_router(scim.app)
|
app.include_router(scim.app)
|
||||||
app.include_router(scim.app_apikey)
|
app.include_router(scim.app_apikey)
|
||||||
|
app.mount("/sso/scim/v2", WSGIMiddleware(scim.scim_app))
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
|
@ -10,13 +9,15 @@ from chalicelib.utils.TimeUTC import TimeUTC
|
||||||
|
|
||||||
def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool:
|
def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool:
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify(f"""SELECT EXISTS(SELECT 1
|
query = cur.mogrify(
|
||||||
|
f"""SELECT EXISTS(SELECT 1
|
||||||
FROM public.roles
|
FROM public.roles
|
||||||
WHERE tenant_id = %(tenant_id)s
|
WHERE tenant_id = %(tenant_id)s
|
||||||
AND name ILIKE %(name)s
|
AND name ILIKE %(name)s
|
||||||
AND deleted_at ISNULL
|
AND deleted_at ISNULL
|
||||||
{"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""",
|
{"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""",
|
||||||
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id})
|
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
return row["exists"]
|
return row["exists"]
|
||||||
|
|
@ -28,24 +29,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
|
||||||
if not admin["admin"] and not admin["superAdmin"]:
|
if not admin["admin"] and not admin["superAdmin"]:
|
||||||
return {"errors": ["unauthorized"]}
|
return {"errors": ["unauthorized"]}
|
||||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id):
|
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id):
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
|
||||||
|
)
|
||||||
|
|
||||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
||||||
return {"errors": ["must specify a project or all projects"]}
|
return {"errors": ["must specify a project or all projects"]}
|
||||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
||||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
data.projects = projects.is_authorized_batch(
|
||||||
|
project_ids=data.projects, tenant_id=tenant_id
|
||||||
|
)
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify("""SELECT 1
|
query = cur.mogrify(
|
||||||
|
"""SELECT 1
|
||||||
FROM public.roles
|
FROM public.roles
|
||||||
WHERE role_id = %(role_id)s
|
WHERE role_id = %(role_id)s
|
||||||
AND tenant_id = %(tenant_id)s
|
AND tenant_id = %(tenant_id)s
|
||||||
AND protected = TRUE
|
AND protected = TRUE
|
||||||
LIMIT 1;""",
|
LIMIT 1;""",
|
||||||
{"tenant_id": tenant_id, "role_id": role_id})
|
{"tenant_id": tenant_id, "role_id": role_id},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
if cur.fetchone() is not None:
|
if cur.fetchone() is not None:
|
||||||
return {"errors": ["this role is protected"]}
|
return {"errors": ["this role is protected"]}
|
||||||
query = cur.mogrify("""UPDATE public.roles
|
query = cur.mogrify(
|
||||||
|
"""UPDATE public.roles
|
||||||
SET name= %(name)s,
|
SET name= %(name)s,
|
||||||
description= %(description)s,
|
description= %(description)s,
|
||||||
permissions= %(permissions)s,
|
permissions= %(permissions)s,
|
||||||
|
|
@ -57,43 +65,36 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
|
||||||
RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id)
|
RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id)
|
||||||
FROM roles_projects
|
FROM roles_projects
|
||||||
WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""",
|
WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""",
|
||||||
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()})
|
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||||
if not data.all_projects:
|
if not data.all_projects:
|
||||||
d_projects = [i for i in row["projects"] if i not in data.projects]
|
d_projects = [i for i in row["projects"] if i not in data.projects]
|
||||||
if len(d_projects) > 0:
|
if len(d_projects) > 0:
|
||||||
query = cur.mogrify("""DELETE FROM roles_projects
|
query = cur.mogrify(
|
||||||
|
"""DELETE FROM roles_projects
|
||||||
WHERE role_id=%(role_id)s
|
WHERE role_id=%(role_id)s
|
||||||
AND project_id IN %(project_ids)s""",
|
AND project_id IN %(project_ids)s""",
|
||||||
{"role_id": role_id, "project_ids": tuple(d_projects)})
|
{"role_id": role_id, "project_ids": tuple(d_projects)},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
n_projects = [i for i in data.projects if i not in row["projects"]]
|
n_projects = [i for i in data.projects if i not in row["projects"]]
|
||||||
if len(n_projects) > 0:
|
if len(n_projects) > 0:
|
||||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
query = cur.mogrify(
|
||||||
|
f"""INSERT INTO roles_projects(role_id, project_id)
|
||||||
VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""",
|
VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""",
|
||||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(n_projects)}})
|
{
|
||||||
|
"role_id": role_id,
|
||||||
|
**{f"project_id_{i}": p for i, p in enumerate(n_projects)},
|
||||||
|
},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
row["projects"] = data.projects
|
row["projects"] = data.projects
|
||||||
|
|
||||||
return helper.dict_to_camel_case(row)
|
return helper.dict_to_camel_case(row)
|
||||||
|
|
||||||
def update_group_name(tenant_id, group_id, name):
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""UPDATE public.roles
|
|
||||||
SET name= %(name)s
|
|
||||||
WHERE roles.data->>'group_id' = %(group_id)s
|
|
||||||
AND tenant_id = %(tenant_id)s
|
|
||||||
AND deleted_at ISNULL
|
|
||||||
AND protected = FALSE
|
|
||||||
RETURNING *;""",
|
|
||||||
{"tenant_id": tenant_id, "group_id": group_id, "name": name })
|
|
||||||
cur.execute(query=query)
|
|
||||||
row = cur.fetchone()
|
|
||||||
|
|
||||||
return helper.dict_to_camel_case(row)
|
|
||||||
|
|
||||||
|
|
||||||
def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
|
def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
|
||||||
admin = users.get(user_id=user_id, tenant_id=tenant_id)
|
admin = users.get(user_id=user_id, tenant_id=tenant_id)
|
||||||
|
|
@ -102,57 +103,44 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
|
||||||
return {"errors": ["unauthorized"]}
|
return {"errors": ["unauthorized"]}
|
||||||
|
|
||||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
|
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
|
||||||
|
)
|
||||||
|
|
||||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
||||||
return {"errors": ["must specify a project or all projects"]}
|
return {"errors": ["must specify a project or all projects"]}
|
||||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
||||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
data.projects = projects.is_authorized_batch(
|
||||||
|
project_ids=data.projects, tenant_id=tenant_id
|
||||||
|
)
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
|
query = cur.mogrify(
|
||||||
|
"""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
|
||||||
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s)
|
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s)
|
||||||
RETURNING *;""",
|
RETURNING *;""",
|
||||||
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
|
{
|
||||||
"permissions": data.permissions, "all_projects": data.all_projects})
|
"tenant_id": tenant_id,
|
||||||
|
"name": data.name,
|
||||||
|
"description": data.description,
|
||||||
|
"permissions": data.permissions,
|
||||||
|
"all_projects": data.all_projects,
|
||||||
|
},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||||
row["projects"] = []
|
row["projects"] = []
|
||||||
if not data.all_projects:
|
if not data.all_projects:
|
||||||
role_id = row["role_id"]
|
role_id = row["role_id"]
|
||||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
query = cur.mogrify(
|
||||||
|
f"""INSERT INTO roles_projects(role_id, project_id)
|
||||||
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
|
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
|
||||||
RETURNING project_id;""",
|
RETURNING project_id;""",
|
||||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
|
{
|
||||||
cur.execute(query=query)
|
"role_id": role_id,
|
||||||
row["projects"] = [r["project_id"] for r in cur.fetchall()]
|
**{f"project_id_{i}": p for i, p in enumerate(data.projects)},
|
||||||
return helper.dict_to_camel_case(row)
|
},
|
||||||
|
)
|
||||||
def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema):
|
|
||||||
|
|
||||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
|
||||||
|
|
||||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
|
||||||
return {"errors": ["must specify a project or all projects"]}
|
|
||||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
|
||||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects, data)
|
|
||||||
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s, %(data)s)
|
|
||||||
RETURNING *;""",
|
|
||||||
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
|
|
||||||
"permissions": data.permissions, "all_projects": data.all_projects, "data": json.dumps({ "group_id": group_id })})
|
|
||||||
cur.execute(query=query)
|
|
||||||
row = cur.fetchone()
|
|
||||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
|
||||||
row["projects"] = []
|
|
||||||
if not data.all_projects:
|
|
||||||
role_id = row["role_id"]
|
|
||||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
|
||||||
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
|
|
||||||
RETURNING project_id;""",
|
|
||||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
|
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
row["projects"] = [r["project_id"] for r in cur.fetchall()]
|
row["projects"] = [r["project_id"] for r in cur.fetchall()]
|
||||||
return helper.dict_to_camel_case(row)
|
return helper.dict_to_camel_case(row)
|
||||||
|
|
@ -160,7 +148,8 @@ def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema):
|
||||||
|
|
||||||
def get_roles(tenant_id):
|
def get_roles(tenant_id):
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
query = cur.mogrify(
|
||||||
|
"""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
||||||
FROM public.roles
|
FROM public.roles
|
||||||
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
||||||
FROM roles_projects
|
FROM roles_projects
|
||||||
|
|
@ -171,66 +160,25 @@ def get_roles(tenant_id):
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
AND not service_role
|
AND not service_role
|
||||||
ORDER BY role_id;""",
|
ORDER BY role_id;""",
|
||||||
{"tenant_id": tenant_id})
|
{"tenant_id": tenant_id},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
for r in rows:
|
for r in rows:
|
||||||
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
|
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
|
||||||
return helper.list_to_camel_case(rows)
|
return helper.list_to_camel_case(rows)
|
||||||
|
|
||||||
def get_roles_with_uuid(tenant_id):
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
|
||||||
FROM public.roles
|
|
||||||
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
|
||||||
FROM roles_projects
|
|
||||||
INNER JOIN projects USING (project_id)
|
|
||||||
WHERE roles_projects.role_id = roles.role_id
|
|
||||||
AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE)
|
|
||||||
WHERE tenant_id =%(tenant_id)s
|
|
||||||
AND data ? 'group_id'
|
|
||||||
AND deleted_at IS NULL
|
|
||||||
AND not service_role
|
|
||||||
ORDER BY role_id;""",
|
|
||||||
{"tenant_id": tenant_id})
|
|
||||||
cur.execute(query=query)
|
|
||||||
rows = cur.fetchall()
|
|
||||||
for r in rows:
|
|
||||||
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
|
|
||||||
return helper.list_to_camel_case(rows)
|
|
||||||
|
|
||||||
def get_roles_with_uuid_paginated(tenant_id, start_index, count=None, name=None):
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
|
||||||
FROM public.roles
|
|
||||||
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
|
||||||
FROM roles_projects
|
|
||||||
INNER JOIN projects USING (project_id)
|
|
||||||
WHERE roles_projects.role_id = roles.role_id
|
|
||||||
AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE)
|
|
||||||
WHERE tenant_id =%(tenant_id)s
|
|
||||||
AND data ? 'group_id'
|
|
||||||
AND deleted_at IS NULL
|
|
||||||
AND not service_role
|
|
||||||
AND name = COALESCE(%(name)s, name)
|
|
||||||
ORDER BY role_id
|
|
||||||
LIMIT %(count)s
|
|
||||||
OFFSET %(startIndex)s;""",
|
|
||||||
{"tenant_id": tenant_id, "name": name, "startIndex": start_index - 1, "count": count})
|
|
||||||
cur.execute(query=query)
|
|
||||||
rows = cur.fetchall()
|
|
||||||
return helper.list_to_camel_case(rows)
|
|
||||||
|
|
||||||
|
|
||||||
def get_role_by_name(tenant_id, name):
|
def get_role_by_name(tenant_id, name):
|
||||||
### "name" isn't unique in database
|
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify("""SELECT *
|
query = cur.mogrify(
|
||||||
|
"""SELECT *
|
||||||
FROM public.roles
|
FROM public.roles
|
||||||
WHERE tenant_id =%(tenant_id)s
|
WHERE tenant_id =%(tenant_id)s
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
AND name ILIKE %(name)s;""",
|
AND name ILIKE %(name)s;""",
|
||||||
{"tenant_id": tenant_id, "name": name})
|
{"tenant_id": tenant_id, "name": name},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
|
|
@ -244,139 +192,55 @@ def delete(tenant_id, user_id, role_id):
|
||||||
if not admin["admin"] and not admin["superAdmin"]:
|
if not admin["admin"] and not admin["superAdmin"]:
|
||||||
return {"errors": ["unauthorized"]}
|
return {"errors": ["unauthorized"]}
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify("""SELECT 1
|
query = cur.mogrify(
|
||||||
|
"""SELECT 1
|
||||||
FROM public.roles
|
FROM public.roles
|
||||||
WHERE role_id = %(role_id)s
|
WHERE role_id = %(role_id)s
|
||||||
AND tenant_id = %(tenant_id)s
|
AND tenant_id = %(tenant_id)s
|
||||||
AND protected = TRUE
|
AND protected = TRUE
|
||||||
LIMIT 1;""",
|
LIMIT 1;""",
|
||||||
{"tenant_id": tenant_id, "role_id": role_id})
|
{"tenant_id": tenant_id, "role_id": role_id},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
if cur.fetchone() is not None:
|
if cur.fetchone() is not None:
|
||||||
return {"errors": ["this role is protected"]}
|
return {"errors": ["this role is protected"]}
|
||||||
query = cur.mogrify("""SELECT 1
|
query = cur.mogrify(
|
||||||
|
"""SELECT 1
|
||||||
FROM public.users
|
FROM public.users
|
||||||
WHERE role_id = %(role_id)s
|
WHERE role_id = %(role_id)s
|
||||||
AND tenant_id = %(tenant_id)s
|
AND tenant_id = %(tenant_id)s
|
||||||
LIMIT 1;""",
|
LIMIT 1;""",
|
||||||
{"tenant_id": tenant_id, "role_id": role_id})
|
{"tenant_id": tenant_id, "role_id": role_id},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
if cur.fetchone() is not None:
|
if cur.fetchone() is not None:
|
||||||
return {"errors": ["this role is already attached to other user(s)"]}
|
return {"errors": ["this role is already attached to other user(s)"]}
|
||||||
query = cur.mogrify("""UPDATE public.roles
|
query = cur.mogrify(
|
||||||
|
"""UPDATE public.roles
|
||||||
SET deleted_at = timezone('utc'::text, now())
|
SET deleted_at = timezone('utc'::text, now())
|
||||||
WHERE role_id = %(role_id)s
|
WHERE role_id = %(role_id)s
|
||||||
AND tenant_id = %(tenant_id)s
|
AND tenant_id = %(tenant_id)s
|
||||||
AND protected = FALSE;""",
|
AND protected = FALSE;""",
|
||||||
{"tenant_id": tenant_id, "role_id": role_id})
|
{"tenant_id": tenant_id, "role_id": role_id},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
return get_roles(tenant_id=tenant_id)
|
return get_roles(tenant_id=tenant_id)
|
||||||
|
|
||||||
def delete_scim_group(tenant_id, group_uuid):
|
|
||||||
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""SELECT 1
|
|
||||||
FROM public.roles
|
|
||||||
WHERE data->>'group_id' = %(group_uuid)s
|
|
||||||
AND tenant_id = %(tenant_id)s
|
|
||||||
AND protected = TRUE
|
|
||||||
LIMIT 1;""",
|
|
||||||
{"tenant_id": tenant_id, "group_uuid": group_uuid})
|
|
||||||
cur.execute(query)
|
|
||||||
if cur.fetchone() is not None:
|
|
||||||
return {"errors": ["this role is protected"]}
|
|
||||||
|
|
||||||
query = cur.mogrify(
|
|
||||||
f"""DELETE FROM public.roles
|
|
||||||
WHERE roles.data->>'group_id' = %(group_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL
|
|
||||||
{"group_uuid": group_uuid})
|
|
||||||
cur.execute(query)
|
|
||||||
|
|
||||||
return get_roles(tenant_id=tenant_id)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_role(tenant_id, role_id):
|
def get_role(tenant_id, role_id):
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify("""SELECT roles.*
|
query = cur.mogrify(
|
||||||
|
"""SELECT roles.*
|
||||||
FROM public.roles
|
FROM public.roles
|
||||||
WHERE tenant_id =%(tenant_id)s
|
WHERE tenant_id =%(tenant_id)s
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
AND not service_role
|
AND not service_role
|
||||||
AND role_id = %(role_id)s
|
AND role_id = %(role_id)s
|
||||||
LIMIT 1;""",
|
LIMIT 1;""",
|
||||||
{"tenant_id": tenant_id, "role_id": role_id})
|
{"tenant_id": tenant_id, "role_id": role_id},
|
||||||
|
)
|
||||||
cur.execute(query=query)
|
cur.execute(query=query)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||||
return helper.dict_to_camel_case(row)
|
return helper.dict_to_camel_case(row)
|
||||||
|
|
||||||
def get_role_by_group_id(tenant_id, group_id):
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""SELECT roles.*
|
|
||||||
FROM public.roles
|
|
||||||
WHERE tenant_id =%(tenant_id)s
|
|
||||||
AND deleted_at IS NULL
|
|
||||||
AND not service_role
|
|
||||||
AND data->>'group_id' = %(group_id)s
|
|
||||||
LIMIT 1;""",
|
|
||||||
{"tenant_id": tenant_id, "group_id": group_id})
|
|
||||||
cur.execute(query=query)
|
|
||||||
row = cur.fetchone()
|
|
||||||
if row is not None:
|
|
||||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
|
||||||
return helper.dict_to_camel_case(row)
|
|
||||||
|
|
||||||
def get_users_by_group_uuid(tenant_id, group_id):
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""SELECT
|
|
||||||
u.user_id,
|
|
||||||
u.name,
|
|
||||||
u.data
|
|
||||||
FROM public.roles r
|
|
||||||
LEFT JOIN public.users u USING (role_id, tenant_id)
|
|
||||||
WHERE u.tenant_id = %(tenant_id)s
|
|
||||||
AND u.deleted_at IS NULL
|
|
||||||
AND r.data->>'group_id' = %(group_id)s
|
|
||||||
""",
|
|
||||||
{"tenant_id": tenant_id, "group_id": group_id})
|
|
||||||
cur.execute(query=query)
|
|
||||||
rows = cur.fetchall()
|
|
||||||
return helper.list_to_camel_case(rows)
|
|
||||||
|
|
||||||
def get_member_permissions(tenant_id):
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""SELECT
|
|
||||||
r.permissions
|
|
||||||
FROM public.roles r
|
|
||||||
WHERE r.tenant_id = %(tenant_id)s
|
|
||||||
AND r.name = 'Member'
|
|
||||||
AND r.deleted_at IS NULL
|
|
||||||
""",
|
|
||||||
{"tenant_id": tenant_id})
|
|
||||||
cur.execute(query=query)
|
|
||||||
row = cur.fetchone()
|
|
||||||
return helper.dict_to_camel_case(row)
|
|
||||||
|
|
||||||
def remove_group_membership(tenant_id, group_id, user_id):
|
|
||||||
with pg_client.PostgresClient() as cur:
|
|
||||||
query = cur.mogrify("""WITH r AS (
|
|
||||||
SELECT role_id
|
|
||||||
FROM public.roles
|
|
||||||
WHERE data->>'group_id' = %(group_id)s
|
|
||||||
LIMIT 1
|
|
||||||
)
|
|
||||||
UPDATE public.users u
|
|
||||||
SET role_id= NULL
|
|
||||||
FROM r
|
|
||||||
WHERE u.data->>'user_id' = %(user_id)s
|
|
||||||
AND u.role_id = r.role_id
|
|
||||||
AND u.tenant_id = %(tenant_id)s
|
|
||||||
AND u.deleted_at IS NULL
|
|
||||||
RETURNING *;""",
|
|
||||||
{"tenant_id": tenant_id, "group_id": group_id, "user_id": user_id})
|
|
||||||
cur.execute(query=query)
|
|
||||||
row = cur.fetchone()
|
|
||||||
|
|
||||||
return helper.dict_to_camel_case(row)
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -23,20 +23,18 @@ SAML2 = {
|
||||||
"entityId": config("SITE_URL") + API_PREFIX + "/sso/saml2/metadata/",
|
"entityId": config("SITE_URL") + API_PREFIX + "/sso/saml2/metadata/",
|
||||||
"assertionConsumerService": {
|
"assertionConsumerService": {
|
||||||
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/acs/",
|
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/acs/",
|
||||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
||||||
},
|
},
|
||||||
"singleLogoutService": {
|
"singleLogoutService": {
|
||||||
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/sls/",
|
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/sls/",
|
||||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
|
||||||
},
|
},
|
||||||
"NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
|
"NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
|
||||||
"x509cert": config("sp_crt", default=""),
|
"x509cert": config("sp_crt", default=""),
|
||||||
"privateKey": config("sp_key", default=""),
|
"privateKey": config("sp_key", default=""),
|
||||||
},
|
},
|
||||||
"security": {
|
"security": {"requestedAuthnContext": False},
|
||||||
"requestedAuthnContext": False
|
"idp": None,
|
||||||
},
|
|
||||||
"idp": None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# in case tenantKey is included in the URL
|
# in case tenantKey is included in the URL
|
||||||
|
|
@ -50,25 +48,29 @@ if config("SAML2_MD_URL", default=None) is not None and len(config("SAML2_MD_URL
|
||||||
print("SAML2_MD_URL provided, getting IdP metadata config")
|
print("SAML2_MD_URL provided, getting IdP metadata config")
|
||||||
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
|
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
|
||||||
|
|
||||||
idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote(config("SAML2_MD_URL", default=None))
|
idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote(
|
||||||
|
config("SAML2_MD_URL", default=None)
|
||||||
|
)
|
||||||
idp = idp_data.get("idp")
|
idp = idp_data.get("idp")
|
||||||
|
|
||||||
if SAML2["idp"] is None:
|
if SAML2["idp"] is None:
|
||||||
if len(config("idp_entityId", default="")) > 0 \
|
if (
|
||||||
and len(config("idp_sso_url", default="")) > 0 \
|
len(config("idp_entityId", default="")) > 0
|
||||||
and len(config("idp_x509cert", default="")) > 0:
|
and len(config("idp_sso_url", default="")) > 0
|
||||||
|
and len(config("idp_x509cert", default="")) > 0
|
||||||
|
):
|
||||||
idp = {
|
idp = {
|
||||||
"entityId": config("idp_entityId"),
|
"entityId": config("idp_entityId"),
|
||||||
"singleSignOnService": {
|
"singleSignOnService": {
|
||||||
"url": config("idp_sso_url"),
|
"url": config("idp_sso_url"),
|
||||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
|
||||||
},
|
},
|
||||||
"x509cert": config("idp_x509cert")
|
"x509cert": config("idp_x509cert"),
|
||||||
}
|
}
|
||||||
if len(config("idp_sls_url", default="")) > 0:
|
if len(config("idp_sls_url", default="")) > 0:
|
||||||
idp["singleLogoutService"] = {
|
idp["singleLogoutService"] = {
|
||||||
"url": config("idp_sls_url"),
|
"url": config("idp_sls_url"),
|
||||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
|
||||||
}
|
}
|
||||||
|
|
||||||
if idp is None:
|
if idp is None:
|
||||||
|
|
@ -106,8 +108,8 @@ async def prepare_request(request: Request):
|
||||||
session = {}
|
session = {}
|
||||||
# If server is behind proxys or balancers use the HTTP_X_FORWARDED fields
|
# If server is behind proxys or balancers use the HTTP_X_FORWARDED fields
|
||||||
headers = request.headers
|
headers = request.headers
|
||||||
proto = headers.get('x-forwarded-proto', 'http')
|
proto = headers.get("x-forwarded-proto", "http")
|
||||||
url_data = urlparse('%s://%s' % (proto, headers['host']))
|
url_data = urlparse("%s://%s" % (proto, headers["host"]))
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
site_url = urlparse(config("SITE_URL"))
|
site_url = urlparse(config("SITE_URL"))
|
||||||
# to support custom port without changing IDP config
|
# to support custom port without changing IDP config
|
||||||
|
|
@ -117,21 +119,21 @@ async def prepare_request(request: Request):
|
||||||
|
|
||||||
# add / to /acs
|
# add / to /acs
|
||||||
if not path.endswith("/"):
|
if not path.endswith("/"):
|
||||||
path = path + '/'
|
path = path + "/"
|
||||||
if len(API_PREFIX) > 0 and not path.startswith(API_PREFIX):
|
if len(API_PREFIX) > 0 and not path.startswith(API_PREFIX):
|
||||||
path = API_PREFIX + path
|
path = API_PREFIX + path
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'https': 'on' if proto == 'https' else 'off',
|
"https": "on" if proto == "https" else "off",
|
||||||
'http_host': request.headers['host'] + host_suffix,
|
"http_host": request.headers["host"] + host_suffix,
|
||||||
'server_port': url_data.port,
|
"server_port": url_data.port,
|
||||||
'script_name': path,
|
"script_name": path,
|
||||||
'get_data': request.args.copy(),
|
"get_data": request.args.copy(),
|
||||||
# Uncomment if using ADFS as IdP, https://github.com/onelogin/python-saml/pull/144
|
# Uncomment if using ADFS as IdP, https://github.com/onelogin/python-saml/pull/144
|
||||||
# 'lowercase_urlencoding': True,
|
# 'lowercase_urlencoding': True,
|
||||||
'post_data': request.form.copy(),
|
"post_data": request.form.copy(),
|
||||||
'cookie': {"session": session},
|
"cookie": {"session": session},
|
||||||
'request': request
|
"request": request,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,8 +142,11 @@ def is_saml2_available():
|
||||||
|
|
||||||
|
|
||||||
def get_saml2_provider():
|
def get_saml2_provider():
|
||||||
return config("idp_name", default="saml2") if is_saml2_available() and len(
|
return (
|
||||||
config("idp_name", default="saml2")) > 0 else None
|
config("idp_name", default="saml2")
|
||||||
|
if is_saml2_available() and len(config("idp_name", default="saml2")) > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_landing_URL(query_params: dict = None, redirect_to_link2=False):
|
def get_landing_URL(query_params: dict = None, redirect_to_link2=False):
|
||||||
|
|
@ -152,11 +157,14 @@ def get_landing_URL(query_params: dict = None, redirect_to_link2=False):
|
||||||
|
|
||||||
if redirect_to_link2:
|
if redirect_to_link2:
|
||||||
if len(config("sso_landing_override", default="")) == 0:
|
if len(config("sso_landing_override", default="")) == 0:
|
||||||
logging.warning("SSO trying to redirect to custom URL, but sso_landing_override env var is empty")
|
logging.warning(
|
||||||
|
"SSO trying to redirect to custom URL, but sso_landing_override env var is empty"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return config("sso_landing_override") + query_params
|
return config("sso_landing_override") + query_params
|
||||||
|
|
||||||
return config("SITE_URL") + config("sso_landing", default="/login") + query_params
|
base_url = config("SITE_URLx") if config("LOCAL_DEV") else config("SITE_URL")
|
||||||
|
return base_url + config("sso_landing", default="/login") + query_params
|
||||||
|
|
||||||
|
|
||||||
environ["hastSAML2"] = str(is_saml2_available())
|
environ["hastSAML2"] = str(is_saml2_available())
|
||||||
|
|
|
||||||
|
|
@ -13,23 +13,9 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
|
||||||
ALGORITHM = config("SCIM_JWT_ALGORITHM")
|
ALGORITHM = config("SCIM_JWT_ALGORITHM")
|
||||||
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
|
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
|
||||||
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
|
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
|
||||||
AUDIENCE="okta_client"
|
AUDIENCE = config("SCIM_AUDIENCE")
|
||||||
ISSUER=config("JWT_ISSUER"),
|
ISSUER = (config("JWT_ISSUER"),)
|
||||||
|
|
||||||
# Simulated Okta Client Credentials
|
|
||||||
# OKTA_CLIENT_ID = "okta-client"
|
|
||||||
# OKTA_CLIENT_SECRET = "okta-secret"
|
|
||||||
|
|
||||||
# class TokenRequest(BaseModel):
|
|
||||||
# client_id: str
|
|
||||||
# client_secret: str
|
|
||||||
|
|
||||||
# async def authenticate_client(token_request: TokenRequest):
|
|
||||||
# """Validate Okta Client Credentials and issue JWT"""
|
|
||||||
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
|
|
||||||
# raise HTTPException(status_code=401, detail="Invalid client credentials")
|
|
||||||
|
|
||||||
# return {"access_token": create_jwt(), "token_type": "bearer"}
|
|
||||||
|
|
||||||
def create_tokens(tenant_id):
|
def create_tokens(tenant_id):
|
||||||
curr_time = time.time()
|
curr_time = time.time()
|
||||||
|
|
@ -38,7 +24,7 @@ def create_tokens(tenant_id):
|
||||||
"sub": "scim_server",
|
"sub": "scim_server",
|
||||||
"aud": AUDIENCE,
|
"aud": AUDIENCE,
|
||||||
"iss": ISSUER,
|
"iss": ISSUER,
|
||||||
"exp": ""
|
"exp": "",
|
||||||
}
|
}
|
||||||
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
|
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
|
||||||
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
|
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
@ -47,20 +33,26 @@ def create_tokens(tenant_id):
|
||||||
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
|
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
|
||||||
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
|
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
return access_token, refresh_token
|
return access_token, refresh_token, ACCESS_TOKEN_EXPIRE_SECONDS
|
||||||
|
|
||||||
|
|
||||||
def verify_access_token(token: str):
|
def verify_access_token(token: str):
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
payload = jwt.decode(
|
||||||
|
token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
|
||||||
|
)
|
||||||
return payload
|
return payload
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
raise HTTPException(status_code=401, detail="Token expired")
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
|
||||||
def verify_refresh_token(token: str):
|
def verify_refresh_token(token: str):
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
payload = jwt.decode(
|
||||||
|
token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
|
||||||
|
)
|
||||||
return payload
|
return payload
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
raise HTTPException(status_code=401, detail="Token expired")
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
|
|
@ -68,10 +60,25 @@ def verify_refresh_token(token: str):
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
required_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
|
||||||
# Authentication Dependency
|
# Authentication Dependency
|
||||||
def auth_required(token: str = Depends(oauth2_scheme)):
|
def auth_required(token: str = Depends(required_oauth2_scheme)):
|
||||||
"""Dependency to check Authorization header."""
|
"""Dependency to check Authorization header."""
|
||||||
if config("SCIM_AUTH_TYPE") == "OAuth2":
|
if config("SCIM_AUTH_TYPE") == "OAuth2":
|
||||||
payload = verify_access_token(token)
|
payload = verify_access_token(token)
|
||||||
return payload["tenant_id"]
|
return payload["tenant_id"]
|
||||||
|
|
||||||
|
|
||||||
|
optional_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def auth_optional(token: str | None = Depends(optional_oauth2_scheme)):
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
tenant_id = auth_required(token)
|
||||||
|
return tenant_id
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,466 +0,0 @@
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from decouple import config
|
|
||||||
from fastapi import Depends, HTTPException, Header, Query, Response
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
import schemas
|
|
||||||
from chalicelib.core import users, roles, tenants
|
|
||||||
from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token
|
|
||||||
from routers.base import get_routers
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
|
||||||
|
|
||||||
|
|
||||||
"""Authentication endpoints"""
|
|
||||||
|
|
||||||
class RefreshRequest(BaseModel):
|
|
||||||
refresh_token: str
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
# Login endpoint to generate tokens
|
|
||||||
@public_app.post("/token")
|
|
||||||
async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()):
|
|
||||||
subdomain = host.split(".")[0]
|
|
||||||
|
|
||||||
# Missing authentication part, to add
|
|
||||||
if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"):
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
||||||
|
|
||||||
subdomain = "Openreplay EE"
|
|
||||||
tenant = tenants.get_by_name(subdomain)
|
|
||||||
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
|
|
||||||
|
|
||||||
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
|
|
||||||
|
|
||||||
# Refresh token endpoint
|
|
||||||
@public_app.post("/refresh")
|
|
||||||
async def refresh_token(r: RefreshRequest):
|
|
||||||
|
|
||||||
payload = verify_refresh_token(r.refresh_token)
|
|
||||||
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
|
|
||||||
|
|
||||||
return {"access_token": new_access_token, "token_type": "Bearer"}
|
|
||||||
|
|
||||||
"""
|
|
||||||
User endpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Name(BaseModel):
|
|
||||||
givenName: str
|
|
||||||
familyName: str
|
|
||||||
|
|
||||||
class Email(BaseModel):
|
|
||||||
primary: bool
|
|
||||||
value: str
|
|
||||||
type: str
|
|
||||||
|
|
||||||
class UserRequest(BaseModel):
|
|
||||||
schemas: list[str]
|
|
||||||
userName: str
|
|
||||||
name: Name
|
|
||||||
emails: list[Email]
|
|
||||||
displayName: str
|
|
||||||
locale: str
|
|
||||||
externalId: str
|
|
||||||
groups: list[dict]
|
|
||||||
password: str = Field(default=None)
|
|
||||||
active: bool
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
|
||||||
schemas: list[str]
|
|
||||||
id: str
|
|
||||||
userName: str
|
|
||||||
name: Name
|
|
||||||
emails: list[Email] # ignore for now
|
|
||||||
displayName: str
|
|
||||||
locale: str
|
|
||||||
externalId: str
|
|
||||||
active: bool
|
|
||||||
groups: list[dict]
|
|
||||||
meta: dict = Field(default={"resourceType": "User"})
|
|
||||||
|
|
||||||
class PatchUserRequest(BaseModel):
|
|
||||||
schemas: list[str]
|
|
||||||
Operations: list[dict]
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.get("/Users", dependencies=[Depends(auth_required)])
|
|
||||||
async def get_users(
|
|
||||||
start_index: int = Query(1, alias="startIndex"),
|
|
||||||
count: Optional[int] = Query(None, alias="count"),
|
|
||||||
email: Optional[str] = Query(None, alias="filter"),
|
|
||||||
):
|
|
||||||
"""Get SCIM Users"""
|
|
||||||
if email:
|
|
||||||
email = email.split(" ")[2].strip('"')
|
|
||||||
result_users = users.get_users_paginated(start_index, count, email)
|
|
||||||
|
|
||||||
serialized_users = []
|
|
||||||
for user in result_users:
|
|
||||||
serialized_users.append(
|
|
||||||
UserResponse(
|
|
||||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
|
||||||
id = user["data"]["userId"],
|
|
||||||
userName = user["email"],
|
|
||||||
name = Name.model_validate(user["data"]["name"]),
|
|
||||||
emails = [Email.model_validate(user["data"]["emails"])],
|
|
||||||
displayName = user["name"],
|
|
||||||
locale = user["data"]["locale"],
|
|
||||||
externalId = user["internalId"],
|
|
||||||
active = True, # ignore for now, since, can't insert actual timestamp
|
|
||||||
groups = [], # ignore
|
|
||||||
).model_dump(mode='json')
|
|
||||||
)
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=200,
|
|
||||||
content={
|
|
||||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
|
|
||||||
"totalResults": len(serialized_users),
|
|
||||||
"startIndex": start_index,
|
|
||||||
"itemsPerPage": len(serialized_users),
|
|
||||||
"Resources": serialized_users,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)])
|
|
||||||
def get_user(user_id: str):
|
|
||||||
"""Get SCIM User"""
|
|
||||||
tenant_id = 1
|
|
||||||
user = users.get_by_uuid(user_id, tenant_id)
|
|
||||||
if not user:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=404,
|
|
||||||
content={
|
|
||||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
|
||||||
"detail": "User not found",
|
|
||||||
"status": 404,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
res = UserResponse(
|
|
||||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
|
||||||
id = user["data"]["userId"],
|
|
||||||
userName = user["email"],
|
|
||||||
name = Name.model_validate(user["data"]["name"]),
|
|
||||||
emails = [Email.model_validate(user["data"]["emails"])],
|
|
||||||
displayName = user["name"],
|
|
||||||
locale = user["data"]["locale"],
|
|
||||||
externalId = user["internalId"],
|
|
||||||
active = True, # ignore for now, since, can't insert actual timestamp
|
|
||||||
groups = [], # ignore
|
|
||||||
)
|
|
||||||
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.post("/Users", dependencies=[Depends(auth_required)])
|
|
||||||
async def create_user(r: UserRequest):
|
|
||||||
"""Create SCIM User"""
|
|
||||||
tenant_id = 1
|
|
||||||
existing_user = users.get_by_email_only(r.userName)
|
|
||||||
deleted_user = users.get_deleted_user_by_email(r.userName)
|
|
||||||
|
|
||||||
if existing_user:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code = 409,
|
|
||||||
content = {
|
|
||||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
|
||||||
"detail": "User already exists in the database.",
|
|
||||||
"status": 409,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif deleted_user:
|
|
||||||
user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], tenant_id)
|
|
||||||
user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
|
|
||||||
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
|
|
||||||
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
user = users.create_scim_user(tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
|
|
||||||
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
|
|
||||||
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
res = UserResponse(
|
|
||||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
|
||||||
id = user["data"]["userId"],
|
|
||||||
userName = r.userName,
|
|
||||||
name = r.name,
|
|
||||||
emails = r.emails,
|
|
||||||
displayName = r.displayName,
|
|
||||||
locale = r.locale,
|
|
||||||
externalId = r.externalId,
|
|
||||||
active = r.active, # ignore for now, since, can't insert actual timestamp
|
|
||||||
groups = [], # ignore
|
|
||||||
)
|
|
||||||
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)])
|
|
||||||
def update_user(user_id: str, r: UserRequest):
|
|
||||||
"""Update SCIM User"""
|
|
||||||
tenant_id = 1
|
|
||||||
user = users.get_by_uuid(user_id, tenant_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
|
|
||||||
changes = r.model_dump(mode='json', exclude={"schemas", "emails", "name", "locale", "groups", "password", "active"}) # some of these should be added later if necessary
|
|
||||||
nested_changes = r.model_dump(mode='json', include={"name", "emails"})
|
|
||||||
mapping = {"userName": "email", "displayName": "name", "externalId": "internal_id"} # mapping between scim schema field names and local database model, can be done as config?
|
|
||||||
for k, v in mapping.items():
|
|
||||||
if k in changes:
|
|
||||||
changes[v] = changes.pop(k)
|
|
||||||
changes["data"] = {}
|
|
||||||
for k, v in nested_changes.items():
|
|
||||||
value_to_insert = v[0] if k == "emails" else v
|
|
||||||
changes["data"][k] = value_to_insert
|
|
||||||
try:
|
|
||||||
users.update(tenant_id, user["userId"], changes)
|
|
||||||
res = UserResponse(
|
|
||||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
|
||||||
id = user["data"]["userId"],
|
|
||||||
userName = r.userName,
|
|
||||||
name = r.name,
|
|
||||||
emails = r.emails,
|
|
||||||
displayName = r.displayName,
|
|
||||||
locale = r.locale,
|
|
||||||
externalId = r.externalId,
|
|
||||||
active = r.active, # ignore for now, since, can't insert actual timestamp
|
|
||||||
groups = [], # ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.patch("/Users/{user_id}", dependencies=[Depends(auth_required)])
|
|
||||||
def deactivate_user(user_id: str, r: PatchUserRequest):
|
|
||||||
"""Deactivate user, soft-delete"""
|
|
||||||
tenant_id = 1
|
|
||||||
active = r.model_dump(mode='json')["Operations"][0]["value"]["active"]
|
|
||||||
if active:
|
|
||||||
raise HTTPException(status_code=404, detail="Activating user is not supported")
|
|
||||||
user = users.get_by_uuid(user_id, tenant_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
users.delete_member_as_admin(tenant_id, user["userId"])
|
|
||||||
|
|
||||||
return Response(status_code=204, content="")
|
|
||||||
|
|
||||||
@public_app.delete("/Users/{user_uuid}", dependencies=[Depends(auth_required)])
|
|
||||||
def delete_user(user_uuid: str):
|
|
||||||
"""Delete user from database, hard-delete"""
|
|
||||||
tenant_id = 1
|
|
||||||
user = users.get_by_uuid(user_uuid, tenant_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
|
|
||||||
users.__hard_delete_user_uuid(user_uuid)
|
|
||||||
return Response(status_code=204, content="")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Group endpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Operation(BaseModel):
|
|
||||||
op: str
|
|
||||||
path: str = Field(default=None)
|
|
||||||
value: list[dict] | dict = Field(default=None)
|
|
||||||
|
|
||||||
class GroupGetResponse(BaseModel):
|
|
||||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:ListResponse"])
|
|
||||||
totalResults: int
|
|
||||||
startIndex: int
|
|
||||||
itemsPerPage: int
|
|
||||||
resources: list = Field(alias="Resources")
|
|
||||||
|
|
||||||
class GroupRequest(BaseModel):
|
|
||||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
|
|
||||||
displayName: str = Field(default=None)
|
|
||||||
members: list = Field(default=None)
|
|
||||||
operations: list[Operation] = Field(default=None, alias="Operations")
|
|
||||||
|
|
||||||
class GroupPatchRequest(BaseModel):
|
|
||||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:PatchOp"])
|
|
||||||
operations: list[Operation] = Field(alias="Operations")
|
|
||||||
|
|
||||||
class GroupResponse(BaseModel):
|
|
||||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
|
|
||||||
id: str
|
|
||||||
displayName: str
|
|
||||||
members: list
|
|
||||||
meta: dict = Field(default={"resourceType": "Group"})
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.get("/Groups", dependencies=[Depends(auth_required)])
|
|
||||||
def get_groups(
|
|
||||||
start_index: int = Query(1, alias="startIndex"),
|
|
||||||
count: Optional[int] = Query(None, alias="count"),
|
|
||||||
group_name: Optional[str] = Query(None, alias="filter"),
|
|
||||||
):
|
|
||||||
"""Get groups"""
|
|
||||||
tenant_id = 1
|
|
||||||
res = []
|
|
||||||
if group_name:
|
|
||||||
group_name = group_name.split(" ")[2].strip('"')
|
|
||||||
|
|
||||||
groups = roles.get_roles_with_uuid_paginated(tenant_id, start_index, count, group_name)
|
|
||||||
res = [{
|
|
||||||
"id": group["data"]["groupId"],
|
|
||||||
"meta": {
|
|
||||||
"created": group["createdAt"],
|
|
||||||
"lastModified": "", # not currently a field
|
|
||||||
"version": "v1.0"
|
|
||||||
},
|
|
||||||
"displayName": group["name"]
|
|
||||||
} for group in groups
|
|
||||||
]
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=200,
|
|
||||||
content=GroupGetResponse(
|
|
||||||
totalResults=len(groups),
|
|
||||||
startIndex=start_index,
|
|
||||||
itemsPerPage=len(groups),
|
|
||||||
Resources=res
|
|
||||||
).model_dump(mode='json'))
|
|
||||||
|
|
||||||
@public_app.get("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
|
||||||
def get_group(group_id: str):
|
|
||||||
"""Get a group by id"""
|
|
||||||
tenant_id = 1
|
|
||||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="Group not found")
|
|
||||||
members = roles.get_users_by_group_uuid(tenant_id, group["data"]["groupId"])
|
|
||||||
members = [{"value": member["data"]["userId"], "display": member["name"]} for member in members]
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=200,
|
|
||||||
content=GroupResponse(
|
|
||||||
id=group["data"]["groupId"],
|
|
||||||
displayName=group["name"],
|
|
||||||
members=members,
|
|
||||||
).model_dump(mode='json'))
|
|
||||||
|
|
||||||
@public_app.post("/Groups", dependencies=[Depends(auth_required)])
|
|
||||||
def create_group(r: GroupRequest):
|
|
||||||
"""Create a group"""
|
|
||||||
tenant_id = 1
|
|
||||||
member_role = roles.get_member_permissions(tenant_id)
|
|
||||||
try:
|
|
||||||
data = schemas.RolePayloadSchema(name=r.displayName, permissions=member_role["permissions"]) # permissions by default are same as for member role
|
|
||||||
group = roles.create_as_admin(tenant_id, uuid.uuid4().hex, data)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
added_members = []
|
|
||||||
for member in r.members:
|
|
||||||
user = users.get_by_uuid(member["value"], tenant_id)
|
|
||||||
if user:
|
|
||||||
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
|
|
||||||
added_members.append({
|
|
||||||
"value": user["data"]["userId"],
|
|
||||||
"display": user["name"]
|
|
||||||
})
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=200,
|
|
||||||
content=GroupResponse(
|
|
||||||
id=group["data"]["groupId"],
|
|
||||||
displayName=group["name"],
|
|
||||||
members=added_members,
|
|
||||||
).model_dump(mode='json'))
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.put("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
|
||||||
def update_put_group(group_id: str, r: GroupRequest):
|
|
||||||
"""Update a group or members of the group (not used by anything yet)"""
|
|
||||||
tenant_id = 1
|
|
||||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="Group not found")
|
|
||||||
|
|
||||||
if r.operations and r.operations[0].op == "replace" and r.operations[0].path is None:
|
|
||||||
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
|
|
||||||
return Response(status_code=200, content="")
|
|
||||||
|
|
||||||
members = r.members
|
|
||||||
modified_members = []
|
|
||||||
for member in members:
|
|
||||||
user = users.get_by_uuid(member["value"], tenant_id)
|
|
||||||
if user:
|
|
||||||
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
|
|
||||||
modified_members.append({
|
|
||||||
"value": user["data"]["userId"],
|
|
||||||
"display": user["name"]
|
|
||||||
})
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=200,
|
|
||||||
content=GroupResponse(
|
|
||||||
id=group_id,
|
|
||||||
displayName=group["name"],
|
|
||||||
members=modified_members,
|
|
||||||
).model_dump(mode='json'))
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.patch("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
|
||||||
def update_patch_group(group_id: str, r: GroupPatchRequest):
|
|
||||||
"""Update a group or members of the group, used by AIW"""
|
|
||||||
tenant_id = 1
|
|
||||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="Group not found")
|
|
||||||
if r.operations[0].op == "replace" and r.operations[0].path is None:
|
|
||||||
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
|
|
||||||
return Response(status_code=200, content="")
|
|
||||||
|
|
||||||
modified_members = []
|
|
||||||
for op in r.operations:
|
|
||||||
if op.op == "add" or op.op == "replace":
|
|
||||||
# Both methods work as "replace"
|
|
||||||
for u in op.value:
|
|
||||||
user = users.get_by_uuid(u["value"], tenant_id)
|
|
||||||
if user:
|
|
||||||
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
|
|
||||||
modified_members.append({
|
|
||||||
"value": user["data"]["userId"],
|
|
||||||
"display": user["name"]
|
|
||||||
})
|
|
||||||
elif op.op == "remove":
|
|
||||||
user_id = re.search(r'\[value eq \"([a-f0-9]+)\"\]', op.path).group(1)
|
|
||||||
roles.remove_group_membership(tenant_id, group_id, user_id)
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=200,
|
|
||||||
content=GroupResponse(
|
|
||||||
id=group_id,
|
|
||||||
displayName=group["name"],
|
|
||||||
members=modified_members,
|
|
||||||
).model_dump(mode='json'))
|
|
||||||
|
|
||||||
|
|
||||||
@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
|
||||||
def delete_group(group_id: str):
|
|
||||||
"""Delete a group, hard-delete"""
|
|
||||||
# possibly need to set the user's roles to default member role, instead of null
|
|
||||||
tenant_id = 1
|
|
||||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
|
||||||
if not group:
|
|
||||||
raise HTTPException(status_code=404, detail="Group not found")
|
|
||||||
roles.delete_scim_group(tenant_id, group["data"]["groupId"])
|
|
||||||
|
|
||||||
return Response(status_code=200, content="")
|
|
||||||
0
ee/api/routers/scim/__init__.py
Normal file
0
ee/api/routers/scim/__init__.py
Normal file
164
ee/api/routers/scim/api.py
Normal file
164
ee/api/routers/scim/api.py
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
from scim2_server import utils
|
||||||
|
|
||||||
|
|
||||||
|
from routers.base import get_routers
|
||||||
|
from routers.scim.providers import MultiTenantProvider
|
||||||
|
from routers.scim.backends import PostgresBackend
|
||||||
|
from routers.scim.postgres_resource import PostgresResource
|
||||||
|
from routers.scim import users, groups, helpers
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from chalicelib.utils import pg_client
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from chalicelib.utils.scim_auth import (
|
||||||
|
create_tokens,
|
||||||
|
verify_refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
b = PostgresBackend()
|
||||||
|
b.register_postgres_resource(
|
||||||
|
"User",
|
||||||
|
PostgresResource(
|
||||||
|
query_resources=users.query_resources,
|
||||||
|
get_resource=users.get_resource,
|
||||||
|
create_resource=users.create_resource,
|
||||||
|
search_existing=users.search_existing,
|
||||||
|
restore_resource=users.restore_resource,
|
||||||
|
delete_resource=users.delete_resource,
|
||||||
|
update_resource=users.update_resource,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
b.register_postgres_resource(
|
||||||
|
"Group",
|
||||||
|
PostgresResource(
|
||||||
|
query_resources=groups.query_resources,
|
||||||
|
get_resource=groups.get_resource,
|
||||||
|
create_resource=groups.create_resource,
|
||||||
|
search_existing=groups.search_existing,
|
||||||
|
restore_resource=None,
|
||||||
|
delete_resource=groups.delete_resource,
|
||||||
|
update_resource=groups.update_resource,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
scim_app = MultiTenantProvider(b)
|
||||||
|
|
||||||
|
for schema in utils.load_default_schemas().values():
|
||||||
|
scim_app.register_schema(schema)
|
||||||
|
for schema in helpers.load_custom_schemas().values():
|
||||||
|
scim_app.register_schema(schema)
|
||||||
|
for resource_type in helpers.load_custom_resource_types().values():
|
||||||
|
scim_app.register_resource_type(resource_type)
|
||||||
|
|
||||||
|
|
||||||
|
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
||||||
|
|
||||||
|
|
||||||
|
@public_app.post("/token")
|
||||||
|
async def post_token(r: Request):
|
||||||
|
form = await r.form()
|
||||||
|
|
||||||
|
client_id = form.get("client_id")
|
||||||
|
client_secret = form.get("client_secret")
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
SELECT tenant_id
|
||||||
|
FROM public.tenants
|
||||||
|
WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s
|
||||||
|
""",
|
||||||
|
{"tenant_id": int(client_id), "tenant_key": client_secret},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
|
|
||||||
|
tenant = cur.fetchone()
|
||||||
|
if not tenant:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
|
|
||||||
|
grant_type = form.get("grant_type")
|
||||||
|
if grant_type == "refresh_token":
|
||||||
|
refresh_token = form.get("refresh_token")
|
||||||
|
verify_refresh_token(refresh_token)
|
||||||
|
else:
|
||||||
|
code = form.get("code")
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM public.scim_auth_codes
|
||||||
|
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
|
||||||
|
""",
|
||||||
|
{"auth_code": code, "tenant_id": int(client_id)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if cur.fetchone() is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401, detail="Invalid code/client_id pair"
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
UPDATE public.scim_auth_codes
|
||||||
|
SET used=TRUE
|
||||||
|
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
|
||||||
|
""",
|
||||||
|
{"auth_code": code, "tenant_id": int(client_id)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token, refresh_token, expires_in = create_tokens(
|
||||||
|
tenant_id=tenant["tenant_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": expires_in,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint
|
||||||
|
@public_app.get("/authorize")
|
||||||
|
async def get_authorize(
|
||||||
|
r: Request,
|
||||||
|
response_type: str,
|
||||||
|
client_id: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
state: str | None = None,
|
||||||
|
):
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
UPDATE public.scim_auth_codes
|
||||||
|
SET used=TRUE
|
||||||
|
WHERE tenant_id=%(tenant_id)s
|
||||||
|
""",
|
||||||
|
{"tenant_id": int(client_id)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
INSERT INTO public.scim_auth_codes (tenant_id)
|
||||||
|
VALUES (%(tenant_id)s)
|
||||||
|
RETURNING auth_code
|
||||||
|
""",
|
||||||
|
{"tenant_id": int(client_id)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
code = cur.fetchone()["auth_code"]
|
||||||
|
params = {"code": code}
|
||||||
|
if state:
|
||||||
|
params["state"] = state
|
||||||
|
url = f"{redirect_uri}?{urlencode(params)}"
|
||||||
|
return RedirectResponse(url)
|
||||||
203
ee/api/routers/scim/backends.py
Normal file
203
ee/api/routers/scim/backends.py
Normal file
|
|
@ -0,0 +1,203 @@
|
||||||
|
from scim2_server import backend
|
||||||
|
from scim2_server.filter import evaluate_filter
|
||||||
|
from scim2_server.utils import SCIMException
|
||||||
|
|
||||||
|
from scim2_models import (
|
||||||
|
SearchRequest,
|
||||||
|
Resource,
|
||||||
|
Context,
|
||||||
|
Error,
|
||||||
|
)
|
||||||
|
from scim2_filter_parser import lexer
|
||||||
|
from scim2_filter_parser.parser import SCIMParser
|
||||||
|
from routers.scim.postgres_resource import PostgresResource
|
||||||
|
from scim2_server.operators import ResolveSortOperator
|
||||||
|
import operator
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresBackend(backend.Backend):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._postgres_resources = {}
|
||||||
|
|
||||||
|
def register_postgres_resource(
|
||||||
|
self, resource_type_id: str, postgres_resource: PostgresResource
|
||||||
|
):
|
||||||
|
self._postgres_resources[resource_type_id] = postgres_resource
|
||||||
|
|
||||||
|
def query_resources(
|
||||||
|
self,
|
||||||
|
search_request: SearchRequest,
|
||||||
|
tenant_id: int,
|
||||||
|
resource_type_id: str | None = None,
|
||||||
|
) -> tuple[int, list[Resource]]:
|
||||||
|
"""Query the backend for a set of resources.
|
||||||
|
|
||||||
|
:param search_request: SearchRequest instance describing the
|
||||||
|
query.
|
||||||
|
:param resource_type_id: ID of the resource type to query. If
|
||||||
|
None, all resource types are queried.
|
||||||
|
:return: A tuple of "total results" and a List of found
|
||||||
|
Resources. The List must contain a copy of resources.
|
||||||
|
Mutating elements in the List must not modify the data
|
||||||
|
stored in the backend.
|
||||||
|
:raises SCIMException: If the backend only supports querying for
|
||||||
|
one resource type at a time, setting resource_type_id to
|
||||||
|
None the backend may raise a
|
||||||
|
SCIMException(Error.make_too_many_error()).
|
||||||
|
"""
|
||||||
|
start_index = (search_request.start_index or 1) - 1
|
||||||
|
|
||||||
|
tree = None
|
||||||
|
if search_request.filter is not None:
|
||||||
|
token_stream = lexer.SCIMLexer().tokenize(search_request.filter)
|
||||||
|
tree = SCIMParser().parse(token_stream)
|
||||||
|
|
||||||
|
# todo(jon): handle the case when resource_type_id is None.
|
||||||
|
# we're assuming it's never None for now.
|
||||||
|
# but, this is fine to leave as it doesn't seem to used or reached in
|
||||||
|
# any of my tests yet.
|
||||||
|
if not resource_type_id:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
resources = self._postgres_resources[resource_type_id].query_resources(
|
||||||
|
tenant_id
|
||||||
|
)
|
||||||
|
model = self.get_model(resource_type_id)
|
||||||
|
resources = [
|
||||||
|
model.model_validate(r, scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
|
||||||
|
for r in resources
|
||||||
|
]
|
||||||
|
resources = [r for r in resources if (tree is None or evaluate_filter(r, tree))]
|
||||||
|
|
||||||
|
if search_request.sort_by is not None:
|
||||||
|
descending = search_request.sort_order == SearchRequest.SortOrder.descending
|
||||||
|
sort_operator = ResolveSortOperator(search_request.sort_by)
|
||||||
|
|
||||||
|
# To ensure that unset attributes are sorted last (when ascending, as defined in the RFC),
|
||||||
|
# we have to divide the result set into a set and unset subset.
|
||||||
|
unset_values = []
|
||||||
|
set_values = []
|
||||||
|
for resource in resources:
|
||||||
|
result = sort_operator(resource)
|
||||||
|
if result is None:
|
||||||
|
unset_values.append(resource)
|
||||||
|
else:
|
||||||
|
set_values.append((resource, result))
|
||||||
|
|
||||||
|
set_values.sort(key=operator.itemgetter(1), reverse=descending)
|
||||||
|
set_values = [value[0] for value in set_values]
|
||||||
|
if descending:
|
||||||
|
resources = unset_values + set_values
|
||||||
|
else:
|
||||||
|
resources = set_values + unset_values
|
||||||
|
|
||||||
|
found_resources = resources[start_index:]
|
||||||
|
if search_request.count is not None:
|
||||||
|
found_resources = resources[: search_request.count]
|
||||||
|
|
||||||
|
return len(resources), found_resources
|
||||||
|
|
||||||
|
def get_resource(
|
||||||
|
self, tenant_id: int, resource_type_id: str, object_id: str
|
||||||
|
) -> Resource | None:
|
||||||
|
"""Query the backend for a resources by its ID.
|
||||||
|
|
||||||
|
:param resource_type_id: ID of the resource type to get the
|
||||||
|
object from.
|
||||||
|
:param object_id: ID of the object to get.
|
||||||
|
:return: The resource object if it exists, None otherwise. The
|
||||||
|
resource must be a copy, modifying it must not change the
|
||||||
|
data stored in the backend.
|
||||||
|
"""
|
||||||
|
resource = self._postgres_resources[resource_type_id].get_resource(
|
||||||
|
object_id, tenant_id
|
||||||
|
)
|
||||||
|
if resource:
|
||||||
|
model = self.get_model(resource_type_id)
|
||||||
|
resource = model.model_validate(resource)
|
||||||
|
return resource
|
||||||
|
|
||||||
|
def delete_resource(
|
||||||
|
self, tenant_id: int, resource_type_id: str, object_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""Delete a resource.
|
||||||
|
|
||||||
|
:param resource_type_id: ID of the resource type to delete the
|
||||||
|
object from.
|
||||||
|
:param object_id: ID of the object to delete.
|
||||||
|
:return: True if the resource was deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
resource = self.get_resource(tenant_id, resource_type_id, object_id)
|
||||||
|
if resource:
|
||||||
|
self._postgres_resources[resource_type_id].delete_resource(
|
||||||
|
object_id, tenant_id
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_resource(
|
||||||
|
self, tenant_id: int, resource_type_id: str, resource: Resource
|
||||||
|
) -> Resource | None:
|
||||||
|
"""Create a resource.
|
||||||
|
|
||||||
|
:param resource_type_id: ID of the resource type to create.
|
||||||
|
:param resource: Resource to create.
|
||||||
|
:return: The created resource. Creation should set system-
|
||||||
|
defined attributes (ID, Metadata). May be the same object
|
||||||
|
that is passed in.
|
||||||
|
"""
|
||||||
|
model = self.get_model(resource_type_id)
|
||||||
|
existing = self._postgres_resources[resource_type_id].search_existing(
|
||||||
|
tenant_id, resource
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
existing = model.model_validate(existing)
|
||||||
|
if existing.active:
|
||||||
|
raise SCIMException(Error.make_uniqueness_error())
|
||||||
|
resource = self._postgres_resources[resource_type_id].restore_resource(
|
||||||
|
tenant_id, resource
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resource = self._postgres_resources[resource_type_id].create_resource(
|
||||||
|
tenant_id, resource
|
||||||
|
)
|
||||||
|
resource = model.model_validate(resource)
|
||||||
|
return resource
|
||||||
|
|
||||||
|
def update_resource(
|
||||||
|
self, tenant_id: int, resource_type_id: str, resource: Resource
|
||||||
|
) -> Resource | None:
|
||||||
|
"""Update a resource. The resource is identified by its ID.
|
||||||
|
|
||||||
|
:param resource_type_id: ID of the resource type to update.
|
||||||
|
:param resource: Resource to update.
|
||||||
|
:return: The updated resource. Updating should update the
|
||||||
|
"meta.lastModified" data. May be the same object that is
|
||||||
|
passed in.
|
||||||
|
"""
|
||||||
|
model = self.get_model(resource_type_id)
|
||||||
|
existing = self._postgres_resources[resource_type_id].search_existing(
|
||||||
|
tenant_id, resource
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
existing = model.model_validate(existing)
|
||||||
|
if existing.active:
|
||||||
|
if existing.id != resource.id:
|
||||||
|
raise SCIMException(Error.make_uniqueness_error())
|
||||||
|
resource = self._postgres_resources[resource_type_id].update_resource(
|
||||||
|
tenant_id, resource
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._postgres_resources[resource_type_id].delete_resource(
|
||||||
|
existing.id, tenant_id
|
||||||
|
)
|
||||||
|
resource = self._postgres_resources[resource_type_id].update_resource(
|
||||||
|
resource.id, tenant_id, resource
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resource = self._postgres_resources[resource_type_id].update_resource(
|
||||||
|
tenant_id, resource
|
||||||
|
)
|
||||||
|
resource = model.model_validate(resource)
|
||||||
|
return resource
|
||||||
36
ee/api/routers/scim/fixtures/custom_resource_types.json
Normal file
36
ee/api/routers/scim/fixtures/custom_resource_types.json
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
[{
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||||
|
"id": "User",
|
||||||
|
"name": "User",
|
||||||
|
"endpoint": "/Users",
|
||||||
|
"description": "User Account",
|
||||||
|
"schema": "urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
"schemaExtensions": [
|
||||||
|
{
|
||||||
|
"schema":
|
||||||
|
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"schema":
|
||||||
|
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"location": "/v2/ResourceTypes/User",
|
||||||
|
"resourceType": "ResourceType"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||||
|
"id": "Group",
|
||||||
|
"name": "Group",
|
||||||
|
"endpoint": "/Groups",
|
||||||
|
"description": "Group",
|
||||||
|
"schema": "urn:ietf:params:scim:schemas:core:2.0:Group",
|
||||||
|
"meta": {
|
||||||
|
"location": "/v2/ResourceTypes/Group",
|
||||||
|
"resourceType": "ResourceType"
|
||||||
|
}
|
||||||
|
}]
|
||||||
32
ee/api/routers/scim/fixtures/custom_schemas.json
Normal file
32
ee/api/routers/scim/fixtures/custom_schemas.json
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
|
||||||
|
"name": "OpenreplayUser",
|
||||||
|
"description": "Openreplay User Account Extension",
|
||||||
|
"attributes": [
|
||||||
|
{
|
||||||
|
"name": "permissions",
|
||||||
|
"type": "string",
|
||||||
|
"multiValued": true,
|
||||||
|
"description": "A list of permissions for the User that represent a thing the User is capable of doing.",
|
||||||
|
"required": false,
|
||||||
|
"canonicalValues": ["SESSION_REPLAY", "DEV_TOOLS", "METRICS", "ASSIST_LIVE", "ASSIST_CALL", "SPOT", "SPOT_PUBLIC"],
|
||||||
|
"mutability": "readWrite",
|
||||||
|
"returned": "default"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "projectKeys",
|
||||||
|
"type": "string",
|
||||||
|
"multiValued": true,
|
||||||
|
"description": "A list of project keys for the User that represent a project the User is allowed to work on.",
|
||||||
|
"required": false,
|
||||||
|
"mutability": "readWrite",
|
||||||
|
"returned": "default"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"resourceType": "Schema",
|
||||||
|
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:extension:openreplay:2.0:User"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
203
ee/api/routers/scim/groups.py
Normal file
203
ee/api/routers/scim/groups.py
Normal file
|
|
@ -0,0 +1,203 @@
|
||||||
|
from typing import Any
|
||||||
|
from datetime import datetime
|
||||||
|
from psycopg2.extensions import AsIs
|
||||||
|
|
||||||
|
from chalicelib.utils import pg_client
|
||||||
|
from routers.scim import helpers
|
||||||
|
|
||||||
|
from scim2_models import Error, Resource
|
||||||
|
from scim2_server.utils import SCIMException
|
||||||
|
|
||||||
|
|
||||||
|
def convert_provider_resource_to_client_resource(
|
||||||
|
provider_resource: dict,
|
||||||
|
) -> dict:
|
||||||
|
members = provider_resource["users"] or []
|
||||||
|
return {
|
||||||
|
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||||
|
"id": str(provider_resource["role_id"]),
|
||||||
|
"meta": {
|
||||||
|
"resourceType": "Group",
|
||||||
|
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||||
|
"lastModified": provider_resource["updated_at"].strftime(
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"displayName": provider_resource["name"],
|
||||||
|
"members": [
|
||||||
|
{
|
||||||
|
"value": str(member["user_id"]),
|
||||||
|
"$ref": f"Users/{member['user_id']}",
|
||||||
|
"type": "User",
|
||||||
|
}
|
||||||
|
for member in members
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def query_resources(tenant_id: int) -> list[dict]:
|
||||||
|
query = _main_select_query(tenant_id)
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(query)
|
||||||
|
items = cur.fetchall()
|
||||||
|
return [convert_provider_resource_to_client_resource(item) for item in items]
|
||||||
|
|
||||||
|
|
||||||
|
def get_resource(resource_id: str, tenant_id: int) -> dict | None:
|
||||||
|
query = _main_select_query(tenant_id, resource_id)
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(query)
|
||||||
|
item = cur.fetchone()
|
||||||
|
if item:
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_resource(resource_id: str, tenant_id: int) -> None:
|
||||||
|
_update_resource_sql(
|
||||||
|
resource_id=resource_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
deleted_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def search_existing(tenant_id: int, resource: Resource) -> dict | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_resource(tenant_id: int, resource: Resource) -> dict:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
user_ids = (
|
||||||
|
[int(x.value) for x in resource.members] if resource.members else None
|
||||||
|
)
|
||||||
|
user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur)
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
INSERT INTO public.roles (
|
||||||
|
name,
|
||||||
|
tenant_id
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
%(name)s,
|
||||||
|
%(tenant_id)s
|
||||||
|
)
|
||||||
|
RETURNING role_id
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"name": resource.display_name,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise SCIMException(Error.make_invalid_value_error())
|
||||||
|
role_id = cur.fetchone()["role_id"]
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE public.users
|
||||||
|
SET
|
||||||
|
updated_at = now(),
|
||||||
|
role_id = {role_id}
|
||||||
|
WHERE users.user_id = ANY({user_id_clause})
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(f"{_main_select_query(tenant_id, role_id)} LIMIT 1")
|
||||||
|
item = cur.fetchone()
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
|
||||||
|
|
||||||
|
def update_resource(tenant_id: int, resource: Resource) -> dict | None:
|
||||||
|
item = _update_resource_sql(
|
||||||
|
resource_id=resource.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=resource.display_name,
|
||||||
|
user_ids=[int(x.value) for x in resource.members],
|
||||||
|
deleted_at=None,
|
||||||
|
)
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
|
||||||
|
|
||||||
|
def _main_select_query(tenant_id: int, resource_id: str | None = None) -> str:
|
||||||
|
where_and_clauses = [
|
||||||
|
f"roles.tenant_id = {tenant_id}",
|
||||||
|
"roles.deleted_at IS NULL",
|
||||||
|
]
|
||||||
|
if resource_id is not None:
|
||||||
|
where_and_clauses.append(f"roles.role_id = {resource_id}")
|
||||||
|
where_clause = " AND ".join(where_and_clauses)
|
||||||
|
return f"""
|
||||||
|
SELECT
|
||||||
|
roles.*,
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT json_agg(users)
|
||||||
|
FROM public.users
|
||||||
|
WHERE users.role_id = roles.role_id
|
||||||
|
),
|
||||||
|
'[]'
|
||||||
|
) AS users,
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT json_agg(projects.project_key)
|
||||||
|
FROM public.projects
|
||||||
|
LEFT JOIN public.roles_projects USING (project_id)
|
||||||
|
WHERE roles_projects.role_id = roles.role_id
|
||||||
|
),
|
||||||
|
'[]'
|
||||||
|
) AS project_keys
|
||||||
|
FROM public.roles
|
||||||
|
WHERE {where_clause}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _update_resource_sql(
|
||||||
|
resource_id: int,
|
||||||
|
tenant_id: int,
|
||||||
|
user_ids: list[int] | None = None,
|
||||||
|
**kwargs: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
kwargs["updated_at"] = datetime.now()
|
||||||
|
set_fragments = [
|
||||||
|
cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8")
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
]
|
||||||
|
set_clause = ", ".join(set_fragments)
|
||||||
|
user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur)
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE public.users
|
||||||
|
SET
|
||||||
|
updated_at = now(),
|
||||||
|
role_id = NULL
|
||||||
|
WHERE
|
||||||
|
users.role_id = {resource_id}
|
||||||
|
AND users.user_id != ALL({user_id_clause})
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE public.users
|
||||||
|
SET
|
||||||
|
updated_at = now(),
|
||||||
|
role_id = {resource_id}
|
||||||
|
WHERE
|
||||||
|
(users.role_id != {resource_id} OR users.role_id IS NULL)
|
||||||
|
AND users.user_id = ANY({user_id_clause})
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE public.roles
|
||||||
|
SET {set_clause}
|
||||||
|
WHERE
|
||||||
|
roles.role_id = {resource_id}
|
||||||
|
AND roles.tenant_id = {tenant_id}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1")
|
||||||
|
return cur.fetchone()
|
||||||
44
ee/api/routers/scim/helpers.py
Normal file
44
ee/api/routers/scim/helpers.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
from typing import Any, Literal
|
||||||
|
from chalicelib.utils import pg_client
|
||||||
|
from scim2_models import Schema, Resource, ResourceType
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def safe_mogrify_array(
|
||||||
|
items: list[Any] | None,
|
||||||
|
array_type: Literal["varchar", "int"],
|
||||||
|
cursor: pg_client.PostgresClient,
|
||||||
|
) -> str:
|
||||||
|
items = items or []
|
||||||
|
fragments = [cursor.mogrify("%s", (item,)).decode("utf-8") for item in items]
|
||||||
|
result = f"ARRAY[{', '.join(fragments)}]::{array_type}[]"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_json_resource(json_name: str) -> dict:
|
||||||
|
with open(json_name) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_scim_resource(
|
||||||
|
json_name: str, type_: type[Resource]
|
||||||
|
) -> dict[str, type[Resource]]:
|
||||||
|
ret = {}
|
||||||
|
definitions = load_json_resource(json_name)
|
||||||
|
for d in definitions:
|
||||||
|
model = type_.model_validate(d)
|
||||||
|
ret[model.id] = model
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def load_custom_schemas() -> dict[str, Schema]:
|
||||||
|
json_name = os.path.join("routers", "scim", "fixtures", "custom_schemas.json")
|
||||||
|
return load_scim_resource(json_name, Schema)
|
||||||
|
|
||||||
|
|
||||||
|
def load_custom_resource_types() -> dict[str, ResourceType]:
|
||||||
|
json_name = os.path.join(
|
||||||
|
"routers", "scim", "fixtures", "custom_resource_types.json"
|
||||||
|
)
|
||||||
|
return load_scim_resource(json_name, ResourceType)
|
||||||
14
ee/api/routers/scim/postgres_resource.py
Normal file
14
ee/api/routers/scim/postgres_resource.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
from scim2_models import Resource
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PostgresResource:
|
||||||
|
query_resources: Callable[[int], list[dict]]
|
||||||
|
get_resource: Callable[[str, int], dict | None]
|
||||||
|
create_resource: Callable[[int, Resource], dict]
|
||||||
|
search_existing: Callable[[int, Resource], dict | None]
|
||||||
|
restore_resource: Callable[[int, Resource], dict] | None
|
||||||
|
delete_resource: Callable[[str, int], None]
|
||||||
|
update_resource: Callable[[int, Resource], dict]
|
||||||
280
ee/api/routers/scim/providers.py
Normal file
280
ee/api/routers/scim/providers.py
Normal file
|
|
@ -0,0 +1,280 @@
|
||||||
|
import traceback
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from scim2_server import provider
|
||||||
|
|
||||||
|
from scim2_models import (
|
||||||
|
AuthenticationScheme,
|
||||||
|
ServiceProviderConfig,
|
||||||
|
Patch,
|
||||||
|
Bulk,
|
||||||
|
Filter,
|
||||||
|
Sort,
|
||||||
|
ETag,
|
||||||
|
Meta,
|
||||||
|
ChangePassword,
|
||||||
|
Error,
|
||||||
|
ResourceType,
|
||||||
|
Context,
|
||||||
|
ListResponse,
|
||||||
|
PatchOp,
|
||||||
|
)
|
||||||
|
|
||||||
|
from werkzeug import Request, Response
|
||||||
|
from werkzeug.exceptions import HTTPException, NotFound, PreconditionFailed
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from werkzeug.routing.exceptions import RequestRedirect
|
||||||
|
from scim2_server.utils import SCIMException, merge_resources
|
||||||
|
|
||||||
|
from chalicelib.utils.scim_auth import verify_access_token
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTenantProvider(provider.SCIMProvider):
|
||||||
|
def check_auth(self, request: Request):
|
||||||
|
auth = request.headers.get("Authorization")
|
||||||
|
if not auth or not auth.startswith("Bearer "):
|
||||||
|
return None
|
||||||
|
token = auth[len("Bearer ") :]
|
||||||
|
if not token:
|
||||||
|
return Response(
|
||||||
|
"Missing or invalid Authorization header",
|
||||||
|
status=401,
|
||||||
|
headers={"WWW-Authenticate": 'Bearer realm="login required"'},
|
||||||
|
)
|
||||||
|
payload = verify_access_token(token)
|
||||||
|
tenant_id = payload["tenant_id"]
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
|
def get_service_provider_config(self):
|
||||||
|
auth_schemes = [
|
||||||
|
AuthenticationScheme(
|
||||||
|
type="oauthbearertoken",
|
||||||
|
name="OAuth Bearer Token",
|
||||||
|
description="Authentication scheme using the OAuth Bearer Token Standard. The access token should be sent in the 'Authorization' header using the Bearer schema.",
|
||||||
|
spec_uri="https://datatracker.ietf.org/doc/html/rfc6750",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return ServiceProviderConfig(
|
||||||
|
# todo(jon): write correct documentation uri
|
||||||
|
documentation_uri="https://www.example.com/",
|
||||||
|
patch=Patch(supported=True),
|
||||||
|
bulk=Bulk(supported=False),
|
||||||
|
filter=Filter(supported=True, max_results=1000),
|
||||||
|
change_password=ChangePassword(supported=False),
|
||||||
|
sort=Sort(supported=True),
|
||||||
|
etag=ETag(supported=False),
|
||||||
|
authentication_schemes=auth_schemes,
|
||||||
|
meta=Meta(resource_type="ServiceProviderConfig"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def query_resource(
|
||||||
|
self, request: Request, tenant_id: int, resource: ResourceType | None
|
||||||
|
):
|
||||||
|
search_request = self.build_search_request(request)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if resource is not None:
|
||||||
|
kwargs["resource_type_id"] = resource.id
|
||||||
|
total_results, results = self.backend.query_resources(
|
||||||
|
search_request=search_request, tenant_id=tenant_id, **kwargs
|
||||||
|
)
|
||||||
|
for r in results:
|
||||||
|
self.adjust_location(request, r)
|
||||||
|
|
||||||
|
resources = [
|
||||||
|
s.model_dump(
|
||||||
|
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||||
|
attributes=search_request.attributes,
|
||||||
|
excluded_attributes=search_request.excluded_attributes,
|
||||||
|
)
|
||||||
|
for s in results
|
||||||
|
]
|
||||||
|
|
||||||
|
return ListResponse[Union[tuple(self.backend.get_models())]]( # noqa: UP007
|
||||||
|
total_results=total_results,
|
||||||
|
items_per_page=search_request.count,
|
||||||
|
start_index=search_request.start_index,
|
||||||
|
resources=resources,
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_resource(
|
||||||
|
self, request: Request, resource_endpoint: str, **kwargs
|
||||||
|
) -> Response:
|
||||||
|
resource_type = self.backend.get_resource_type_by_endpoint(
|
||||||
|
"/" + resource_endpoint
|
||||||
|
)
|
||||||
|
if not resource_type:
|
||||||
|
raise NotFound
|
||||||
|
|
||||||
|
if "tenant_id" not in kwargs:
|
||||||
|
raise Exception
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
|
||||||
|
match request.method:
|
||||||
|
case "GET":
|
||||||
|
return self.make_response(
|
||||||
|
self.query_resource(request, tenant_id, resource_type).model_dump(
|
||||||
|
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case _: # "POST"
|
||||||
|
payload = request.json
|
||||||
|
resource = self.backend.get_model(resource_type.id).model_validate(
|
||||||
|
payload, scim_ctx=Context.RESOURCE_CREATION_REQUEST
|
||||||
|
)
|
||||||
|
created_resource = self.backend.create_resource(
|
||||||
|
tenant_id,
|
||||||
|
resource_type.id,
|
||||||
|
resource,
|
||||||
|
)
|
||||||
|
self.adjust_location(request, created_resource)
|
||||||
|
return self.make_response(
|
||||||
|
created_resource.model_dump(
|
||||||
|
scim_ctx=Context.RESOURCE_CREATION_RESPONSE
|
||||||
|
),
|
||||||
|
status=201,
|
||||||
|
headers={"Location": created_resource.meta.location},
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_single_resource(
|
||||||
|
self, request: Request, resource_endpoint: str, resource_id: str, **kwargs
|
||||||
|
) -> Response:
|
||||||
|
find_endpoint = "/" + resource_endpoint
|
||||||
|
resource_type = self.backend.get_resource_type_by_endpoint(find_endpoint)
|
||||||
|
if not resource_type:
|
||||||
|
raise NotFound
|
||||||
|
|
||||||
|
if "tenant_id" not in kwargs:
|
||||||
|
raise Exception
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
|
||||||
|
match request.method:
|
||||||
|
case "GET":
|
||||||
|
if resource := self.backend.get_resource(
|
||||||
|
tenant_id, resource_type.id, resource_id
|
||||||
|
):
|
||||||
|
if self.continue_etag(request, resource):
|
||||||
|
response_args = self.get_attrs_from_request(request)
|
||||||
|
self.adjust_location(request, resource)
|
||||||
|
return self.make_response(
|
||||||
|
resource.model_dump(
|
||||||
|
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||||
|
**response_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.make_response(None, status=304)
|
||||||
|
raise NotFound
|
||||||
|
case "DELETE":
|
||||||
|
if self.backend.delete_resource(
|
||||||
|
tenant_id, resource_type.id, resource_id
|
||||||
|
):
|
||||||
|
return self.make_response(None, 204)
|
||||||
|
else:
|
||||||
|
raise NotFound
|
||||||
|
case "PUT":
|
||||||
|
response_args = self.get_attrs_from_request(request)
|
||||||
|
resource = self.backend.get_resource(
|
||||||
|
tenant_id, resource_type.id, resource_id
|
||||||
|
)
|
||||||
|
if resource is None:
|
||||||
|
raise NotFound
|
||||||
|
if not self.continue_etag(request, resource):
|
||||||
|
raise PreconditionFailed
|
||||||
|
|
||||||
|
updated_attributes = self.backend.get_model(
|
||||||
|
resource_type.id
|
||||||
|
).model_validate(request.json)
|
||||||
|
merge_resources(resource, updated_attributes)
|
||||||
|
updated = self.backend.update_resource(
|
||||||
|
tenant_id, resource_type.id, resource
|
||||||
|
)
|
||||||
|
self.adjust_location(request, updated)
|
||||||
|
return self.make_response(
|
||||||
|
updated.model_dump(
|
||||||
|
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
|
||||||
|
**response_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case _: # "PATCH"
|
||||||
|
payload = request.json
|
||||||
|
# MS Entra sometimes passes a "id" attribute
|
||||||
|
if "id" in payload:
|
||||||
|
del payload["id"]
|
||||||
|
operations = payload.get("Operations", [])
|
||||||
|
for operation in operations:
|
||||||
|
if "name" in operation:
|
||||||
|
# MS Entra sometimes passes a "name" attribute
|
||||||
|
del operation["name"]
|
||||||
|
|
||||||
|
patch_operation = PatchOp.model_validate(payload)
|
||||||
|
response_args = self.get_attrs_from_request(request)
|
||||||
|
resource = self.backend.get_resource(
|
||||||
|
tenant_id, resource_type.id, resource_id
|
||||||
|
)
|
||||||
|
if resource is None:
|
||||||
|
raise NotFound
|
||||||
|
if not self.continue_etag(request, resource):
|
||||||
|
raise PreconditionFailed
|
||||||
|
|
||||||
|
self.apply_patch_operation(resource, patch_operation)
|
||||||
|
updated = self.backend.update_resource(
|
||||||
|
tenant_id, resource_type.id, resource
|
||||||
|
)
|
||||||
|
|
||||||
|
if response_args:
|
||||||
|
self.adjust_location(request, updated)
|
||||||
|
return self.make_response(
|
||||||
|
updated.model_dump(
|
||||||
|
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
|
||||||
|
**response_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# RFC 7644, section 3.5.2:
|
||||||
|
# A PATCH operation MAY return a 204 (no content)
|
||||||
|
# if no attributes were requested
|
||||||
|
return self.make_response(
|
||||||
|
None, 204, headers={"ETag": updated.meta.version}
|
||||||
|
)
|
||||||
|
|
||||||
|
def wsgi_app(self, request: Request, environ):
|
||||||
|
try:
|
||||||
|
if environ.get("PATH_INFO", "").endswith(".scim"):
|
||||||
|
# RFC 7644, Section 3.8
|
||||||
|
# Just strip .scim suffix, the provider always returns application/scim+json
|
||||||
|
environ["PATH_INFO"], _, _ = environ["PATH_INFO"].rpartition(".scim")
|
||||||
|
urls = self.url_map.bind_to_environ(environ)
|
||||||
|
endpoint, args = urls.match()
|
||||||
|
|
||||||
|
tenant_id = None
|
||||||
|
if endpoint != "service_provider_config":
|
||||||
|
# RFC7643, Section 5: skip authentication for ServiceProviderConfig
|
||||||
|
tenant_id = self.check_auth(request)
|
||||||
|
|
||||||
|
# Wrap the entire call in a transaction. Should probably be optimized (use transaction only when necessary).
|
||||||
|
with self.backend:
|
||||||
|
if endpoint == "service_provider_config" or endpoint == "schema":
|
||||||
|
response = getattr(self, f"call_{endpoint}")(request, **args)
|
||||||
|
else:
|
||||||
|
response = getattr(self, f"call_{endpoint}")(
|
||||||
|
request, **args, tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except RequestRedirect as e:
|
||||||
|
# urls.match may cause a redirect, handle it as a special case of HTTPException
|
||||||
|
self.log.exception(e)
|
||||||
|
return e.get_response(environ)
|
||||||
|
except HTTPException as e:
|
||||||
|
self.log.exception(e)
|
||||||
|
return self.make_error(Error(status=e.code, detail=e.description))
|
||||||
|
except SCIMException as e:
|
||||||
|
self.log.exception(e)
|
||||||
|
return self.make_error(e.scim_error)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.log.exception(e)
|
||||||
|
return self.make_error(Error(status=400, detail=str(e)))
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception(e)
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
return self.make_error(Error(status=500, detail=str(e) + "\n" + tb))
|
||||||
371
ee/api/routers/scim/users.py
Normal file
371
ee/api/routers/scim/users.py
Normal file
|
|
@ -0,0 +1,371 @@
|
||||||
|
from routers.scim import helpers
|
||||||
|
|
||||||
|
from chalicelib.utils import pg_client
|
||||||
|
from scim2_models import Resource
|
||||||
|
|
||||||
|
|
||||||
|
def convert_provider_resource_to_client_resource(
|
||||||
|
provider_resource: dict,
|
||||||
|
) -> dict:
|
||||||
|
groups = []
|
||||||
|
if provider_resource["role_id"]:
|
||||||
|
groups.append(
|
||||||
|
{
|
||||||
|
"value": str(provider_resource["role_id"]),
|
||||||
|
"$ref": f"Groups/{provider_resource['role_id']}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"id": str(provider_resource["user_id"]),
|
||||||
|
"schemas": [
|
||||||
|
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||||
|
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||||
|
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"resourceType": "User",
|
||||||
|
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||||
|
"lastModified": provider_resource["updated_at"].strftime(
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"userName": provider_resource["email"],
|
||||||
|
"externalId": provider_resource["internal_id"],
|
||||||
|
"name": {
|
||||||
|
"formatted": provider_resource["name"],
|
||||||
|
},
|
||||||
|
"displayName": provider_resource["name"] or provider_resource["email"],
|
||||||
|
"active": provider_resource["deleted_at"] is None,
|
||||||
|
"groups": groups,
|
||||||
|
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User": {
|
||||||
|
"permissions": provider_resource.get("permissions") or [],
|
||||||
|
"projectKeys": provider_resource.get("project_keys") or [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def query_resources(tenant_id: int) -> list[dict]:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
SELECT
|
||||||
|
users.*,
|
||||||
|
roles.permissions AS permissions,
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT json_agg(projects.project_key)
|
||||||
|
FROM public.projects
|
||||||
|
LEFT JOIN public.roles_projects USING (project_id)
|
||||||
|
WHERE roles_projects.role_id = roles.role_id
|
||||||
|
),
|
||||||
|
'[]'
|
||||||
|
) AS project_keys
|
||||||
|
FROM public.users
|
||||||
|
LEFT JOIN public.roles ON roles.role_id = users.role_id
|
||||||
|
WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
items = cur.fetchall()
|
||||||
|
return [convert_provider_resource_to_client_resource(item) for item in items]
|
||||||
|
|
||||||
|
|
||||||
|
def get_resource(resource_id: str, tenant_id: int) -> dict | None:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
SELECT
|
||||||
|
users.*,
|
||||||
|
roles.permissions AS permissions,
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT json_agg(projects.project_key)
|
||||||
|
FROM public.projects
|
||||||
|
LEFT JOIN public.roles_projects USING (project_id)
|
||||||
|
WHERE roles_projects.role_id = roles.role_id
|
||||||
|
),
|
||||||
|
'[]'
|
||||||
|
) AS project_keys
|
||||||
|
FROM public.users
|
||||||
|
LEFT JOIN public.roles ON roles.role_id = users.role_id
|
||||||
|
WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL AND users.user_id = {resource_id}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
item = cur.fetchone()
|
||||||
|
if item:
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_resource(resource_id: str, tenatn_id: int) -> None:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
UPDATE public.users
|
||||||
|
SET
|
||||||
|
deleted_at = NULL,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE users.user_id = %(user_id)s
|
||||||
|
""",
|
||||||
|
{"user_id": resource_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def search_existing(tenant_id: int, resource: Resource) -> dict | None:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM public.users
|
||||||
|
WHERE email = %(email)s
|
||||||
|
""",
|
||||||
|
{"email": resource.user_name},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
item = cur.fetchone()
|
||||||
|
if item:
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def restore_resource(tenant_id: int, resource: Resource) -> dict | None:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
SELECT role_id
|
||||||
|
FROM public.users
|
||||||
|
WHERE user_id = %(user_id)s
|
||||||
|
""",
|
||||||
|
{"user_id": resource.id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
item = cur.fetchone()
|
||||||
|
if item and item["role_id"] is not None:
|
||||||
|
_update_role_projects_and_permissions(
|
||||||
|
item["role_id"],
|
||||||
|
resource.OpenreplayUser.project_keys,
|
||||||
|
resource.OpenreplayUser.permissions,
|
||||||
|
cur,
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
WITH u AS (
|
||||||
|
UPDATE public.users
|
||||||
|
SET
|
||||||
|
tenant_id = %(tenant_id)s,
|
||||||
|
email = %(email)s,
|
||||||
|
name = %(name)s,
|
||||||
|
internal_id = %(internal_id)s,
|
||||||
|
deleted_at = NULL,
|
||||||
|
created_at = now(),
|
||||||
|
updated_at = now(),
|
||||||
|
api_key = default,
|
||||||
|
jwt_iat = NULL,
|
||||||
|
weekly_report = default
|
||||||
|
WHERE users.email = %(email)s
|
||||||
|
RETURNING *
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
u.*,
|
||||||
|
roles.permissions AS permissions,
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT json_agg(projects.project_key)
|
||||||
|
FROM public.projects
|
||||||
|
LEFT JOIN public.roles_projects USING (project_id)
|
||||||
|
WHERE roles_projects.role_id = roles.role_id
|
||||||
|
),
|
||||||
|
'[]'
|
||||||
|
) AS project_keys
|
||||||
|
FROM u
|
||||||
|
LEFT JOIN public.roles ON roles.role_id = u.role_id
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"email": resource.user_name,
|
||||||
|
"name": " ".join(
|
||||||
|
[
|
||||||
|
x
|
||||||
|
for x in [
|
||||||
|
resource.name.honorific_prefix,
|
||||||
|
resource.name.given_name,
|
||||||
|
resource.name.middle_name,
|
||||||
|
resource.name.family_name,
|
||||||
|
resource.name.honorific_suffix,
|
||||||
|
]
|
||||||
|
if x
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if resource.name
|
||||||
|
else "",
|
||||||
|
"internal_id": resource.external_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
item = cur.fetchone()
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
|
||||||
|
|
||||||
|
def create_resource(tenant_id: int, resource: Resource) -> dict:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
WITH u AS (
|
||||||
|
INSERT INTO public.users (
|
||||||
|
tenant_id,
|
||||||
|
email,
|
||||||
|
name,
|
||||||
|
internal_id
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
%(tenant_id)s,
|
||||||
|
%(email)s,
|
||||||
|
%(name)s,
|
||||||
|
%(internal_id)s
|
||||||
|
)
|
||||||
|
RETURNING *
|
||||||
|
)
|
||||||
|
SELECT *
|
||||||
|
FROM u
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"email": resource.user_name,
|
||||||
|
"name": " ".join(
|
||||||
|
[
|
||||||
|
x
|
||||||
|
for x in [
|
||||||
|
resource.name.honorific_prefix,
|
||||||
|
resource.name.given_name,
|
||||||
|
resource.name.middle_name,
|
||||||
|
resource.name.family_name,
|
||||||
|
resource.name.honorific_suffix,
|
||||||
|
]
|
||||||
|
if x
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if resource.name
|
||||||
|
else "",
|
||||||
|
"internal_id": resource.external_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
item = cur.fetchone()
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
|
||||||
|
|
||||||
|
def update_resource(tenant_id: int, resource: Resource) -> dict | None:
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
SELECT role_id
|
||||||
|
FROM public.users
|
||||||
|
WHERE user_id = %(user_id)s
|
||||||
|
""",
|
||||||
|
{"user_id": resource.id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
item = cur.fetchone()
|
||||||
|
if item and item["role_id"] is not None:
|
||||||
|
_update_role_projects_and_permissions(
|
||||||
|
item["role_id"],
|
||||||
|
resource.OpenreplayUser.project_keys,
|
||||||
|
resource.OpenreplayUser.permissions,
|
||||||
|
cur,
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
cur.mogrify(
|
||||||
|
"""
|
||||||
|
WITH u AS (
|
||||||
|
UPDATE public.users
|
||||||
|
SET
|
||||||
|
tenant_id = %(tenant_id)s,
|
||||||
|
email = %(email)s,
|
||||||
|
name = %(name)s,
|
||||||
|
internal_id = %(internal_id)s,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE user_id = %(user_id)s
|
||||||
|
RETURNING *
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
u.*,
|
||||||
|
roles.permissions AS permissions,
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT json_agg(projects.project_key)
|
||||||
|
FROM public.projects
|
||||||
|
LEFT JOIN public.roles_projects USING (project_id)
|
||||||
|
WHERE roles_projects.role_id = roles.role_id
|
||||||
|
),
|
||||||
|
'[]'
|
||||||
|
) AS project_keys
|
||||||
|
FROM u
|
||||||
|
LEFT JOIN public.roles ON roles.role_id = u.role_id
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"user_id": resource.id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"email": resource.user_name,
|
||||||
|
"name": " ".join(
|
||||||
|
[
|
||||||
|
x
|
||||||
|
for x in [
|
||||||
|
resource.name.honorific_prefix,
|
||||||
|
resource.name.given_name,
|
||||||
|
resource.name.middle_name,
|
||||||
|
resource.name.family_name,
|
||||||
|
resource.name.honorific_suffix,
|
||||||
|
]
|
||||||
|
if x
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if resource.name
|
||||||
|
else "",
|
||||||
|
"internal_id": resource.external_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
item = cur.fetchone()
|
||||||
|
return convert_provider_resource_to_client_resource(item)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_role_projects_and_permissions(
|
||||||
|
role_id: int,
|
||||||
|
project_keys: list[str] | None,
|
||||||
|
permissions: list[str] | None,
|
||||||
|
cur: pg_client.PostgresClient,
|
||||||
|
) -> None:
|
||||||
|
all_projects = "true" if not project_keys else "false"
|
||||||
|
project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur)
|
||||||
|
permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur)
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE public.roles
|
||||||
|
SET
|
||||||
|
updated_at = now(),
|
||||||
|
all_projects = {all_projects},
|
||||||
|
permissions = {permission_clause}
|
||||||
|
WHERE role_id = {role_id}
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
DELETE FROM public.roles_projects
|
||||||
|
WHERE roles_projects.role_id = {role_id}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO public.roles_projects (role_id, project_id)
|
||||||
|
SELECT {role_id}, projects.project_id
|
||||||
|
FROM public.projects
|
||||||
|
WHERE projects.project_key = ANY({project_key_clause})
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
@ -108,6 +108,16 @@ CREATE TABLE public.tenants
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE public.scim_auth_codes
|
||||||
|
(
|
||||||
|
auth_code_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||||
|
tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE,
|
||||||
|
auth_code text NOT NULL UNIQUE DEFAULT generate_api_key(20),
|
||||||
|
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||||
|
used bool NOT NULL DEFAULT FALSE
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE public.roles
|
CREATE TABLE public.roles
|
||||||
(
|
(
|
||||||
role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
|
role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||||
|
|
@ -118,6 +128,7 @@ CREATE TABLE public.roles
|
||||||
protected bool NOT NULL DEFAULT FALSE,
|
protected bool NOT NULL DEFAULT FALSE,
|
||||||
all_projects bool NOT NULL DEFAULT TRUE,
|
all_projects bool NOT NULL DEFAULT TRUE,
|
||||||
created_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()),
|
created_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()),
|
||||||
|
updated_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()),
|
||||||
deleted_at timestamp NULL DEFAULT NULL,
|
deleted_at timestamp NULL DEFAULT NULL,
|
||||||
service_role bool NOT NULL DEFAULT FALSE
|
service_role bool NOT NULL DEFAULT FALSE
|
||||||
);
|
);
|
||||||
|
|
@ -132,6 +143,7 @@ CREATE TABLE public.users
|
||||||
role user_role NOT NULL DEFAULT 'member',
|
role user_role NOT NULL DEFAULT 'member',
|
||||||
name text NOT NULL,
|
name text NOT NULL,
|
||||||
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||||
|
updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||||
deleted_at timestamp without time zone NULL DEFAULT NULL,
|
deleted_at timestamp without time zone NULL DEFAULT NULL,
|
||||||
api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL,
|
api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL,
|
||||||
jwt_iat timestamp without time zone NULL DEFAULT NULL,
|
jwt_iat timestamp without time zone NULL DEFAULT NULL,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue