dump-things-server/dump_things_service/main.py
Christian Monch 34fadbd73a replace schema-attribute in pydantic-objects
Replaces the `schema`-attribute with the
attribute `schema_location`. That prevents
shadowing of pydantic's internal `schema`-attribute
and gets rid of pydantic-issues warnings.
2026-06-10 16:35:48 +02:00

664 lines
19 KiB
Python

from __future__ import annotations # noqa: I001 -- the patches have to be imported early
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from dump_things_service.abstract_config import store_config
from dump_things_service.commands.load_config import convert_to_new_format
from dump_things_service.manifest import manifest_configuration
# Perform the patching before importing any third-party libraries
from dump_things_service.patches import enabled # noqa F401 -- used by generated code
import yaml
import uvicorn
from fastapi import (
Depends,
FastAPI,
HTTPException,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import (
Page,
add_pagination,
paginate,
)
from fastapi_pagination.utils import disable_installed_extensions_check
from pydantic import (
BaseModel,
Field,
)
from starlette.responses import (
PlainTextResponse,
RedirectResponse,
)
from dump_things_service import (
HTTP_400_BAD_REQUEST,
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
Format,
)
from dump_things_service.__about__ import __version__
from dump_things_service.abstract_config import (
Configuration,
check_collection,
hash_token_representation,
read_config,
)
from dump_things_service.api_key import api_key_header_scheme
from dump_things_service.converter import (
FormatConverter,
ConvertingList,
)
from dump_things_service.curated import router as curated_router
from dump_things_service.exceptions import CurieResolutionError
from dump_things_service.incoming import router as incoming_router
from dump_things_service.instance_state import create_instance_state, \
InstanceState
from dump_things_service.lazy_list import (
PriorityList,
ModifierList,
)
from dump_things_service.model import get_subclasses
from dump_things_service.collection_endpoints import router as collection_router
from dump_things_service.token_endpoints import (
hash_matcher,
router as token_router,
)
from dump_things_service.utils import (
authenticate_token,
check_bounds,
process_token,
wrap_http_exception,
)
if TYPE_CHECKING:
from dump_things_service.lazy_list import LazyList
class MaintenanceRequest(BaseModel):
collection: str
active: bool
class ServerCollectionResponse(BaseModel):
name: str
schema_location: str = Field(alias='schema')
classes: list[str]
class ServerCollectionCountedResponse(ServerCollectionResponse):
records: int
class ServerResponse(BaseModel):
version: str
collections: list[ServerCollectionResponse|ServerCollectionCountedResponse]
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('dump_things_service')
parser = argparse.ArgumentParser()
parser.add_argument('--host', default='0.0.0.0') # noqa S104
parser.add_argument('--port', default=8000, type=int)
parser.add_argument('--origins', action='append', default=[])
parser.add_argument(
'--admin-token-hash',
type=str,
default='',
help='The sha256 hash of an initial admin token that will allow to add or '
'remove tokens, collections, and additional admin tokens (64 '
'characters hex-digit). NOTE: an admin token in plaintext is read '
'from the environment variable `DTS_ADMIN_TOKEN` if it is set, and '
'if this option is not provided.',
)
parser.add_argument(
'-c',
'--config',
metavar='CONFIG_FILE',
help="Read the configuration from 'CONFIG_FILE' if no persisted "
"configuration is found in the data store root directory, and "
"initialize the persistent configuration and the service state with "
"the values in 'CONFIG_FILE'.",
)
parser.add_argument(
'--root-path',
default='',
help="Set the ASGI 'root_path' for applications submounted below a given URL path.",
)
parser.add_argument(
'--log-level',
default='WARNING',
help="Set the log level for the service, allowed values are 'ERROR', 'WARNING', 'INFO', 'DEBUG'. Default is 'warning'.",
)
parser.add_argument(
'store',
help='The root of the data store, it should contain a global_store and token_stores.',
)
description = """
A service to store and retrieve data that is structured according to given
schemata.
Data is stored in **collections**.
Each collection has a name and an associated schema.
All data records in the collection have to adhere to the given schema.
Users store data in an incoming area and read data from a curated area and their
incoming area. There can be many incoming areas, but only one curated area.
Curators store data in an incoming area or in the curated area and read data
from any incoming area or the curated area.
For more information refer to the [README-file](https://hub.psychoinformatics.de/orinoco/dump-things-server)
of the project.
"""
arguments = parser.parse_args()
# Try to get bootstrap token from environment if an admin token hash is
# not provided via option
if not arguments.admin_token_hash:
if 'DTS_ADMIN_TOKEN' in os.environ:
arguments.admin_token_hash = hash_token_representation(
os.environ.get('DTS_ADMIN_TOKEN', ''),
)
else:
# Validate the hash token format
if not hash_matcher.match(arguments.admin_token_hash):
print(
'Hashed admin token is not a 64-digits hex-number',
file=sys.stderr,
flush=True,
)
sys.exit(1)
# Set the log level
numeric_level = getattr(logging, arguments.log_level.upper(), None)
if not isinstance(numeric_level, int):
logger.error(
'Invalid log level: %s, defaulting to level "WARNING"',
arguments.log_level,
)
else:
logger.setLevel(level=numeric_level)
store_path = Path(arguments.store).resolve()
if not store_path.exists():
logger.error(f'Store path does not exist: {store_path}')
raise SystemExit(1)
disable_installed_extensions_check()
app = FastAPI(
title='Dump Things Service',
description=description,
version=__version__,
)
app.include_router(curated_router)
app.include_router(incoming_router)
app.include_router(token_router)
app.include_router(collection_router)
# Add CORS origins
app.add_middleware(
CORSMiddleware,
allow_origins=arguments.origins,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
# Add pagination
add_pagination(app)
g_instance_state = create_instance_state(
store_path=store_path,
bootstrap_token=arguments.admin_token_hash,
fastapi_app=app,
)
g_configuration = read_config(store_path)
def initialize_from_config_file(
instance_state: InstanceState,
config_file: str,
) -> Configuration:
with open(config_file) as f:
config_dict = yaml.safe_load(f)
config_version = config_dict['version']
if config_version == 1:
logger.info(
'Converting version 1 configuration at %s',
arguments.config,
)
config_dict = convert_to_new_format(
config_dict,
instance_state.store_path,
)
elif config_version != 2:
msg = f'Invalid version in config file: {config_version}'
raise ValueError(msg)
return Configuration(**config_dict)
# If the configuration is empty, check for configuration option
if not (
g_configuration.admin_tokens
or g_configuration.collections
or g_configuration.tokens
):
if arguments.config:
logger.info(
'Initializing empty persisted configuration from %s',
arguments.config,
)
g_configuration = initialize_from_config_file(
g_instance_state,
arguments.config,
)
# Persist the configuration
store_config(
store_path=g_instance_state.store_path,
config=g_configuration,
)
manifest_configuration(
configuration=g_configuration,
instance_state=g_instance_state,
)
g_instance_state.fastapi_app.openapi_schema = None
g_instance_state.fastapi_app.setup()
add_pagination(g_instance_state.fastapi_app)
@app.get('/', response_class=RedirectResponse)
async def root() -> RedirectResponse:
return RedirectResponse('/docs')
@app.get(
'/server',
tags=['Server management'],
name='get server information'
)
async def server() -> ServerResponse:
return ServerResponse(
version = __version__,
collections = [
ServerCollectionResponse(
name=collection_name,
schema=g_configuration.collections[collection_name].schema_location,
classes=g_instance_state.schema_info[g_configuration.collections[collection_name].schema_location].classes,
)
for collection_name in g_configuration.collections
]
)
@app.post(
'/maintenance',
tags=['Server management'],
name='put a collection in maintenance mode'
)
async def maintenance(
body: MaintenanceRequest,
api_key: str | None = Depends(api_key_header_scheme),
):
if api_key is None:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f'Token required for this operation',
)
collection = body.collection
active = body.active
# Try to authenticate the token with the authentication providers that
# are associated with the collection.
check_collection(g_configuration, collection)
auth_info = authenticate_token(g_instance_state, collection, api_key)
permissions = auth_info.token_permission
if not (
permissions.curated_write
and permissions.curated_read
and permissions.zones_access
):
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f'Curator permissions required for this operation',
)
if active:
g_instance_state.maintenance_mode.add(collection)
else:
g_instance_state.maintenance_mode.remove(collection)
return
@app.get(
'/{collection}/record',
tags=['Read records'],
name='Read the record with the given PID from the given collection',
)
async def read_record_with_pid(
collection: str,
pid: str,
format: Format = Format.json, # noqa A002
api_key: str = Depends(api_key_header_scheme),
):
check_collection(g_configuration, collection)
final_permissions, token_store = await process_token(
g_configuration, g_instance_state, api_key, collection
)
class_name, json_object = None, None
if final_permissions.incoming_read:
with wrap_http_exception(CurieResolutionError, header='CURIE error:'):
class_name, json_object = token_store.get_object_by_pid(pid)
if not json_object and final_permissions.curated_read:
with wrap_http_exception(CurieResolutionError, header='CURIE error:'):
class_name, json_object = g_instance_state.curated_stores[
collection
].get_object_by_pid(pid)
if not json_object:
return None
if format == Format.ttl:
converter = FormatConverter(
schema=g_configuration.collections[collection].schema_location,
input_format=Format.json,
output_format=format,
)
with wrap_http_exception(ValueError, header='Conversion error'):
ttl_record = converter.convert(json_object, class_name)
return PlainTextResponse(ttl_record, media_type='text/turtle')
return json_object
@app.get(
'/{collection}/records/',
tags=['Read records'],
name='Read all records from the given collection',
)
async def read_all_records(
collection: str,
matching: str | None = None,
format: Format = Format.json, # noqa A002
api_key: str = Depends(api_key_header_scheme),
):
return await _read_all_records(
collection=collection,
matching=matching,
format=format,
api_key=api_key,
# Set an upper limit for the number of non-paginated result records to
# keep processing time for individual requests short and avoid
# overloading the server.
bound=1000,
)
@app.get(
'/{collection}/records/p/',
tags=['Read records'],
name='Read all records from the given collection with pagination',
)
async def read_all_records_paginated(
collection: str,
matching: str | None = None,
format: Format = Format.json, # noqa A002
api_key: str = Depends(api_key_header_scheme),
) -> Page[dict | str]:
result_list = await _read_all_records(
collection=collection,
matching=matching,
format=format,
api_key=api_key,
bound=None,
)
return paginate(result_list)
@app.get(
'/{collection}/records/{class_name}',
tags=['Read records'],
name='Read records of the given class (or subclass) from the given collection',
)
async def read_records_of_type(
collection: str,
class_name: str,
matching: str | None = None,
format: Format = Format.json, # noqa A002
api_key: str = Depends(api_key_header_scheme),
):
return await _read_records_of_type(
collection=collection,
class_name=class_name,
matching=matching,
format=format,
api_key=api_key,
# Set an upper limit for the number of non-paginated result records to
# keep processing time for individual requests short and avoid
# overloading the server.
bound=1000,
)
@app.get(
'/{collection}/records/p/{class_name}',
tags=['Read records'],
name='Read records of the given class (or subclass) from the given collection with pagination',
)
async def read_records_of_type_paginated(
collection: str,
class_name: str,
matching: str | None = None,
format: Format = Format.json, # noqa A002
api_key: str = Depends(api_key_header_scheme),
) -> Page[dict | str]:
result_list = await _read_records_of_type(
collection=collection,
class_name=class_name,
matching=matching,
format=format,
api_key=api_key,
bound=None,
)
return paginate(result_list)
async def _read_all_records(
collection: str,
matching: str | None = None,
format: Format = Format.json, # noqa A002
api_key: str = Depends(api_key_header_scheme),
bound: int | None = None,
) -> LazyList:
def convert_to_http_exception(e: BaseException):
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f'Conversion error: {e}',
) from e
check_collection(g_configuration, collection)
final_permissions, token_store = await process_token(
g_configuration, g_instance_state, api_key, collection
)
result_list = PriorityList()
if final_permissions.incoming_read:
token_store_list = token_store.get_all_objects(matching=matching)
if bound:
check_bounds(len(token_store_list), bound, collection, 'records/p/')
result_list.add_list(token_store_list)
if final_permissions.curated_read:
curated_store_list = g_instance_state.curated_stores[
collection
].get_all_objects(
matching=matching,
)
if bound:
check_bounds(len(curated_store_list), bound, collection, 'records/p/')
result_list.add_list(curated_store_list)
# Sort the result list.
result_list.sort(key=result_list.sort_key)
if format == Format.ttl:
result_list = ConvertingList(
result_list,
g_configuration.collections[collection].schema_location,
input_format=Format.json,
output_format=format,
exception_handler=convert_to_http_exception,
)
else:
result_list = ModifierList(
result_list,
lambda record_info: record_info.json_object,
)
return result_list
async def _read_records_of_type(
collection: str,
class_name: str,
matching: str | None = None,
format: Format = Format.json, # noqa A002
api_key: str = Depends(api_key_header_scheme),
bound: int | None = None,
) -> LazyList:
def convert_to_http_exception(e: BaseException):
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f'Conversion error: {e}',
) from e
check_collection(g_configuration, collection)
schema_location = g_configuration.collections[collection].schema_location
model = g_instance_state.schema_info[schema_location].pydantic_module_info.module
if class_name not in g_instance_state.collections[collection].active_classes:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"No '{class_name}'-class in collection '{collection}'.",
)
final_permissions, token_store = await process_token(
g_configuration, g_instance_state, api_key, collection
)
result_list = PriorityList()
if final_permissions.incoming_read:
for search_class_name in get_subclasses(model, class_name):
token_store_list = token_store.get_objects_of_class(
class_name=search_class_name,
matching=matching,
)
if bound:
check_bounds(len(token_store_list), bound, collection, f'/records/p/{class_name}')
result_list.add_list(token_store_list)
if final_permissions.curated_read:
for search_class_name in get_subclasses(model, class_name):
curated_store_list = g_instance_state.curated_stores[
collection
].get_objects_of_class(
class_name=search_class_name,
matching=matching,
)
if bound:
check_bounds(len(curated_store_list), bound, collection, f'/records/p/{class_name}')
result_list.add_list(curated_store_list)
# Sort the result list.
result_list.sort(key=result_list.sort_key)
if format == Format.ttl:
result_list = ConvertingList(
result_list,
schema_location,
input_format=Format.json,
output_format=format,
exception_handler=convert_to_http_exception,
)
else:
result_list = ModifierList(
result_list,
lambda record_info: record_info.json_object,
)
return result_list
@app.delete(
'/{collection}/record',
tags=['Delete records'],
name='Delete record with the given pid from the given collection',
)
async def delete_record(
collection: str,
pid: str,
api_key: str = Depends(api_key_header_scheme),
):
check_collection(g_configuration, collection)
final_permissions, token_store = await process_token(
g_configuration, g_instance_state, api_key, collection
)
if not final_permissions.incoming_write:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail=f"No write access to incoming data in collection '{collection}'.",
)
with wrap_http_exception(Exception):
result = token_store.delete_object(pid)
if not result:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Could not remove record with PID '{pid}' from the "
"token associated incoming area of collection "
f"'{collection}'.",
)
return True
def main():
uvicorn.run(
app,
host=arguments.host,
port=arguments.port,
root_path=arguments.root_path,
)
if __name__ == '__main__':
main()