Initial commit

This commit is contained in:
hiperman
2025-12-04 00:33:37 -05:00
commit 7ca0a21283
798 changed files with 190424 additions and 0 deletions

20
backend/.dockerignore Normal file
View File

@@ -0,0 +1,20 @@
# Version control
.git/
.gitignore
.gitattributes
# Python
**/__pycache__/
# Docker
Dockerfile*
docker-compose*.yml
.dockerignore
# Development
.env.local
.env.development
*.log
# Formatting and linting
**/.ruff_cache

15
backend/.gitignore vendored Normal file
View File

@@ -0,0 +1,15 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
# Project specific directories/files
covers/
books/
.postgres/

1
backend/.python-version Normal file
View File

@@ -0,0 +1 @@
3.13

57
backend/Dockerfile Normal file
View File

@@ -0,0 +1,57 @@
# An example using multi-stage image builds to create a final image without uv.
# First, build the application in the `/app` directory.
# See `Dockerfile` for details.
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
ENV UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy
# Disable Python downloads, because we want to use the system interpreter
# across both images. If using a managed Python version, it needs to be
# copied from the build image into the final image; see `standalone.Dockerfile`
# for an example.
ENV UV_PYTHON_DOWNLOADS=0
WORKDIR /app
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv sync --locked --no-install-project --no-dev
COPY uv.lock pyproject.toml README.md /app/
COPY src /app
COPY migrations ./migrations
COPY entrypoint.sh alembic.ini /app/
RUN chmod +x /app/entrypoint.sh
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --locked --no-dev --no-editable
FROM python:3.13-slim-bookworm
# Setup a non-root user
RUN groupadd -g 1000 appuser && \
useradd -u 1000 -g appuser -m -d /app -s /sbin/nologin appuser
RUN apt-get update && apt-get install -y \
gosu \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy the application from the builder
COPY --from=builder --chown=appuser:appuser /app /app
# Place executables in the environment at the front of the path
ENV PATH="/app/.venv/bin:$PATH"
HEALTHCHECK --interval=10s --timeout=3s --retries=5 \
CMD curl -f http://localhost:8000/healthcheck || exit 1
# Use `/app` as the working directory
WORKDIR /app
ENTRYPOINT [ "/app/entrypoint.sh" ]
# Run the web application
CMD ["litestar", "--app-dir", "chitai", "run", "--host", "0.0.0.0", "--port", "8000"]

127
backend/README.md Normal file
View File

@@ -0,0 +1,127 @@
# Chitai API
A RESTful API for managing ebooks, built with [Litestar](https://litestar.dev/) and Python.
## Overview
This backend service provides a comprehensive API for ebook management, including:
- Ebook catalog management (CRUD operations)
- Virtual libraries and user bookshelves
- Metadata management (authors, publishers, tags, identifiers)
- File storage and retrieval
- Search and filtering capabilities
- Reading progress tracking
- User library management
**Tech Stack:**
- Python 3.13
- Litestar (ASGI web framework)
- Advanced-alchemy/SQLAlchemy (ORM)
- PostgreSQL (database)
- Alembic (migrations)
## Prerequisites
- Python 3.13
- PostgreSQL 17
- uv
## Getting Started
### 1. Install Dependencies
Using uv (recommended):
```bash
uv sync
```
### 2. Environment Configuration
Copy the example environment file and configure environment variables:
```bash
cp .env.example .env
```
### 3. Database Setup
Run migrations:
```bash
alchemy --config chitai.database.config.config upgrade
```
### 4. Run the Application
Development mode with auto-reload:
```bash
uv run litestar --app-dir src/chitai/ run --reload
```
The API will be available at `http://localhost:8000`
## Development
### Project Structure
```
backend/
├── src/
| └── chitai
| | ├── app.py # Litestar app initialization
| | ├── config.py # Configuration settings
| | ├── controllers # API route handlers
| | ├── database
| | │ ├── models # SQLAlchemy models
| | ├── exceptions # Custom exceptions and handlers
| | ├── schemas # Pydantic schemas (DTOs)
| | └── services # Business logic layer
├── migrations/ # Alembic migrations
├── tests/
│ ├── unit/
│ ├── integration/
│ └── conftest.py
├── alembic.ini
├── pyproject.toml
└── README.md
```
### Running Tests
Run all tests:
```bash
pytest tests/
```
Run specific test categories:
```bash
# Unit tests only
pytest tests/unit
# Integration tests only
pytest tests/integration
```
### Code Quality
Format code:
```bash
ruff format src/
```
### Creating Database Migrations
After modifying models:
```bash
alchemy --config chitai.database.config.config make-migrations
```
For manual migrations:
```bash
alchemy --config chitai.database.config.config make-migrations --no-autogenerate
```
## API Documentation
Once the server is running, interactive API documentation is available at:
- **Swagger UI**: http://localhost:8000/schema/
- **OpenAPI JSON**: http://localhost:8000/schema/openapi.json

73
backend/alembic.ini Normal file
View File

@@ -0,0 +1,73 @@
# Advanced Alchemy Alembic Asyncio Config
[alembic]
prepend_sys_path = src:.
# path to migration scripts
script_location = migrations
# template used to generate migration files
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(slug)s_%%(rev)s
# This is not required to be set when running through `advanced_alchemy`
# sqlalchemy.url = driver://user:pass@localhost/dbname
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone = UTC
# max length of characters to apply to the
# "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
output_encoding = utf-8
# [post_write_hooks]
# This section defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner,
# against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME

39
backend/entrypoint.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/sh
set -e
PUID=${PUID:-1000}
PGID=${PGID:-1000}
echo "Starting with UID=$PUID, GID=$PGID"
WRITABLE_DIRS="/app/covers /app/libraries"
# Modify user PGID and PUID to match ther user on the host
if [ "$PUID" != "1000" ] || [ "$PGID" != "1000" ]; then
echo "Adjusting user to UID=$PUID, GID=$PGID"
groupmod -o -g "$PGID" appuser
usermod -o -u "$PUID" appuser
# Update permissions on writable directories
for dir in $WRITABLE_DIRS; do
if [ -d "$dir" ]; then
echo "Fixing permissions for $dir"
chown -R appuser:appuser "$dir"
chmod -R 755 "$dir"
else
echo "Creating $dir"
mkdir -p "$dir"
chown -R appuser:appuser "$dir"
chmod -R 755 "$dir"
fi
done
fi
echo "Running database migrations..."
alchemy --config chitai.database.config.config upgrade --no-prompt
echo "Starting application..."
exec gosu appuser "$@"

View File

107
backend/migrations/env.py Normal file
View File

@@ -0,0 +1,107 @@
import asyncio
from typing import TYPE_CHECKING, cast
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import AsyncEngine, async_engine_from_config
from advanced_alchemy.base import metadata_registry
from alembic import context
from alembic.autogenerate import rewriter
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from advanced_alchemy.alembic.commands import AlembicCommandConfig
__all__ = ("do_run_migrations", "run_migrations_offline", "run_migrations_online")
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config: "AlembicCommandConfig" = context.config # type: ignore
writer = rewriter.Rewriter()
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=config.db_url,
target_metadata=metadata_registry.get(config.bind_key),
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: "Connection") -> None:
"""Run migrations."""
context.configure(
connection=connection,
target_metadata=metadata_registry.get(config.bind_key),
compare_type=config.compare_type,
version_table=config.version_table_name,
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer,
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine and associate a
connection with the context.
Raises:
RuntimeError: If the engine cannot be created from the config.
"""
configuration = config.get_section(config.config_ini_section) or {}
configuration["sqlalchemy.url"] = config.db_url
connectable = cast(
"AsyncEngine",
config.engine
or async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
),
)
if connectable is None: # pyright: ignore[reportUnnecessaryComparison]
msg = "Could not get engine from config. Please ensure your `alembic.ini` according to the official Alembic documentation."
raise RuntimeError(
msg,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

View File

@@ -0,0 +1,72 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
import warnings
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash, FernetBackend
from advanced_alchemy.types.encrypted_string import PGCryptoBackend
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
from sqlalchemy import Text # noqa: F401
${imports if imports else ""}
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"]
sa.GUID = GUID
sa.DateTimeUTC = DateTimeUTC
sa.ORA_JSONB = ORA_JSONB
sa.EncryptedString = EncryptedString
sa.EncryptedText = EncryptedText
sa.StoredObject = StoredObject
sa.PasswordHash = PasswordHash
sa.Argon2Hasher = Argon2Hasher
sa.PasslibHasher = PasslibHasher
sa.PwdlibHasher = PwdlibHasher
sa.FernetBackend = FernetBackend
sa.PGCryptoBackend = PGCryptoBackend
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
schema_upgrades()
data_upgrades()
def downgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
data_downgrades()
schema_downgrades()
def schema_upgrades() -> None:
"""schema upgrade migrations go here."""
${upgrades if upgrades else "pass"}
def schema_downgrades() -> None:
"""schema downgrade migrations go here."""
${downgrades if downgrades else "pass"}
def data_upgrades() -> None:
"""Add any optional data upgrade migrations here!"""
def data_downgrades() -> None:
"""Add any optional data downgrade migrations here!"""

View File

@@ -0,0 +1,65 @@
"""Add pg_trgm extension
Revision ID: 26022ec86f32
Revises:
Create Date: 2025-10-31 18:45:55.027462
"""
import warnings
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash
from sqlalchemy import Text # noqa: F401
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"]
sa.GUID = GUID
sa.DateTimeUTC = DateTimeUTC
sa.ORA_JSONB = ORA_JSONB
sa.EncryptedString = EncryptedString
sa.EncryptedText = EncryptedText
sa.StoredObject = StoredObject
sa.PasswordHash = PasswordHash
# revision identifiers, used by Alembic.
revision = '26022ec86f32'
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(sa.text('create EXTENSION if not EXISTS "pgcrypto"'))
op.execute(sa.text('create EXTENSION if not EXISTS "pg_trgm"'))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
schema_upgrades()
data_upgrades()
def downgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
data_downgrades()
schema_downgrades()
def schema_upgrades() -> None:
"""schema upgrade migrations go here."""
pass
def schema_downgrades() -> None:
"""schema downgrade migrations go here."""
pass
def data_upgrades() -> None:
"""Add any optional data upgrade migrations here!"""
def data_downgrades() -> None:
"""Add any optional data downgrade migrations here!"""

View File

43
backend/pyproject.toml Normal file
View File

@@ -0,0 +1,43 @@
[project]
name = "chitai"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
authors = [
{ name = "patrykj", email = "patrykjaroszewski@proton.me" }
]
requires-python = ">=3.13"
dependencies = [
"advanced-alchemy==1.8.0",
"aiofiles>=24.1.0",
"asyncpg>=0.30.0",
"ebooklib>=0.19",
"jinja2>=3.1.6",
"litestar[jwt,standard]>=2.16.0",
"passlib>=1.7.4",
"pillow>=11.2.1",
"pwdlib[argon2]>=0.2.1",
"pydantic>=2.11.5",
"pydantic-settings>=2.9.1",
"pypdfium2>=4.30.0",
"watchfiles>=1.1.1",
"xmltodict>=1.0.2",
]
[project.scripts]
chitai = "chitai:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[dependency-groups]
dev = [
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-databases[postgres]>=0.15.0",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

91
backend/shell.nix Normal file
View File

@@ -0,0 +1,91 @@
{ pkgs ? import <nixpkgs> {}}:
pkgs.mkShell {
buildInputs = with pkgs; [
# Python development environment for Chitai
python313Full
python313Packages.greenlet
python313Packages.ruff
uv
# postgres database
postgresql
];
shellHook = ''
# Only cd if we're not already in backend
if [[ $(basename $(pwd)) != "backend" ]]; then
cd backend
fi
BACKEND_PATH=$(pwd)
echo "Initializing Python development environment with uv"
# Create a virtual environment if it doesn't exist
if [ ! -d ".venv" ]; then
echo "Creating virtual environment..."
uv venv .venv
fi
# Activate the virtual environment
source .venv/bin/activate
echo "Python environment ready!"
echo "Python version: $(python --version)"
echo "uv version: $(uv --version)"
echo "Virtual environment activated at: $VIRTUAL_ENV"
# Install required packages
uv sync
echo "Successfully installed packages!"
echo "Setting up the postgres database"
export PGHOST=$BACKEND_PATH/.postgres
export PGDATA=$PGHOST/data
export PGDATABASE=chitai
export PGUSERNAME=chitai_user
export PGPASSWORD="chitai_password"
export PGLOG=$PGHOST/postgres.log
export LD_LIBRARY_PATH="${pkgs.postgresql.lib}/lib:$LD_LIBRARY_PATH"
mkdir -p $PGHOST
# Initialize the postgres database if not present
if [ ! -d "$PGDATA" ]; then
echo "Initializing PostgreSQL..."
initdb -D $BACKEND_PATH/.postgres/data \
-U $PGUSERNAME \
--pwfile=<(echo "$PGPASSWORD") \
--auth=md5 \
--encoding=UTF-8
fi
if ! pg_ctl status > /dev/null 2>&1
then
echo "Starting PostgreSQL..."
pg_ctl start -l $PGLOG -o "--unix_socket_directories='$PGHOST'"
fi
echo "PostgreSQL is running!"
if ! psql -h $PGHOST -d postgres -lqt 2>/dev/null | cut -d \| -f 1 | grep -qw chitai; then
echo "Creating database chitai..."
createdb -U $PGUSERNAME -h $PGHOST chitai
fi
# Run database migrations
uv run alchemy --config chitai.database.config.config upgrade --no-prompt
# Return to root directory
cd -
'';
exitHook = ''
BACKEND_PATH=$(pwd)/backend
cd $BACKEND_PATH
echo "Stopping PostgreSQL..."
pg_ctl stop
cd -
'';
}

View File

@@ -0,0 +1,2 @@
def main() -> None:
print("Hello from chitai!")

136
backend/src/chitai/app.py Normal file
View File

@@ -0,0 +1,136 @@
import asyncio
from typing import Any
from chitai.services.book import BookService
from chitai.services.consume import ConsumeDirectoryWatcher
from litestar import Litestar, get
from litestar.openapi.config import OpenAPIConfig
from litestar.openapi.plugins import SwaggerRenderPlugin
from litestar.di import Provide
from litestar.security.jwt import OAuth2PasswordBearerAuth, Token
from litestar.connection import ASGIConnection
from litestar.static_files import create_static_files_router
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from chitai import controllers as c
from chitai.config import settings
from chitai.database.config import alchemy
from chitai.database.models.user import User
from chitai.schemas.library import LibraryCreate
from chitai.services.dependencies import provide_user, provide_user_service
from chitai.services.library import LibraryService
from chitai.services.utils import create_directory
from chitai.exceptions.handlers import exception_handlers
@get("/")
async def index() -> str:
return "Hello from Chitai!"
@get("/healthcheck")
async def healthcheck(db_session: AsyncSession) -> dict[str, str]:
await db_session.execute(text("SELECT 1"))
return {"status": "OK"}
async def retrieve_user_handler(
token: Token, connection: ASGIConnection[Any, Any, Any, Any]
) -> User | None:
service = await anext(
provide_user_service(
settings.alchemy_config.provide_session(
connection.app.state, connection.scope
)
)
)
return await service.get_one_or_none(email=token.sub)
oauth2_auth = OAuth2PasswordBearerAuth[User](
retrieve_user_handler=retrieve_user_handler,
token_secret=settings.token_secret,
# we are specifying the URL for retrieving a JWT access token
token_url="/access/login",
# we are specifying which endpoints should be excluded from authentication. In this case the login endpoint
# and our openAPI docs.
exclude=[
"/healthcheck",
"/access/login",
"/access/login/token",
"/access/signup",
"/opds",
"/schema",
],
)
watcher_task: asyncio.Task
async def startup():
"""Run setup."""
# Setup databse
async with settings.alchemy_config.get_session() as db_session:
# Create default library if none exist
library_service = LibraryService(session=db_session)
_, total = await library_service.list_and_count()
if total == 0:
await library_service.create(
LibraryCreate(
name=settings.default_library_name,
root_path=settings.default_library_path,
)
)
await db_session.commit()
# book_service = BookService(session=db_session)
# Create book covers directory if it does not exist
await create_directory(settings.book_cover_path)
# Create consume directory
await create_directory(settings.consume_path)
# file_watcher = ConsumeDirectoryWatcher(settings.consume_path, library_service, book_service)
# watcher_task = asyncio.create_task(file_watcher.init_watcher())
async def shutdown():
""" Run shutdown tasks. """
watcher_task.cancel()
def create_app() -> Litestar:
return Litestar(
route_handlers=[
c.BookController,
c.LibraryController,
c.AccessController,
c.BookshelfController,
c.AuthorController,
c.PublisherController,
c.TagController,
c.OpdsController,
create_static_files_router(path="/covers", directories=["./covers"]),
index,
healthcheck,
],
exception_handlers=exception_handlers,
on_startup=[startup],
plugins=[alchemy],
on_app_init=[oauth2_auth.on_app_init],
openapi_config=OpenAPIConfig(
title="Chitai",
description="Chitai API docs",
version="0.0.1",
render_plugins=[SwaggerRenderPlugin()],
),
dependencies={"current_user": Provide(provide_user, sync_to_thread=False)},
debug=settings.api_debug,
)
app = create_app()

View File

@@ -0,0 +1,55 @@
from pydantic import Field, PostgresDsn, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict
from advanced_alchemy.extensions.litestar import (
SQLAlchemyAsyncConfig,
)
class Settings(BaseSettings):
version: str = Field("0.0.1")
project_name: str = Field("chitai")
# Database settings
postgres_db: str = Field("chitai")
postgres_user: str = Field("chitai_user")
postgres_password: str = Field("chitai_password")
postgres_host: str = Field("localhost")
postgres_port: str = Field("5432")
postgres_echo: bool = Field(False)
postgres_scheme: str = "postgresql+asyncpg"
# Debug settings
api_debug: bool = Field(False)
# JWT token secret
token_secret: str = Field("secret")
# Defaut library
default_library_name: str = Field("Books")
default_library_path: str = Field("./books")
# Path to book covers
book_cover_path: str = Field("./covers")
# Path to consume directory
consume_path: str = Field("./consume")
@computed_field
@property
def postgres_uri(self) -> PostgresDsn:
return PostgresDsn(
f"{self.postgres_scheme}://{self.postgres_user}:{self.postgres_password}"
f"@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
)
model_config = SettingsConfigDict(
env_prefix="chitai_", env_file=".env", validate_default=False, extra="ignore"
)
@property
def alchemy_config(self) -> SQLAlchemyAsyncConfig:
from chitai.database.config import config
return config
settings = Settings() # type: ignore

View File

@@ -0,0 +1,8 @@
from .book import BookController
from .library import LibraryController
from .access import AccessController
from .shelf import BookshelfController
from .author import AuthorController
from .tag import TagController
from .publisher import PublisherController
from .opds import OpdsController

View File

@@ -0,0 +1,116 @@
# src/chitai/controllers/access.py
# Standard library
from typing import Annotated, Any
import logging
# Third-party libraries
from litestar.security.jwt import OAuth2Login
from litestar import Controller, Response, get, post
from litestar.params import Body
from litestar.enums import RequestEncodingType
from litestar.status_codes import HTTP_409_CONFLICT, HTTP_401_UNAUTHORIZED
from litestar.exceptions import HTTPException, PermissionDeniedException
from advanced_alchemy.extensions.litestar.providers import create_service_dependencies
from advanced_alchemy.exceptions import DuplicateKeyError
# Local imports
from chitai.services.user import UserService
from chitai.schemas.user import UserCreate, UserLogin, UserRead
from chitai.database.models.user import User
logger = logging.getLogger(__name__)
class AccessController(Controller):
"""Controller for user authentication and access management."""
path = "/access"
dependencies = create_service_dependencies(UserService, key="users_service")
@post("/signup")
async def register(self, data: UserCreate, users_service: UserService) -> UserRead:
"""
Register a new user account.
Creates a new user with the provided credentials. Email addresses must be unique.
Request Body:
data: User registration data including email and password.
Injected Dependencies:
users_service: The user service for database operations.
Returns:
The created user account as a UserRead schema.
Raises:
HTTPException (409): If a user with the provided email already exists.
"""
try:
user = await users_service.create(data)
return users_service.to_schema(user, schema_type=UserRead)
except DuplicateKeyError:
raise HTTPException(
status_code=HTTP_409_CONFLICT,
detail="A user with this email already exists",
)
@post("/login")
async def login(
self,
data: Annotated[
UserLogin,
Body(title="OAuth2 Login", media_type=RequestEncodingType.URL_ENCODED),
],
users_service: UserService,
) -> Response[OAuth2Login]:
"""
Authenticate a user and generate an OAuth2 token.
Verifies the provided credentials and returns an authentication token
for subsequent authenticated requests.
Request Body:
data: Login credentials (email and password).
Injected Dependencies:
users_service: The user service for authentication.
Returns:
OAuth2 token and login information.
Raises:
HTTPException (401): If credentials are invalid or user not found.
"""
try:
user = await users_service.authenticate(data.email, data.password)
except PermissionDeniedException:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Invalid email or password"
)
from chitai.app import oauth2_auth
return oauth2_auth.login(
identifier=data.email,
token_extras={"user_id": user.id, "email": user.email},
)
@get("/me")
async def get_user_info(self, current_user: User) -> UserRead:
"""
Retrieve the currently authenticated user's information.
Returns the profile information of the user making the request.
Injected Dependencies:
current_user: The authenticated user from the OAuth2 token.
Returns:
The current user's information as a UserRead schema.
"""
return UserRead(**current_user.to_dict())

View File

@@ -0,0 +1,68 @@
# src/chitai/controllers/author.py
# Standard library
from typing import Annotated
# Third-party libraries
from litestar import Controller, post, get, patch, delete
from litestar.params import Dependency
from litestar.exceptions import HTTPException
from advanced_alchemy.extensions.litestar.providers import create_service_dependencies
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service import FilterTypeT
# Local imports
from chitai.services import AuthorService
from chitai.schemas import AuthorRead
from chitai.services.filters.author import AuthorLibraryFilter
class AuthorController(Controller):
"""Controller for managing author information."""
path = "/authors"
dependencies = create_service_dependencies(
AuthorService,
key="author_service",
filters={
"id_filter": int,
"pagination_type": "limit_offset",
"sort_field": "name",
"search": ["name"],
"search_ignore_case": True,
},
)
@get()
async def list_authors(
self,
author_service: AuthorService,
libraries: list[int] | None = None,
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = [],
) -> OffsetPagination[AuthorRead]:
"""
List authors with filtering, pagination, and search.
Supports searching by name and filtering by library. Results can be sorted
and paginated using standard query parameters.
Query Parameters:
library_id: Optional library ID to filter authors. If None, returns all authors.
id_filter: Filter by author ID (from create_filter_dependencies).
sort_field: Field to sort by (default: 'name').
search: Search term for author name field.
search_ignore_case: Case-insensitive search (default: True).
limit: Number of results per page (pagination).
offset: Number of results to skip (pagination).
Injected Dependencies:
author_service: The author service for database operations.
filters: Pre-constructed filters from dependencies.
Returns:
Paginated list of authors matching the criteria.
"""
filters.append(AuthorLibraryFilter(libraries))
authors, total = await author_service.list_and_count(*filters, uniquify=True)
return author_service.to_schema(authors, total, filters, schema_type=AuthorRead)

View File

@@ -0,0 +1,454 @@
# src/chitai/controllers/book.py
# TODO: Make the endpoints more consistent (i.e sometimes book_id is a path parameter, other times it a query param)
# Standard library
from typing import Annotated
# Third-party libraries
from litestar import Controller, get, post, put, patch, delete
from litestar.di import Provide
from litestar.params import Dependency, Body
from litestar.enums import RequestEncodingType
from litestar.response import File, Stream
from litestar.exceptions import HTTPException
from litestar.status_codes import HTTP_400_BAD_REQUEST
from litestar.datastructures import UploadFile
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.filters import CollectionFilter
from advanced_alchemy.service import FilterTypeT
# Local imports
from chitai.services import dependencies as deps
from chitai import schemas as s
from chitai.database import models as m
from chitai.services import BookService, BookProgressService
class BookController(Controller):
"""Controller for managing book-related operations."""
path = "/books"
dependencies = {
"books_service": Provide(deps.provide_book_service),
"progress_service": Provide(deps.provide_progress_service),
"library_service": Provide(deps.provide_library_service),
"library": Provide(deps.get_library_by_id),
"authors_filter": Provide(deps.provide_authors_filter, sync_to_thread=False),
"publishers_filter": Provide(
deps.provide_publishers_filter, sync_to_thread=False
),
"tags_filter": Provide(deps.provide_tags_filter, sync_to_thread=False),
"bookshelves_filter": Provide(
deps.provide_bookshelves_filter, sync_to_thread=False
),
"progress_filter": Provide(deps.provide_progress_filter, sync_to_thread=False),
"libraries_filter": Provide(
deps.provide_libraries_filter, sync_to_thread=False
),
"book_filters": Provide(deps.provide_book_filters, sync_to_thread=False),
} | deps.create_book_filter_dependencies(
{
"id_filter": int,
"pagination_type": "limit_offset",
"sort_field": "title",
"search": ["title"],
"search_ignore_case": True,
}
)
@post(request_max_body_size=None)
async def create_book(
self,
books_service: BookService,
library: m.Library,
data: Annotated[s.BookCreate, Body(media_type=RequestEncodingType.MULTI_PART)],
) -> s.BookRead:
"""
Create a new book with metadata and files.
Accepts book metadata and associated files, processes them through the
book service, and returns the created book details.
Path Parameters:
library_id: The ID of the library the book belongs to.
Request Body:
data: Book creation data including metadata and files.
Injected Dependencies:
books_service: The book service for database operations.
library: The library the book belongs to.
Returns:
The created book as a BookRead schema.
"""
result = await books_service.create(data, library)
book = await books_service.get(result.id)
return books_service.to_schema(book, schema_type=s.BookRead)
@post("fromFiles", request_max_body_size=None)
async def create_books_without_metadata(
self,
books_service: BookService,
library: m.Library,
data: Annotated[
s.BooksCreateFromFiles, Body(media_type=RequestEncodingType.MULTI_PART)
],
) -> OffsetPagination[s.BookRead]:
"""
Create multiple books from uploaded files.
Groups files by directory and creates separate books for each group.
Metadata is automatically extracted from the files.
Request Body:
data: Container with list of uploaded files.
Injected Dependencies:
books_service: The book service for database operations.
library: The library the books belong to.
Returns:
Paginated list of created books.
"""
try:
results = await books_service.create_many_from_files(data, library)
except ValueError:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST, detail="Must upload at least one file"
)
books = await books_service.list(
CollectionFilter("id", [result.id for result in results])
)
return books_service.to_schema(books, schema_type=s.BookRead)
@get(path="/{book_id:int}")
async def get_book_by_id(
self, book_id: int, books_service: BookService, current_user: m.User
) -> s.BookRead:
"""
Retrieve book details by ID.
Path Parameters:
book_id: The ID of the book to retrieve.
Injected Dependencies:
books_service: The book service for database operations.
current_user: The authenticated user making the request.
Returns:
The book details as a BookRead schema.
"""
book = await books_service.get(book_id)
return books_service.to_schema(book, schema_type=s.BookRead)
@get()
async def list_books(
self,
books_service: BookService,
book_filters: Annotated[
list[FilterTypeT], Dependency(skip_validation=True)
] = [],
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = [],
) -> OffsetPagination[s.BookRead]:
"""
List books with filtering, pagination, and search.
Supports filtering by library, tags, authors, shelves, and publishers.
Query Parameters:
library_id: The library to list books from.
tags: Optional list of tag IDs to filter by.
authors: Optional list of author IDs to filter by.
shelves: Optional list of bookshelf IDs to filter by.
publishers: Optional list of publisher IDs to filter by.
id: Injected filter. Filter by book ID.
sort_field: Injected filter. Field to sort by (default: title).
search: Injected filter. Search on title field.
search_ignore_case: Injected filter. Case-insensitive search.
Injected Dependencies:
books_service: The book service for database operations.
current_user: The authenticated user making the request.
filters: Pre-constructed filters from dependencies.
Returns:
Paginated list of books matching the criteria.
"""
books, total = await books_service.list_and_count(*filters, *book_filters)
return books_service.to_schema(books, total, filters, schema_type=s.BookRead)
@patch(path="{book_id:int}")
async def update_book_metadata(
self,
book_id: int,
data: Annotated[
s.BookMetadataUpdate, Body(media_type=RequestEncodingType.JSON)
],
books_service: BookService,
library: m.Library,
) -> s.BookRead:
"""
Update a book's metadata and optionally reorganize files if necessary.
Allows editing book metadata such as title, author, and series information.
If metadata changes affect the file organization path, files are
automatically reorganized.
Path Parameters:
book_id: The ID of the book to update.
Request Body:
data: Updated metadata for the book.
Injected Dependencies:
books_service: The book service for database operations.
library: The library containing the book.
Returns:
The updated book as a BookRead schema.
"""
await books_service.update(book_id, data, library)
book = await books_service.get(book_id)
return books_service.to_schema(book, schema_type=s.BookRead)
@put(path="{book_id:int}/cover")
async def update_cover(
self,
book_id: int,
data: Annotated[UploadFile, Body(media_type=RequestEncodingType.MULTI_PART)],
books_service: BookService,
library: m.Library,
) -> s.BookRead:
"""
Update's a book's cover image.
Allows uploading a new cover image that replaces the current cover.
Path Parameters:
book_id: The ID of the book to update.
Request Body:
cover_image: The new cover image for the book.
Injected Dependencies:
books_service: The book service for database operations.
library: The library containing the book.
Returns:
The updated book as a BookRead schema.
"""
await books_service.update(book_id, {"cover_image": data}, library)
updated_book = await books_service.get(book_id)
return books_service.to_schema(updated_book, schema_type=s.BookRead)
@get(path="download/{book_id:int}/{file_id:int}")
async def get_file(
self, book_id: int, file_id: int, books_service: BookService
) -> File:
"""
Download a single file from a book.
Path Parameters:
book_id: The ID of the book containing the file.
file_id: The ID of the file to download.
Injected Dependencies:
books_service: The book service for database operations.
Returns:
The file ready for download.
"""
return await books_service.get_file(book_id, file_id)
@get(path="download")
async def get_files(
self, book_ids: list[int], library_id: int, books_service: BookService
) -> Stream:
"""
Download multiple books as a single ZIP file.
Compresses all files from specified books into a single archive for download.
Query Parameters:
book_ids: List of book IDs to include in the download.
library_id: The library containing the books.
Injected Dependencies:
books_service: The book service for database operations.
Returns:
A stream of the compressed ZIP file.
"""
return Stream(
books_service.get_files(book_ids, library_id),
headers={
"Content-Disposition": 'attachment; filename="download.zip"',
"Content-Type": "application/zip",
},
)
@post(path="/{book_id:int}/files", request_max_body_size=None)
async def add_book_files(
self,
book_id: int,
data: Annotated[
list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)
],
library: m.Library,
books_service: BookService,
) -> s.BookRead:
"""
Add files to an existing book.
Path Parameters:
book_id: The ID of the book to modify
Request Body:
files: The files to add to the book
Injected Dependencies:
library: The library containing the book.
books_service: The book service for database operations.
Returns:
The modified book
"""
await books_service.add_files(book_id, data, library)
book = await books_service.get(book_id)
return books_service.to_schema(book, schema_type=s.BookRead)
@delete(path="{book_id:int}/files")
async def delete_book_files(
self,
library: m.Library,
books_service: BookService,
book_id: int,
file_ids: list[int],
delete_files: bool = False,
) -> None:
"""
Delete specific files from a book.
Removes the specified files from both the database and optionally
remove them from the filesystem.
Path Parameters:
book_id: The ID of the book.
Query Parameters:
file_ids: List of file IDs to delete.
delete_files: If True, also delete files from the filesystem.
Injected Dependencies:
books_service: The book service for database operations.
library: The library containing the book.
"""
if len(file_ids) == 0:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="At least one file to delete is required",
)
await books_service.remove_files(book_id, file_ids, delete_files, library)
@delete(path="/")
async def delete_books(
self,
books_service: BookService,
library: m.Library,
book_ids: list[int],
delete_files: bool = False,
) -> None:
"""
Delete multiple books by ID.
Removes books from the database and optionally deletes associated files
from the filesystem.
Query Parameters:
book_ids: List of book IDs to delete.
delete_files: If True, also delete associated files from the filesystem.
library_id: The library the given books belong to.
Injected Dependencies:
books_service: The book service for database operations.
library: The library containing the books.
"""
await books_service.delete(book_ids, library, delete_files=delete_files)
@post(path="/progress/{book_id:int}")
async def update_progress(
self,
book_id: int,
current_user: m.User,
data: s.BookProgressCreate,
progress_service: BookProgressService,
) -> None:
"""
Update or create reading progress for a book.
Records the user's progress on a book, creating or updating the progress entry.
Path Parameters:
book_id: The ID of the book.
Request Body:
data: Progress data to update.
Injected Dependencies:
current_user: The authenticated user.
progress_service: The progress service for database operations.
"""
progress = await progress_service.get_one_or_none(
m.BookProgress.user_id == current_user.id, m.BookProgress.book_id == book_id
)
dict_data = data.model_dump() | {"user_id": current_user.id, "book_id": book_id}
await progress_service.upsert(dict_data, progress.id if progress else None)
@post(path="progress")
async def set_book_progress_batch(
self,
book_ids: list[int],
data: Annotated[s.BookProgressCreate, Body()],
current_user: m.User,
progress_service: BookProgressService,
) -> None:
"""
Marks all the user progress on the books matching the IDs as completed.
Query Parameters:
book_ids: List of book IDs to mark as completed.
Injected Dependencies:
current_user: The authenticated user.
progress_service: The progress service for database operations.
"""
# TODO Optimize this by performing a batch query and batch upsert
for book_id in book_ids:
existing_progress = await progress_service.get_one_or_none(
m.BookProgress.user_id == current_user.id,
m.BookProgress.book_id == book_id,
)
progress_data = data.model_dump() | {
"user_id": current_user.id,
"book_id": book_id,
}
await progress_service.upsert(
progress_data, existing_progress.id if existing_progress else None
)

View File

@@ -0,0 +1,76 @@
# src/chitai/controllers/library.py
# Standard library
from typing import Annotated
# Third-party libraries
from litestar import Controller, post, get, patch, delete
from litestar.params import Dependency
from litestar.exceptions import HTTPException
from advanced_alchemy.extensions.litestar.providers import create_service_dependencies
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service import FilterTypeT
# Local imports
from chitai.database import models as m
from chitai.services import LibraryService
from chitai.schemas.library import LibraryCreate, LibraryRead
from chitai.services.utils import DirectoryDoesNotExist
class LibraryController(Controller):
"""Controller for managing library operations."""
path = "/libraries"
dependencies = create_service_dependencies(
LibraryService,
key="library_service",
filters={"id_filter": int, "pagination_type": "limit_offset"},
)
@post()
async def create_library(
self, library_service: LibraryService, data: LibraryCreate
) -> LibraryRead:
"""
Create a new library.
Creates a new library entry in the system using the provided data. If the library is read-only
the target root path must exist.
Request Body:
data: Data used to create the new library.
Injected Dependencies:
library_service: The service used to manage library creation and persistence.
Returns:
The created library.
"""
try:
library = await library_service.create(data)
return library_service.to_schema(library, schema_type=LibraryRead)
except DirectoryDoesNotExist as exc:
raise HTTPException(status_code=400, detail=str(exc))
@get()
async def list_libraries(
self,
library_service: LibraryService,
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)],
) -> OffsetPagination[LibraryRead]:
"""
List all libraries.
Retrieves a paginated list of all libraries, optionally filtered by parameters.
Query Parameters:
filters: Filtering by ID and pagination parameters.
Injected Dependencies:
library_service: The service used to query and return library data.
"""
results, total = await library_service.list_and_count(*filters, load=[m.Library.books])
return library_service.to_schema(
results, total, filters, schema_type=LibraryRead
)

View File

@@ -0,0 +1,298 @@
from chitai.services import dependencies as deps
from chitai.database import models as m
from chitai.services.author import AuthorService
from chitai.services.filters.author import AuthorLibraryFilter
from chitai.services.filters.publisher import PublisherLibraryFilter
from chitai.services.filters.tags import TagLibraryFilter
from chitai.services.opds.models import Entry, Link, LinkTypes, LinkRelations
from chitai.services.opds.opds import create_acquisition_feed, create_navigation_feed, create_library_navigation_feed, create_collection_navigation_feed, create_pagination_links, create_search_link, get_opensearch_document
from chitai.services import BookService, ShelfService, LibraryService
from chitai.services.publisher import PublisherService
from chitai.services.tag import TagService
from litestar import Controller, Request, Response, get
from litestar.response import File
from litestar.di import Provide
from litestar.exceptions import HTTPException
from advanced_alchemy.filters import CollectionFilter, LimitOffset, OrderBy
from chitai.middleware.basic_auth import basic_auth_mw
from urllib.parse import urlencode
from typing import Annotated
from litestar.params import Dependency
from advanced_alchemy.service import FilterTypeT
class OpdsController(Controller):
""" Controller for managing OPDS endpoints """
middleware=[basic_auth_mw]
dependencies = {
"user": Provide(deps.provide_user_via_basic_auth),
"library_service": Provide(deps.provide_library_service),
"library": Provide(deps.get_library_by_id),
"books_service": Provide(deps.provide_book_service),
"shelf_service": Provide(deps.provide_shelf_service),
"tag_service": Provide(deps.provide_tag_service),
"author_service": Provide(deps.provide_author_service),
"publisher_service": Provide(deps.provide_publisher_service),
"authors_filter": Provide(deps.provide_authors_filter, sync_to_thread=False),
"publishers_filter": Provide(
deps.provide_publishers_filter, sync_to_thread=False
),
"tags_filter": Provide(deps.provide_tags_filter, sync_to_thread=False),
"bookshelves_filter": Provide(
deps.provide_bookshelves_filter, sync_to_thread=False
),
"progress_filter": Provide(deps.provide_progress_filter, sync_to_thread=False),
"libraries_filter": Provide(
deps.provide_libraries_filter, sync_to_thread=False
),
"book_filters": Provide(deps.provide_book_filters, sync_to_thread=False),
} | deps.create_book_filter_dependencies(
{
"id_filter": int,
"pagination_type": "limit_offset",
"sort_field": "title",
"search": ["title"],
"search_ignore_case": True,
}
)
path = "/opds"
@get()
async def get_root_feed(self, library_service: LibraryService) -> Response:
libraries = await library_service.list()
entries = [
Entry(
id=f"/opds/library/{lib.id}",
title=lib.name,
link=[
Link(
title=lib.name,
href=f"/opds/library/{lib.id}",
rel=LinkRelations.SUBSECTION,
type=LinkTypes.NAVIGATION
)
]
) for lib in libraries
]
feed = create_navigation_feed(
id="/opds",
title="Root",
self_url="/opds",
links=[
Link(
rel="search",
href="/opds/opensearch",
type="application/opensearchdescription+xml",
title="Search books",
)
],
entries=entries
)
return Response(
feed,
media_type="application/xml"
)
@get("/acquisition")
async def get_acquisition_feed(
self,
request: Request,
feed_id: str,
feed_title: str,
books_service: BookService,
book_filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = [],
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = []
) -> Response:
all_filters = [*filters, *book_filters]
books, total = await books_service.list_and_count(*all_filters)
limit, offset = extract_limit_offset(all_filters)
links = []
# Create pagination links if it is a paginated feed
if request.query_params.get('paginated'):
pagination = create_pagination_links(
request=request,
total=total,
limit=limit,
offset=offset,
feed_title=feed_title,
link_type=LinkTypes.ACQUISITION
)
links.extend([link for link in [pagination.next_link, pagination.prev_link] if link])
# Add search link if this is a searchable feed
if request.query_params.get('search'):
links.append(create_search_link(request))
# Create self URL
self_url = f"{request.url.path}?{urlencode(list(request.query_params.items()), doseq=True)}"
feed = create_acquisition_feed(
id=feed_id,
title=feed_title,
url=self_url,
books=books,
links=links,
)
return Response(feed, media_type="application/xml")
@get("/opensearch")
async def opensearch(self, user: m.User, request: Request) -> Response:
return Response(
get_opensearch_document(
base_url=f'/opds/search?{urlencode(list(request.query_params.items()), doseq=True)}&'
),
media_type="application/xml"
)
@get("/library/{library_id:int}")
async def get_library_feed(self, library: m.Library) -> Response:
feed = create_library_navigation_feed(library)
return Response(feed, media_type="application/xml")
@get("/library/{library_id:int}/{collection_type:str}")
async def get_library_collection_feed(
self,
collection_type: str,
library: m.Library,
user: m.User,
author_service: AuthorService,
shelf_service: ShelfService,
tag_service: TagService,
publisher_service: PublisherService,
request: Request,
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = []
) -> Response:
service_map = {
'shelves': (shelf_service, lambda: shelf_service.list_and_count(
*filters,
CollectionFilter("library_id", values=[library.id]),
OrderBy("name", "asc"),
m.BookList.user_id == user.id
)),
'tags': (tag_service, lambda: tag_service.list_and_count(
*filters,
TagLibraryFilter(libraries=[library.id]),
OrderBy("name", "asc"),
uniquify=True,
)),
'authors': (author_service, lambda: author_service.list_and_count(
*filters,
AuthorLibraryFilter(libraries=[library.id]),
OrderBy("name", "asc"),
uniquify=True
)),
'publishers': (publisher_service, lambda: publisher_service.list_and_count(
*filters,
PublisherLibraryFilter(libraries=[library.id]),
OrderBy("name", "asc"),
uniquify=True
))
}
if collection_type not in service_map:
raise HTTPException(status_code=404, detail="Collection type not found")
_, fetch_items = service_map[collection_type]
items, total = await fetch_items()
links = []
# Create pagination links if it is a paginated feed
limit, offset = extract_limit_offset(filters)
if request.query_params.get('paginated'):
pagination = create_pagination_links(
request=request,
total=total,
limit=limit,
offset=offset,
feed_title=collection_type,
link_type=LinkTypes.ACQUISITION
)
links.extend([link for link in [pagination.next_link, pagination.prev_link] if link])
feed = create_collection_navigation_feed(library, collection_type, items, links)
return Response(feed, media_type="application/xml")
@get("/search")
async def search_books(
self, books_service: BookService,
request: Request,
book_filters: Annotated[
list[FilterTypeT], Dependency(skip_validation=True)
] = [],
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = []
) -> Response:
filters = [*filters, *book_filters]
books, total = await books_service.list_and_count(
*filters
)
limit, offset = extract_limit_offset(filters)
# Create pagination links
pagination = create_pagination_links(
request=request,
total=total,
limit=limit,
offset=offset,
feed_title="Search Results",
link_type=LinkTypes.ACQUISITION
)
links = [link for link in [pagination.next_link, pagination.prev_link] if link]
catalog_xml = create_acquisition_feed(
id=f"/opds/search?q=q",
title="Search results",
url=f"/opds/search?q=q",
books=books,
links=links
)
return Response(catalog_xml, media_type="application/xml")
@get(path="download/{book_id:int}/{file_id:int}")
async def get_file(
self, book_id: int, file_id: int, books_service: BookService
) -> File:
return await books_service.get_file(book_id, file_id)
def extract_limit_offset(filters: list[FilterTypeT]) -> tuple[int, int]:
"""Extract page size and offset from filters"""
limit_offset_filter = next(
(f for f in filters if isinstance(f, LimitOffset)),
None
)
if limit_offset_filter:
return limit_offset_filter.limit, limit_offset_filter.offset
raise ValueError("LimitOffset filter not found")

View File

@@ -0,0 +1,72 @@
# src/chitai/controllers/publisher.py
# Standard library
from typing import Annotated
# Third-party libraries
from litestar import Controller, post, get, patch, delete
from litestar.params import Dependency
from advanced_alchemy.extensions.litestar.providers import create_service_dependencies
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service import FilterTypeT
# Local imports
from chitai.services import PublisherService
from chitai.schemas import PublisherRead
from chitai.services.filters.publisher import PublisherLibraryFilter
class PublisherController(Controller):
"""Controller for managing publisher information."""
path = "/publishers"
dependencies = create_service_dependencies(
PublisherService,
key="publisher_service",
filters={
"id_filter": int,
"pagination_type": "limit_offset",
"sort_field": "name",
"search": ["name"],
"search_ignore_case": True,
},
)
@get()
async def list_publishers(
self,
publisher_service: PublisherService,
libraries: list[int] | None = None,
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = [],
) -> OffsetPagination[PublisherRead]:
"""
List publishers with filtering, pagination, and search.
Supports searching by name and filtering by library. Results can be sorted
and paginated using standard query parameters.
Query Parameters:
library_id: Optional library ID to filter publishers. If None, returns all publishers.
id_filter: Filter by publisher ID (from create_filter_dependencies).
sort_field: Field to sort by (default: 'name').
search: Search term for publisher name field.
search_ignore_case: Case-insensitive search (default: True).
limit: Number of results per page (pagination).
offset: Number of results to skip (pagination).
Injected Dependencies:
publisher_service: The publisher service for database operations.
filters: Pre-constructed filters from dependencies.
Returns:
Paginated list of publishers matching the criteria.
"""
filters.append(PublisherLibraryFilter(libraries))
publishers, total = await publisher_service.list_and_count(
*filters, uniquify=True
)
return publisher_service.to_schema(
publishers, total, filters, schema_type=PublisherRead
)

View File

@@ -0,0 +1,242 @@
# src/chitai/controllers/shelf.py
# Standard library
from typing import Annotated
import logging
# Third-party libraries
from litestar import Controller, get, post, delete
from litestar.di import Provide
from litestar.params import Dependency
from litestar.status_codes import HTTP_403_FORBIDDEN
from litestar.exceptions import HTTPException
from advanced_alchemy.extensions.litestar.providers import create_filter_dependencies
from advanced_alchemy.service import OffsetPagination
from advanced_alchemy.service import FilterTypeT
from advanced_alchemy.exceptions import IntegrityError
from advanced_alchemy.filters import CollectionFilter
from sqlalchemy.orm import selectinload
# Local imports
from chitai.schemas.shelf import ShelfCreate, ShelfRead
from chitai.services import dependencies as deps
from chitai.services.bookshelf import ShelfService
from chitai.database import models as m
logger = logging.getLogger(__name__)
class BookshelfController(Controller):
"""Controller for managing bookshelf operations."""
path = "/shelves"
dependencies = {
"shelf_service": Provide(deps.provide_shelf_service)
} | create_filter_dependencies(
{"id_filter": int, "pagination_type": "limit_offset", "search": "title"}
)
@get()
async def list_shelves(
self,
shelf_service: ShelfService,
current_user: m.User,
libraries: list[int] | None = None,
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = [],
) -> OffsetPagination[ShelfRead]:
"""
List bookshelves for the authenticated user.
Retrieves all shelves created by the current user, with optional filtering
and search by shelf title.
Query Parameters:
library_id: Optional library ID to filter shelves (currently unused).
id_filter: Filter by shelf ID (from create_filter_dependencies).
search: Search term for shelf title (case-insensitive).
limit: Number of results per page (pagination).
offset: Number of results to skip (pagination).
Injected Dependencies:
shelf_service: The shelf service for database operations.
current_user: The authenticated user making the request.
filters: Pre-constructed filters from dependencies.
Returns:
Paginated list of the user's bookshelves.
"""
if libraries:
filters.append(CollectionFilter("library_id", values=libraries))
filters.append(m.BookList.user_id == current_user.id)
results, total = await shelf_service.list_and_count(*filters, load=[selectinload(m.BookList.book_links)])
return shelf_service.to_schema(results, total, filters, schema_type=ShelfRead)
@post()
async def create_shelf(
self, data: ShelfCreate, current_user: m.User, shelf_service: ShelfService
) -> ShelfRead:
"""
Create a new bookshelf for the authenticated user.
Request Body:
data: Shelf creation data including title and optional description.
Injected Dependencies:
current_user: The authenticated user creating the shelf.
shelf_service: The shelf service for database operations.
Returns:
The created shelf as a ShelfRead schema.
"""
try:
shelf = await shelf_service.create(
data.model_dump() | {"user_id": current_user.id}
)
await shelf_service.add_books(shelf.id, data.book_ids)
return shelf_service.to_schema(shelf, schema_type=ShelfRead)
except IntegrityError:
raise HTTPException(
status_code=400,
detail=f"Library with ID {data.library_id} does not exist",
)
@delete("{shelf_id:int}")
async def delete_shelf(
self, shelf_id: int, current_user: m.User, shelf_service: ShelfService
) -> None:
"""
Delete a bookshelf.
Removes a shelf and all its associations. Only the shelf owner can delete it.
Path Parameters:
shelf_id: The ID of the shelf to delete.
Injected Dependencies:
current_user: The authenticated user (for authorization).
shelf_service: The shelf service for database operations.
Raises:
HTTPException (403): If the user does not own the shelf.
HTTPException (404): If the shelf does not exist.
"""
shelf = await shelf_service.get(shelf_id)
if shelf.user_id != current_user.id:
logger.warning(
f"Unauthorized deletion attempt: user {current_user.id} attempted to delete shelf {shelf_id}"
)
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="You do not have permission to delete this shelf",
)
await shelf_service.delete(shelf_id)
@post("{shelf_id:int}/books")
async def add_books_to_shelf(
self,
shelf_id: int,
book_ids: list[int],
current_user: m.User,
shelf_service: ShelfService,
) -> ShelfRead:
"""
Add books to a bookshelf.
Adds one or more books to an existing shelf. Only the shelf owner can add books.
Path Parameters:
shelf_id: The ID of the shelf.
Query Parameters:
book_ids: List of book IDs to add to the shelf.
Injected Dependencies:
current_user: The authenticated user (for authorization).
shelf_service: The shelf service for database operations.
Returns:
the updated shelf
Raises:
HTTPException (403): If the user does not own the shelf.
HTTPException (404): If the shelf or any book does not exist.
"""
shelf = await shelf_service.get(shelf_id)
if shelf.user_id != current_user.id:
logger.warning(
f"Unauthorized add attempt: user {current_user.id} attempted to modify shelf {shelf_id}"
)
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="You do not have permission to modify this shelf",
)
try:
bookshelf = await shelf_service.add_books(shelf_id, book_ids)
except ValueError:
raise HTTPException(
status_code=400, detail="Attempted to add books that do not exist."
)
logger.debug(
f"Added {len(book_ids)} books to shelf {shelf_id} by user {current_user.id}"
)
return shelf_service.to_schema(bookshelf, schema_type=ShelfRead)
@delete("{shelf_id:int}/books", status_code=200)
async def remove_books_from_shelf(
self,
shelf_id: int,
book_ids: list[int],
current_user: m.User,
shelf_service: ShelfService,
) -> ShelfRead:
"""
Remove books from a bookshelf.
Removes one or more books from a shelf. Only the shelf owner can remove books.
Path Parameters:
shelf_id: The ID of the shelf.
Query Parameters:
book_ids: List of book IDs to remove from the shelf.
Injected Dependencies:
current_user: The authenticated user (for authorization).
shelf_service: The shelf service for database operations.
Returns:
the updated shelf
Raises:
HTTPException (403): If the user does not own the shelf.
HTTPException (404): If the shelf or any book does not exist.
"""
shelf = await shelf_service.get(shelf_id)
if shelf.user_id != current_user.id:
logger.warning(
f"Unauthorized remove attempt: user {current_user.id} attempted to modify shelf {shelf_id}"
)
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="You do not have permission to modify this shelf",
)
bookshelf = await shelf_service.remove_books(shelf_id, book_ids)
logger.debug(
f"Removed {len(book_ids)} books from shelf {shelf_id} by user {current_user.id}"
)
return shelf_service.to_schema(bookshelf, schema_type=ShelfRead)

View File

@@ -0,0 +1,67 @@
# src/chitai/controllers/tag.py
# Standard library
from typing import Annotated
# Third-party libraries
from litestar import Controller, post, get, patch, delete
from litestar.params import Dependency
from advanced_alchemy.extensions.litestar.providers import create_service_dependencies
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service import FilterTypeT
# Local imports
from chitai.services import TagService
from chitai.schemas import TagRead
from chitai.services.filters.tags import TagLibraryFilter
class TagController(Controller):
"""Controller for managing tag information."""
path = "/tags"
dependencies = create_service_dependencies(
TagService,
key="tag_service",
filters={
"id_filter": int,
"pagination_type": "limit_offset",
"sort_field": "name",
"search": ["name"],
"search_ignore_case": True,
},
)
@get()
async def list_tags(
self,
tag_service: TagService,
libraries: list[int] | None = None,
filters: Annotated[list[FilterTypeT], Dependency(skip_validation=True)] = [],
) -> OffsetPagination[TagRead]:
"""
List tags with filtering, pagination, and search.
Supports searching by name and filtering by library. Results can be sorted
and paginated using standard query parameters.
Query Parameters:
library_id: Optional library ID to filter tags. If None, returns all tags.
id_filter: Filter by tag ID (from create_filter_dependencies).
sort_field: Field to sort by (default: 'name').
search: Search term for tag name field.
search_ignore_case: Case-insensitive search (default: True).
limit: Number of results per page (pagination).
offset: Number of results to skip (pagination).
Injected Dependencies:
tag_service: The tag service for database operations.
filters: Pre-constructed filters from dependencies.
Returns:
Paginated list of tags matching the criteria.
"""
filters.append(TagLibraryFilter(libraries))
tags, total = await tag_service.list_and_count(*filters, uniquify=True)
return tag_service.to_schema(tags, total, filters, schema_type=TagRead)

View File

@@ -0,0 +1,23 @@
from chitai.config import settings
from advanced_alchemy.extensions.litestar import (
SQLAlchemyAsyncConfig,
AsyncSessionConfig,
SQLAlchemyPlugin,
)
from sqlalchemy.ext.asyncio import create_async_engine
from chitai.database import models
DATABASE_URL = str(settings.postgres_uri)
session_config = AsyncSessionConfig(expire_on_commit=False)
config = SQLAlchemyAsyncConfig(
engine_instance=create_async_engine(DATABASE_URL, echo=settings.postgres_echo),
session_config=session_config,
before_send_handler="autocommit",
create_all=True,
)
alchemy = SQLAlchemyPlugin(config=config)

View File

@@ -0,0 +1,11 @@
from .author import Author, BookAuthorLink
from .book import Book, Identifier, FileMetadata
from .book_list import BookList, BookListLink
from .book_progress import BookProgress
from .book_series import BookSeries
from .library import Library
from .publisher import Publisher
from .tag import Tag, BookTagLink
from .user import User
from advanced_alchemy.base import BigIntBase

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING, Optional
from collections.abc import Hashable
from sqlalchemy import ColumnElement, ForeignKey
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from advanced_alchemy.base import BigIntAuditBase, BigIntBase
from advanced_alchemy.mixins import UniqueMixin
if TYPE_CHECKING:
from .book import Book
class Author(BigIntAuditBase, UniqueMixin):
__tablename__ = "authors"
name: Mapped[str] = mapped_column(unique=True, index=True)
description: Mapped[Optional[str]]
@classmethod
def unique_hash(cls, name: str) -> Hashable:
"""Generate a unique hash for deduplication."""
return name
@classmethod
def unique_filter(cls, name: str) -> ColumnElement[bool]:
"""SQL filter for finding existing records."""
return cls.name == name
def __repr__(self) -> str:
return f"Author({self.name!r})"
class BookAuthorLink(BigIntBase):
__tablename__ = "book_author_links"
book_id: Mapped[int] = mapped_column(
ForeignKey("books.id", ondelete="cascade"), primary_key=True
)
author_id: Mapped[int] = mapped_column(ForeignKey("authors.id"), primary_key=True)
position: Mapped[int]
book: Mapped["Book"] = relationship(back_populates="author_links")
author: Mapped[Author] = relationship()

View File

@@ -0,0 +1,152 @@
from datetime import date
from typing import TYPE_CHECKING, Any, Optional
from sqlalchemy import Index
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.orm.collections import attribute_keyed_dict
from advanced_alchemy.base import BigIntAuditBase, BigIntBase
from .author import Author, BookAuthorLink
from .publisher import Publisher
from .tag import Tag, BookTagLink
from .library import Library
if TYPE_CHECKING:
from chitai.database.models import (
BookProgress,
BookList,
BookSeries,
BookListLink,
)
class Book(BigIntAuditBase):
__tablename__ = "books"
__table_args__ = (
# Create trigram indexes
Index(
"ix_books_title_trigram",
"title",
postgresql_using="gin",
postgresql_ops={"title": "gin_trgm_ops"},
),
)
library_id: Mapped[int] = mapped_column(ForeignKey("libraries.id"), nullable=False)
library: Mapped["Library"] = relationship(back_populates="books")
title: Mapped[str]
subtitle: Mapped[Optional[str]]
description: Mapped[Optional[str]]
published_date: Mapped[Optional[date]]
language: Mapped[Optional[str]]
pages: Mapped[Optional[int]]
cover_image: Mapped[Optional[str]]
edition: Mapped[Optional[int]]
path: Mapped[Optional[str]]
library: Mapped[Library] = relationship(back_populates="books")
author_links: Mapped[list["BookAuthorLink"]] = relationship(
back_populates="book",
cascade="all, delete-orphan",
passive_deletes=True,
order_by="BookAuthorLink.position",
collection_class=ordering_list("position"),
)
authors: AssociationProxy[list["Author"]] = association_proxy(
"author_links", "author", creator=lambda author: BookAuthorLink(author=author)
)
publisher_id: Mapped[Optional[int]] = mapped_column(ForeignKey("publishers.id"))
publisher: Mapped[Optional[Publisher]] = relationship()
tag_links: Mapped[list["BookTagLink"]] = relationship(
back_populates="book",
cascade="all, delete-orphan",
passive_deletes=True,
order_by="BookTagLink.position",
collection_class=ordering_list("position"),
)
tags: AssociationProxy[list["Tag"]] = association_proxy(
"tag_links", "tag", creator=lambda tag: BookTagLink(tag=tag)
)
identifiers: Mapped[list["Identifier"]] = relationship(
cascade="all, delete-orphan",
passive_deletes=True,
)
files: Mapped[list["FileMetadata"]] = relationship(
cascade="all, delete-orphan", passive_deletes=True
)
list_links: Mapped[list["BookListLink"]] = relationship(
back_populates="book", cascade="all, delete-orphan", passive_deletes=True
)
lists: AssociationProxy[list["BookList"]] = association_proxy(
"list_links", "book_list"
)
series_id: Mapped[Optional[int]] = mapped_column(ForeignKey("book_series.id"))
series: Mapped[Optional["BookSeries"]] = relationship()
series_position: Mapped[Optional[str]]
progress_records: Mapped[list["BookProgress"]] = relationship(
cascade="all, delete-orphan", passive_deletes=True
)
@property
def progress(self) -> Optional["BookProgress"]:
return self.progress_records[0] if self.progress_records else None
def __repr__(self) -> str:
return f"Book({self.title=!r})"
def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]:
data = super().to_dict(exclude)
data["authors"] = [author.name for author in self.authors]
data["tags"] = [tag.name for tag in self.tags]
data["publisher"] = self.publisher.name if self.publisher else None
data["series"] = self.series.title if self.series else None
data["files"] = [file.to_dict() for file in self.files]
return data
class Identifier(BigIntBase):
__tablename__ = "identifiers"
name: Mapped[str] = mapped_column(primary_key=True)
book_id: Mapped[int] = mapped_column(
ForeignKey("books.id", ondelete="cascade"), primary_key=True
)
value: Mapped[str]
def __repr__(self):
return f"Identifier({self.name!r} : {self.value!r})"
class FileMetadata(BigIntBase):
__tablename__ = "file_metadata"
book_id: Mapped[int] = mapped_column(ForeignKey("books.id", ondelete="cascade"))
book: Mapped[Book] = relationship(back_populates="files")
hash: Mapped[str]
path: Mapped[str]
size: Mapped[int]
content_type: Mapped[Optional[str]]
def __repr__(self) -> str:
return f"FileMetadata({self.path!r})"

View File

@@ -0,0 +1,48 @@
from typing import Optional
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy
from sqlalchemy.ext.orderinglist import ordering_list
from advanced_alchemy.base import BigIntAuditBase, BigIntBase
from .book import Book
class BookList(BigIntAuditBase):
__tablename__ = "book_lists"
library_id: Mapped[Optional[int]] = mapped_column(
ForeignKey("libraries.id"), nullable=True
)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
title: Mapped[str]
book_links: Mapped[list["BookListLink"]] = relationship(
back_populates="book_list",
cascade="all, delete-orphan",
order_by="BookListLink.position",
collection_class=ordering_list("position"),
)
books: AssociationProxy[list[Book]] = association_proxy(
"book_links", "book", creator=lambda book: BookListLink(book=book)
)
@property
def total(self) -> int | None:
"""Return count of books if book_links is loaded, None otherwise."""
try:
return len(self.book_links) if self.book_links else 0
except Exception:
return None
class BookListLink(BigIntBase):
__tablename__ = "book_list_links"
book_id: Mapped[int] = mapped_column(
ForeignKey("books.id", ondelete="cascade"), primary_key=True
)
list_id: Mapped[int] = mapped_column(ForeignKey("book_lists.id"), primary_key=True)
position: Mapped[int]
book: Mapped[Book] = relationship(back_populates="list_links")
book_list: Mapped[BookList] = relationship(back_populates="book_links")

View File

@@ -0,0 +1,20 @@
from typing import Optional
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import relationship
from advanced_alchemy.base import BigIntAuditBase
class BookProgress(BigIntAuditBase):
__tablename__ = "book_progress"
user_id: Mapped[int] = mapped_column(
ForeignKey("users.id", ondelete="cascade"), nullable=False
)
book_id: Mapped[int] = mapped_column(
ForeignKey("books.id", ondelete="cascade"), nullable=False
)
epub_loc: Mapped[Optional[str]]
pdf_loc: Mapped[Optional[int]]
progress: Mapped[float]
completed: Mapped[Optional[bool]]

View File

@@ -0,0 +1,24 @@
from collections.abc import Hashable
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import ColumnElement, ForeignKey
from advanced_alchemy.base import BigIntAuditBase
from advanced_alchemy.mixins import UniqueMixin
from .book import Book
class BookSeries(BigIntAuditBase, UniqueMixin):
__tablename__ = "book_series"
title: Mapped[str] = mapped_column(unique=True, index=True)
@classmethod
def unique_hash(cls, title: str) -> Hashable:
"""Generate a unique hash for deduplication."""
return title
@classmethod
def unique_filter(cls, title: str) -> ColumnElement[bool]:
"""SQL filter for finding existing records."""
return cls.title == title

View File

@@ -0,0 +1,33 @@
from sqlalchemy.orm import Mapped, mapped_column, relationship
from advanced_alchemy.base import BigIntAuditBase
from advanced_alchemy.mixins import SlugKey
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from .book import Book
class Library(BigIntAuditBase, SlugKey):
__tablename__ = "libraries"
name: Mapped[str] = mapped_column(unique=True)
root_path: Mapped[str]
# Which structure to save the files in the filesystem (i.e {author_name}/{title}.{ext})
path_template: Mapped[str]
description: Mapped[Optional[str]]
read_only: Mapped[bool] = mapped_column(nullable=False, default=False)
books: Mapped[list["Book"]] = relationship(back_populates="library")
@property
def total(self) -> int | None:
"""Return count of books if books is loaded, None otherwise."""
try:
return len(self.books) if self.books else 0
except Exception:
return None
def __repr__(self) -> str:
return f"Library({self.name=!r})"

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from typing import Optional
from sqlalchemy import ColumnElement
from sqlalchemy.orm import Mapped
from advanced_alchemy.base import BigIntAuditBase
from advanced_alchemy.mixins import UniqueMixin
class Publisher(BigIntAuditBase, UniqueMixin):
__tablename__ = "publishers"
name: Mapped[str]
description: Mapped[Optional[str]]
@classmethod
def unique_hash(cls, name: str) -> Hashable:
"""Generate a unique hash for deduplication."""
return name
@classmethod
def unique_filter(cls, name: str) -> ColumnElement[bool]:
"""SQL filter for finding existing records."""
return cls.name == name
def __repr__(self) -> str:
return f"Publisher({self.name!r})"

View File

@@ -0,0 +1,49 @@
from collections.abc import Hashable
from typing import TYPE_CHECKING
from sqlalchemy import ColumnElement, ForeignKey
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from advanced_alchemy.base import BigIntBase
from advanced_alchemy.mixins import UniqueMixin
if TYPE_CHECKING:
from .book import Book
class Tag(BigIntBase, UniqueMixin):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(unique=True, index=True)
@classmethod
def unique_hash(cls, name: str) -> Hashable:
"""Generate a unique hash for deduplication."""
return name
@classmethod
def unique_filter(cls, name: str) -> ColumnElement[bool]:
"""SQL filter for finding existing records."""
return cls.name == name
def __repr__(self) -> str:
return f"Tag({self.name!r})"
class BookTagLink(BigIntBase):
__tablename__ = "book_tag_link"
book_id: Mapped[int] = mapped_column(
ForeignKey("books.id", ondelete="cascade"), primary_key=True
)
tag_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), primary_key=True)
position: Mapped[int]
book: Mapped["Book"] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship()

View File

@@ -0,0 +1,25 @@
from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, mapped_column, relationship
from advanced_alchemy.base import BigIntAuditBase
from advanced_alchemy.types import PasswordHash, HashedPassword
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
from pwdlib.hashers.argon2 import Argon2Hasher as PwdlibArgon2Hasher
if TYPE_CHECKING:
from chitai.database.models import BookList
PasswordType = HashedPassword
else:
PasswordType = str
class User(BigIntAuditBase):
__tablename__ = "users"
email: Mapped[str] = mapped_column(unique=True)
password: Mapped[PasswordType] = mapped_column(
PasswordHash(backend=PwdlibHasher(hasher=PwdlibArgon2Hasher()))
)
shelves: Mapped[list["BookList"]] = relationship(cascade="all, delete-orphan")

View File

@@ -0,0 +1,17 @@
from litestar import Request, Response, MediaType
from litestar.exceptions import HTTPException
from litestar.status_codes import HTTP_404_NOT_FOUND
from advanced_alchemy.exceptions import NotFoundError
def not_found_exception_handler(_: Request, exc: NotFoundError) -> Response:
"""Default handler for NotFoundError."""
return Response(
media_type=MediaType.TEXT,
content=f"Not found error: {exc.detail}",
status_code=404,
)
exception_handlers = {NotFoundError: not_found_exception_handler}

View File

@@ -0,0 +1,35 @@
from base64 import b64decode
from chitai.services.user import UserService
from litestar.middleware import (
AbstractAuthenticationMiddleware,
AuthenticationResult,
DefineMiddleware
)
from litestar.connection import ASGIConnection
from litestar.exceptions import NotAuthorizedException, PermissionDeniedException
from chitai.config import settings
class BasicAuthenticationMiddleware(AbstractAuthenticationMiddleware):
async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
"""Given a request, parse the header for Base64 encoded basic auth credentials. """
# retrieve the auth header
auth_header = connection.headers.get("Authorization", None)
if not auth_header:
raise NotAuthorizedException()
username, password = b64decode(auth_header.split("Basic ")[1]).decode().split(":")
try:
db_session = settings.alchemy_config.provide_session(connection.app.state, connection.scope)
user_service = UserService(db_session)
user = await user_service.authenticate(username, password)
return AuthenticationResult(user=user, auth=None)
except PermissionDeniedException:
raise NotAuthorizedException()
basic_auth_mw = DefineMiddleware(BasicAuthenticationMiddleware)

View File

@@ -0,0 +1,16 @@
from .book import (
BookCreate,
BookRead,
BookProgressCreate,
BookProgressRead,
BooksCreateFromFiles,
BookMetadataUpdate,
FileMetadataRead,
BookSeriesRead,
)
from .shelf import ShelfRead, ShelfCreate
from .library import LibraryCreate, LibraryRead
from .user import UserCreate, UserLogin, UserRead
from .author import AuthorRead
from .tag import TagRead
from .publisher import PublisherRead

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class AuthorRead(BaseModel):
id: int
name: str

View File

@@ -0,0 +1,174 @@
from datetime import date
import json
from pathlib import Path
from typing import Annotated, Optional, Union
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
SkipValidation,
field_validator,
computed_field,
)
from litestar.datastructures import UploadFile
from chitai.schemas.shelf import ShelfRead
from chitai.schemas.author import AuthorRead
from chitai.schemas.tag import TagRead
from chitai.schemas.publisher import PublisherRead
class BookSeriesRead(BaseModel):
id: int
title: str
class FileMetadataRead(BaseModel):
id: int
path: str
hash: str
size: int
content_type: str
@computed_field
@property
def filename(self) -> str:
return Path(self.path).name
class BookRead(BaseModel):
id: int
library_id: int
title: str
subtitle: str | None
description: str | None
cover_image: Path | None
authors: list["AuthorRead"]
edition: int | None
publisher: Optional["PublisherRead"]
published_date: date | None
identifiers: dict[str, str]
language: str | None
pages: int | None
tags: list["TagRead"]
lists: list[ShelfRead]
series: Optional["BookSeriesRead"]
series_position: str | None
files: list["FileMetadataRead"]
progress: Union["BookProgressRead", None] = None
@field_validator("identifiers", mode="before")
@classmethod
def convert_dict(cls, value) -> dict[str, str]:
if isinstance(value, list):
return {item.name: item.value for item in value}
return value
def split_string_and_remove_duplicates(value: list | str) -> list:
if isinstance(
value, str
): # If a string was received, split it into individual authors
value = value.split(",")
return list(dict.fromkeys(value)) # Remove duplicates, keep insertion order
class BookCreate(BaseModel):
library_id: int
title: str | None = None
subtitle: str | None = None
description: str | None = None
authors: Annotated[
list[str],
BeforeValidator(split_string_and_remove_duplicates),
] = Field(default_factory=list)
tags: Annotated[
list[str],
BeforeValidator(split_string_and_remove_duplicates),
] = Field(default_factory=list)
edition: int | None = None
publisher: str | None = None
published_date: date | None = None
identifiers: dict[str, str] | None = Field(default_factory=dict)
language: str | None = None
pages: int | None = None
series: str | None = None
series_position: str | None = None
cover_image: Annotated[UploadFile | None, SkipValidation] = None
files: Annotated[list[UploadFile], SkipValidation] = Field(default_factory=list)
@field_validator("identifiers", mode="before")
@classmethod
def to_dict(cls, value) -> dict[str, str]:
if isinstance(value, str):
return json.loads(value)
return value or {}
model_config = ConfigDict(arbitrary_types_allowed=True)
class BooksCreateFromFiles(BaseModel):
"""Upload schema for multiple books with only files."""
files: Annotated[list[UploadFile], SkipValidation, Field(default_factory=list)]
model_config = ConfigDict(arbitrary_types_allowed=True)
class BookMetadataUpdate(BaseModel):
title: str | None = None
subtitle: str | None = None
description: str | None = None
authors: Annotated[
list[str],
BeforeValidator(split_string_and_remove_duplicates),
Field(default_factory=list),
]
tags: Annotated[
list[str],
BeforeValidator(split_string_and_remove_duplicates),
Field(default_factory=list),
]
edition: int | None = None
publisher: str | None = None
published_date: date | None = None
identifiers: dict[str, str] | None = dict()
language: str | None = None
pages: int | None = None
series: str | None = None
series_position: str | None = None
@field_validator("identifiers", mode="before")
@classmethod
def to_dict(cls, value) -> dict[str, str]:
if isinstance(value, str):
return json.loads(value)
return value or {}
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_validator("title", mode="before")
@classmethod
def title_not_none(cls, v):
if v is None:
raise ValueError("title cannot be explicitly set to None")
return v
class BookProgressCreate(BaseModel):
progress: float
epub_loc: str | None = None
pdf_loc: int | None = None
completed: bool | None = None
class BookProgressRead(BaseModel):
progress: float
epub_loc: str | None = None
pdf_loc: int | None = None
completed: bool | None = False

View File

@@ -0,0 +1,33 @@
from pathlib import Path
from typing import Annotated
from pydantic import BaseModel, Field, computed_field
from advanced_alchemy.utils.text import slugify
class LibraryCreate(BaseModel):
name: Annotated[str, Field(min_length=1)]
root_path: str
path_template: str | None = "{author}/{title}"
description: str | None = None
read_only: bool = False
@computed_field
@property
def slug(self) -> str:
return slugify(self.name)
class LibraryRead(BaseModel):
id: int
name: str
root_path: str
path_template: str
description: str | None
read_only: bool
total: int | None = None
class LibraryUpdate(BaseModel):
name: str | None
root_path: str | None
path_template: str | None
description: str | None
read_only: bool | None

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class PublisherRead(BaseModel):
id: int
name: str

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator
class ShelfRead(BaseModel):
id: int
title: str
library_id: int | None = None
total: int | None = None # Number of books in the shelf
class ShelfCreate(BaseModel):
title: str = Field(min_length=1)
library_id: int | None = None
book_ids: list[int] = []
@field_validator("title", mode="before")
@classmethod
def strip_title(cls, v: str) -> str:
return v.strip() if isinstance(v, str) else v

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class TagRead(BaseModel):
id: int
name: str

View File

@@ -0,0 +1,23 @@
from typing import Annotated
from pydantic import BaseModel, BeforeValidator
def is_strong(value: str) -> str:
if len(value) < 8:
raise ValueError("Password must be at least 8 characters long")
return value
class UserCreate(BaseModel):
email: str
password: Annotated[str, BeforeValidator(is_strong)]
class UserRead(BaseModel):
email: str
class UserLogin(BaseModel):
email: str
password: str

View File

@@ -0,0 +1,8 @@
from .book import BookService
from .library import LibraryService
from .user import UserService
from .bookshelf import ShelfService
from .author import AuthorService
from .tag import TagService
from .publisher import PublisherService
from .book_progress import BookProgressService

View File

@@ -0,0 +1,19 @@
# src/chitai/services/author.py
# Third-party libraries
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
# Local Imports
from chitai.database.models import Author
class AuthorService(SQLAlchemyAsyncRepositoryService[Author]):
"""Author service for managing author operations."""
class Repo(SQLAlchemyAsyncRepository[Author]):
"""Author repository."""
model_type = Author
repository_type = Repo

View File

@@ -0,0 +1,610 @@
# src/chitai/services/book.py
# Standard library
from __future__ import annotations
from collections import defaultdict
import mimetypes
from io import BytesIO
from pathlib import Path
import uuid
import zipfile
from typing import TYPE_CHECKING, Any, AsyncIterator
# Third-party libraries
from advanced_alchemy.extensions.litestar import service
from advanced_alchemy.service import (
SQLAlchemyAsyncRepositoryService,
ModelDictT,
is_dict,
)
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from advanced_alchemy.filters import CollectionFilter
from sqlalchemy import inspect
from litestar.response import File
from litestar.datastructures import UploadFile
import aiofiles
from aiofiles import os as aios
from PIL import Image
# Local imports
from chitai.config import settings
from chitai.database.models import (
Book,
Author,
Tag,
Publisher,
BookSeries,
FileMetadata,
Identifier,
BookList,
Library,
)
from chitai.database.models.book_progress import BookProgress
from chitai.schemas.book import BooksCreateFromFiles
from chitai.services.filesystem_library import BookPathGenerator
from chitai.services.metadata_extractor import Extractor as MetadataExtractor
from chitai.services.utils import (
cleanup_empty_parent_directories,
delete_directory,
delete_file,
is_empty,
move_dir_contents,
move_file,
save_image,
)
class BookService(SQLAlchemyAsyncRepositoryService[Book]):
"""Book service for managing book operations."""
class Repo(SQLAlchemyAsyncRepository[Book]):
"""Book repository."""
model_type = Book
repository_type = Repo
async def create(self, data: ModelDictT[Book], library: Library, **kwargs) -> Book:
"""
Create a new book entity.
Orchestrates the full book creation pipeline: extracting metadata from files,
saving cover images, and storing book files to the filesystem.
Args:
data: Book data as a dictionary or model.
library: The library the book belongs to.
*args: Additional positional arguments passed to parent create.
**kwargs: Additional keyword arguments passed to parent create.
Returns:
The created Book entity.
"""
if not is_dict(data):
data = data.model_dump()
await self._parse_metadata_from_files(data)
await self._save_cover_image(data)
await self._save_book_files(library, data)
return await super().create(data, **kwargs)
async def create_many_from_files(
self, data: BooksCreateFromFiles, library: Library, **kwargs
) -> list[Book]:
"""
Create multiple books from uploaded files.
Groups files by their parent directory to organize books. Files in the root
directory are treated as separate individual books.
Args:
data: Container with list of uploaded files.
library: The library the books belong to.
*args: Additional positional arguments passed to create.
**kwargs: Additional keyword arguments passed to create.
Returns:
List of created Book entities.
"""
if not data.files:
raise ValueError("Must upload at least one file")
# Group book files if they are within the same nested directory
books = defaultdict(list)
for file in data.files:
filepath = Path(file.filename)
# Books within the root directory should be treated as separate books
if len(filepath.parent.parts) > 1:
books[filepath.parent].append(file)
else:
books[filepath].append(file)
return [
await self.create(
{"files": [file for file in files], "library_id": library.id},
library,
**kwargs,
)
for files in books.values()
]
async def create_many_from_existing_files(
self,
file_paths: list[Path],
consume_path: Path,
library: Library,
**kwargs
) -> list[Book]:
# Group up files if they are in the same leaf directory
books = []
file_groups: dict[Path, list[Path]] = defaultdict(list)
for file_path in file_paths:
rel_path = file_path.relative_to(consume_path)
parent_rel = rel_path.parent
# Files at root level should each be in their own group
# Use the file's relative path as the group key instead of parent
if parent_rel == Path("."):
group_key = rel_path # Use the file itself as the key
else:
group_key = parent_rel # Use parent directory as key
# Add to appropriate group
file_groups[group_key].append(file_path)
# For each grouping
for group, files in file_groups.items():
data: dict[str, Any] = {'files': files}
await self._parse_metadata_from_files(data, root_path=consume_path)
await self._save_cover_image(data)
# Get info from files
path_gen = BookPathGenerator(library.root_path)
parent = path_gen.generate_path(data)
data["path"] = str(parent)
data["library_id"] = library.id
file_metadata = []
for file in files:
stats = await aios.stat(file)
file_size = stats.st_size
content_type, _ = mimetypes.guess_type(file)
filename = path_gen.generate_filename(data, Path(file.name))
file_metadata.append(
FileMetadata(
path=str(filename),
size=file_size,
hash="stub-hash", # TODO: implement file hashing to catch duplicates
content_type=content_type,
)
)
data["files"] = file_metadata
# Move files to appropriate directory
if len(files) > 1:
await move_dir_contents(consume_path / group, parent)
else:
if len(group.parts) > 1:
await move_file(consume_path / group / files[0].name, parent / files[0].name)
else:
await move_file(consume_path / group, parent / group)
cleanup_empty_parent_directories(consume_path / group, consume_path)
books.append(await super().create(data))
await self.repository.session.commit()
return books
async def delete(
self,
book_ids: list[int],
library: Library,
delete_files: bool = False,
**kwargs,
) -> None:
"""
Delete books and optionally delete their associated files.
Removes book records from the database and optionally deletes files from the
filesystem, including cover images and book files.
Args:
book_ids: List of book IDs to delete.
library: The library containing the books.
delete_files: If True, also delete associated files from the filesystem.
**kwargs: Additional keyword arguments.
"""
books = await self.list(
CollectionFilter(field_name="id", values=book_ids),
Book.library_id == library.id,
)
for book in books:
await self.remove_files(
book.id,
[file.id for file in book.files],
library=library,
delete_files=delete_files,
)
if book.cover_image:
await delete_file(book.cover_image)
await super().delete_many(book_ids)
async def get_file(self, book_id: int, file_id: int) -> File:
"""
Retrieve a file for download.
Args:
book_id: The ID of the book containing the file.
file_id: The ID of the file to retrieve.
Returns:
A File object ready for download.
Raises:
ValueError: If the file is missing or not found for the given book.
"""
book = await self.get(book_id)
for file in book.files:
if file.id != file_id:
continue
path = Path(book.path) / Path(file.path)
if not await aios.path.isfile(path):
raise ValueError("The file is missing")
return File(path, media_type=file.content_type)
raise ValueError("No such file for the given book")
async def get_files(
self, book_ids: list[int], library_id: int
) -> AsyncIterator[bytes]:
"""
Get all selected book files as a compressed zip file.
Streams the zip file in chunks to avoid loading the entire file into memory.
Args:
book_ids: List of book IDs to include in the zip.
library_id: The ID of the library containing the books.
Yields:
Chunks of compressed zip file data.
"""
books = await self.list(Book.id.in_(book_ids), Book.library_id == library_id)
files = [file for book in books for file in book.files]
buffer = BytesIO()
with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
for file in files:
path = Path(file.path)
if path.exists():
zip_file.write(path, arcname=path.name)
buffer.seek(0)
chunk_size = 32768 # 32 KiB
while True:
chunk = buffer.read(chunk_size)
if not chunk:
break
yield chunk
async def update(
self, book_id: int, update_data: ModelDictT[Book], library: Library
) -> Book:
"""
Update a book's metadata and files.
Handles cover image updates, path reorganization based on updated metadata,
and cleanup of empty directories. Ensures related entities (authors, series, etc.)
remain unique in the database.
Args:
book_id: The ID of the book to update.
update_data: The updated book data.
library: The library containing the book.
Returns:
The updated Book entity.
"""
book = await self.get(book_id)
if not is_dict(update_data):
data = service.schema_dump(update_data, exclude_unset=True)
else:
data = update_data
data["id"] = book_id
if data.get("cover_image", None):
# Delete the existing cover image
if book.cover_image:
await delete_file(book.cover_image)
await self._save_cover_image(data)
# TODO: extract out into its own function _update_book_path
# Check if file path must be updated
path_gen = BookPathGenerator(library.root_path)
updated_path = path_gen.generate_path(book.to_dict() | data)
if str(updated_path) != book.path:
# TODO: Move only the files associated with the book instead of the whole directory
await move_dir_contents(book.path, updated_path)
data["path"] = str(updated_path)
cleanup_empty_parent_directories(Path(book.path), Path(library.root_path))
return await super().update(data, item_id=book_id, execution_options={"populate_existing": True})
async def add_files(
self, book_id: int, files: list[UploadFile], library: Library
) -> None:
"""
Add additional files to an existing book.
Args:
book_id: The ID of the book.
files: List of files to add.
library: The library containing the book.
"""
book = await self.get(book_id)
data = book.to_dict()
data["files"] = files
await self._save_book_files(library, data)
book.files.extend(data["files"])
await self.update(book.id, {"files": [file for file in book.files]}, library)
async def remove_files(
self, book_id: int, file_ids: list[int], delete_files: bool, library: Library
) -> None:
"""
Remove files from an existing book.
Optionally deletes the files from the filesystem and cleans up empty directories.
Args:
book_id: The ID of the book.
file_ids: List of file IDs to remove.
delete_files: If True, also delete the files from the filesystem.
library: The library containing the book.
"""
book = await self.get_one(Book.id == book_id, Book.library_id == library.id)
if delete_files:
# TODO: Extract this out into its own function
for file in (file for file in book.files if file.id in file_ids):
full_path = Path(book.path) / Path(file.path)
if await aios.path.isfile(full_path):
await aios.remove(full_path)
cleanup_empty_parent_directories(
Path(book.path), Path(library.root_path)
)
book.files = [file for file in book.files if file.id not in file_ids]
await super().update(book.files, book.id)
async def to_model_on_create(self, data: ModelDictT[Book]) -> ModelDictT[Book]:
"""
Hook called during model creation to populate unique relationships.
Args:
data: The book data to transform.
Returns:
The transformed data with unique relationship entities.
"""
data = service.schema_dump(data)
self._preprocess_book_data(data)
return await self._populate_with_unique_relationships(data)
async def to_model_on_update(self, data: ModelDictT[Book]) -> ModelDictT[Book]:
"""
Hook called during model update to populate unique relationships.
Args:
data: The book data to transform.
Returns:
The transformed data with unique relationship entities.
"""
data = service.schema_dump(data)
self._preprocess_book_data(data)
model_data = await self._populate_with_unique_relationships(data)
return model_data
def _preprocess_book_data(self, data: dict) -> dict:
"""Transform API input format to model format."""
if not isinstance(data, dict):
return data
# Transform dict identifiers to list of Identifier objects
if "identifiers" in data and isinstance(data["identifiers"], dict):
data["identifiers"] = [
Identifier(name=key, value=val)
for key, val in data["identifiers"].items()
]
return data
async def _populate_with_unique_relationships(self, data: ModelDictT[Book]):
"""
Ensure relationship entities (authors, series, tags, etc.) are unique in the database.
Fetches or creates unique instances of related entities to prevent duplicates.
Args:
data: Book data containing relationship fields.
Returns:
The data with relationship entities replaced by unique instances.
"""
if not isinstance(data, dict):
return data
if data.get("id", None):
model_data = await super().get(data["id"])
# Update all scalar fields dynamically
mapper = inspect(Book)
for column in mapper.columns:
if column.name in data:
setattr(model_data, column.name, data[column.name])
else:
model_data = await super().to_model(data)
if "authors" in data:
model_data.authors = [
await Author.as_unique_async(self.repository.session, name=author)
for author in data["authors"]
]
if "series" in data:
if data["series"]:
model_data.series = await BookSeries.as_unique_async(
self.repository.session, title=data["series"]
)
else:
model_data.series = None
if "publisher" in data:
if data["publisher"]:
model_data.publisher = await Publisher.as_unique_async(
self.repository.session, name=data["publisher"]
)
else:
model_data.publisher = None
if "tags" in data:
model_data.tags = [
await Tag.as_unique_async(self.repository.session, name=tag)
for tag in data["tags"]
]
if "identifiers" in data:
model_data.identifiers = data["identifiers"]
if "files" in data:
model_data.files = data["files"]
return model_data
async def _save_book_files(self, library: Library, data: dict) -> dict:
"""
Save uploaded book files to the filesystem.
Uses the library's path template to organize files and extracts file metadata
(size, content type, etc.) for database storage.
Args:
library: The library containing the book.
data: Book data with files to save.
Returns:
The data with file paths and metadata populated.
"""
# Use the library template path with the book's data to generate a filepath
path_gen = BookPathGenerator(library.root_path)
parent = path_gen.generate_path(data)
data["path"] = str(parent)
file_metadata = []
CHUNK_SIZE = 262144 # 256 KiB
for file in data.pop("files", []):
if TYPE_CHECKING:
file: UploadFile
# TODO filename should be generated with generate_path function
filename = path_gen.generate_filename(data, Path(file.filename).name)
# Store files in the correct directory and return the file's metadata
await file.seek(0)
path = parent / filename
path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(path, "wb") as dest:
# Read spooled file and save it to the local filesystem
while chunk := await file.read(CHUNK_SIZE):
await dest.write(chunk)
stats = await aios.stat(path)
file_size = stats.st_size
file_metadata.append(
FileMetadata(
path=str(filename),
size=file_size,
hash="stub-hash", # TODO: implement file hashing to catch duplicates
content_type=file.content_type,
)
)
data["files"] = file_metadata
return data
async def _parse_metadata_from_files(self, data: dict, root_path: Path | None = None) -> dict:
"""
Extract metadata (title, author, etc.) from book files.
Fills in missing book metadata fields using extracted information from files.
Args:
data: Book data to populate with extracted metadata.
Returns:
The data with extracted metadata populated in empty fields.
"""
extracted_metadata = await MetadataExtractor.extract_metadata(data["files"], root_path)
# Add missing fields and update empty (falsey) fields with extracted metadata
for attr in extracted_metadata.keys():
if not data.get(attr, None):
data[attr] = extracted_metadata[attr]
return data
async def _save_cover_image(self, data: dict) -> dict:
"""
Save the book's cover image to the filesystem.
Converts uploaded images to WebP format for consistent storage and assigns
a unique UUID-based filename.
Args:
data: Book data potentially containing a cover_image field.
Returns:
The data with cover_image path populated or removed if not provided.
"""
if (image := data.pop("cover_image", None)) is None:
return data
if isinstance(image, UploadFile):
content = await image.read()
image = Image.open(BytesIO(content))
filepath = Path(settings.book_cover_path)
filename = Path(f"{uuid.uuid4()}.webp") # Random filename
await save_image(image, filepath / filename)
data["cover_image"] = str(filepath / filename)
return data

View File

@@ -0,0 +1,17 @@
# src/chitai/services/book_progress.py
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from chitai.database import models as m
class BookProgressService(SQLAlchemyAsyncRepositoryService[m.BookProgress]):
"""Book progress service for managing progress operations."""
class Repo(SQLAlchemyAsyncRepository[m.BookProgress]):
"""Book progress repository."""
model_type = m.BookProgress
repository_type = Repo

View File

@@ -0,0 +1,88 @@
# src/chitai/services/bookshelf.py
# Third-party libraries
from typing import Any, Sequence
from advanced_alchemy.exceptions import ErrorMessages
from advanced_alchemy.filters import StatementFilter
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from sqlalchemy import ColumnElement, Select, delete
# Local imports
from chitai.database.models.book_list import BookList, BookListLink
from chitai.services import BookService
from chitai.services.filters.book import BookshelfFilter
from chitai.database import models as m
class ShelfService(SQLAlchemyAsyncRepositoryService[BookList]):
"""Service for managing bookshelves and their contents."""
class Repo(SQLAlchemyAsyncRepository[BookList]):
"""Repository for BookList entities."""
model_type = BookList
repository_type = Repo
async def add_books(self, shelf_id: int, book_ids: list[int]) -> m.BookList:
"""
Add books to the specified bookshelf.
Creates BookListLink entries to associate books with the shelf, automatically
assigning positions based on the current shelf size.
Args:
shelf_id: The ID of the shelf to add books to.
book_ids: List of book IDs to add.
Raises:
ValueError: If the shelf does not exist.
"""
shelf = await self.get(shelf_id, load=[m.BookList.book_links])
async with BookService.new(session=self.repository.session) as book_service:
# Verify all books exist
total = await book_service.count(m.Book.id.in_(book_ids))
if total != len(book_ids):
raise ValueError("One or more books not found")
# Get all books that exist on the shelf
shelf_books = await book_service.list(BookshelfFilter(lists=[shelf_id]))
# Filter out books that are already on the shelf to avoid adding them again
existing_book_ids = [book.id for book in shelf_books]
new_book_ids = [bid for bid in book_ids if bid not in existing_book_ids]
# Add new books
if new_book_ids:
for book_id in new_book_ids:
shelf.book_links.append(BookListLink(book_id=book_id))
await self.update(shelf)
return await self.get(shelf_id, load=[m.BookList.book_links])
async def remove_books(self, shelf_id: int, book_ids: list[int]) -> m.BookList:
"""
Remove books from the specified bookshelf.
Args:
shelf_id: The ID of the shelf to remove books from.
book_ids: List of book IDs to remove.
Raises:
ValueError: If the shelf does not exist.
"""
shelf = await self.get(shelf_id, load=[BookList.book_links])
book_ids_set = set(book_ids) # for O(1) lookup
for i in range(len(shelf.book_links) - 1, -1, -1):
if shelf.book_links[i].book_id in book_ids_set:
shelf.book_links.pop(i)
await self.update(shelf)
return await self.get(shelf_id, load=[m.BookList.book_links])

View File

@@ -0,0 +1,127 @@
import asyncio
from pathlib import Path
from collections import defaultdict
from chitai.database.models.library import Library
from chitai.services import BookService, LibraryService
from chitai.services.metadata_extractor import Extractor
from chitai.services.utils import create_directory
from watchfiles import awatch, Change
class ConsumeDirectoryWatcher:
"""Watches a directory and batch processes files by their relative path."""
def __init__(self, watch_path: str, library_service: LibraryService, book_service: BookService, batch_delay: float = 3.0):
"""
Initialize the file watcher.
Args:
watch_path: Directory path to watch
batch_delay: Seconds to wait before processing a batch
"""
self.watch_path = Path(watch_path).resolve() # Convert to absolute path
self.batch_delay = batch_delay
self.file_groups: dict[str, set[Path]] = defaultdict(set)
self._processing_tasks = set()
self.book_service = book_service
self.library_service = library_service
async def init_watcher(self):
"""Background task that watches for file changes."""
# Create consume directories for each library if they do not exist
libraries = await self.library_service.list()
for lib in libraries:
await create_directory(Path(self.watch_path) / Path(lib.slug))
print(f"Starting file watcher on {self.watch_path}")
try:
async for changes in awatch(str(self.watch_path)):
for change_type, file_path in changes:
if change_type != Change.added:
continue
file_path = Path(file_path)
# If a directory was added, scan it for existing files
if file_path.is_dir():
print(f"Directory added: {file_path}")
await self._handle_directory_added(file_path)
else:
print(f"File added: {file_path}")
await self._handle_file_added(file_path)
except asyncio.CancelledError:
print("File watcher stopped")
# Wait for any pending processing tasks
if self._processing_tasks:
await asyncio.gather(*self._processing_tasks, return_exceptions=True)
raise
async def _handle_file_added(self, file_path: Path):
"""Handle a single file being added."""
# Get relative path from watch directory
rel_path = file_path.relative_to(self.watch_path)
parent_rel = rel_path.parent
library = parent_rel.parts[0]
# Add to appropriate group
self.file_groups[library].add(file_path)
# Schedule batch processing for this group
self._schedule_batch_processing(library)
async def _handle_directory_added(self, dir_path: Path):
"""Handle a directory being added - scan it for existing files."""
try:
# Recursively find all files in the added directory
for file_path in dir_path.rglob("*"):
if file_path.is_file():
print(f"Found existing file: {file_path}")
await self._handle_file_added(file_path)
except Exception as e:
print(f"Error scanning directory {dir_path}: {e}")
def _schedule_batch_processing(self, library_slug: str):
"""Schedule batch processing for a specific path group."""
# Create a task to process this group after a delay
task = asyncio.create_task(self._delayed_batch_process(library_slug))
self._processing_tasks.add(task)
task.add_done_callback(self._processing_tasks.discard)
async def _delayed_batch_process(self, library_slug: str):
"""Wait for batch delay, then process accumulated files."""
await asyncio.sleep(self.batch_delay)
# Get and clear the file list for this path
if library_slug not in self.file_groups:
return
files_to_process = self.file_groups[library_slug].copy()
self.file_groups[library_slug].clear()
if not files_to_process:
return
print(f"Batch processing {len(files_to_process)} files from {library_slug}")
await self._process_batch(files_to_process, library_slug)
async def _process_batch(self, file_paths: set[Path], library_slug: str):
"""Process a batch of files."""
try:
books = await self.book_service.create_many_from_existing_files(
list(file_paths),
self.watch_path / Path(library_slug),
library=await self._get_library(library_slug),
)
print(f"Created {len(books)} books!")
except Exception as e:
print(f"Error processing batch: {e}")
raise e
async def _get_library(self, slug: str) -> Library:
return await self.library_service.get_one(Library.slug == slug)

View File

@@ -0,0 +1,342 @@
# src/chitai/services/dependencies.py
# Standard library
from __future__ import annotations
from typing import Any, AsyncGenerator, Callable, NotRequired, Optional
# Third-party libraries
from advanced_alchemy.extensions.litestar.providers import (
create_service_provider,
FilterConfig,
DEPENDENCY_DEFAULTS,
DependencyDefaults,
)
from advanced_alchemy.exceptions import NotFoundError
from advanced_alchemy.filters import CollectionFilter, StatementFilter
from advanced_alchemy.service import FilterTypeT
from sqlalchemy.orm import selectinload, with_loader_criteria
from sqlalchemy.ext.asyncio import AsyncSession
from litestar import Request
from litestar.params import Dependency, Parameter
from litestar.security.jwt import Token
from litestar.exceptions import HTTPException
from litestar.di import Provide
from advanced_alchemy.extensions.litestar.providers import create_filter_dependencies
# Local imports
from chitai import schemas as s
from chitai.database import models as m
from chitai.services import (
UserService,
BookService,
LibraryService,
BookProgressService,
ShelfService,
TagService,
AuthorService,
PublisherService,
)
from chitai.config import settings
from chitai.services.filters.book import (
AuthorFilter,
BookshelfFilter,
CustomOrderBy,
ProgressFilter,
TagFilter,
TrigramSearchFilter,
)
async def provide_book_service(
db_session: AsyncSession, current_user: m.User | None = None
) -> AsyncGenerator[BookService, None]:
load = [
selectinload(m.Book.author_links).selectinload(m.BookAuthorLink.author),
selectinload(m.Book.tag_links).selectinload(m.BookTagLink.tag),
m.Book.publisher,
m.Book.files,
m.Book.identifiers,
m.Book.series,
]
# Load in specific user-book data
if current_user:
# Load progress data
load.extend(
[
selectinload(m.Book.progress_records),
with_loader_criteria(
m.BookProgress, m.BookProgress.user_id == current_user.id
),
]
)
# Load shelf data
load.extend(
[
selectinload(m.Book.list_links).selectinload(m.BookListLink.book_list),
with_loader_criteria(
m.BookListLink,
m.Book.lists.any(m.BookList.user_id == current_user.id),
),
]
)
provider_func = create_service_provider(
BookService,
load=load,
uniquify=True,
error_messages={
"integrity": "Book operation failed.",
"not_found": "The book does not exist.",
},
config=settings.alchemy_config,
)
async for service in provider_func(db_session=db_session):
yield service
def create_book_filter_dependencies(
config: FilterConfig,
dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS,
) -> dict[str, Provide]:
"""Create filter dependencies for books, including custom progress filters.
Overrides:
- SearchFilter: Uses trigram search for better fuzzy matching
- OrderBy: Adds "random" sort order option
Args:
config: FilterConfig instance with desired settings.
dep_defaults: Dependency defaults to use for the filter dependencies
Returns:
Dictionary of filter provider functions including base and custom filters.
"""
# Get base filters first
filters = create_filter_dependencies(config, dep_defaults)
# OVERRIDE: Custom search filter with trigram search
if config.get("search"):
search_fields = config.get("search")
def provide_trigram_search_filter(
search_string: str | None = Parameter(
title="Field to search",
query="searchString",
default=None,
required=False,
),
ignore_case: bool | None = Parameter(
title="Search should be case sensitive",
query="searchIgnoreCase",
default=config.get("search_ignore_case", False),
required=False,
),
) -> TrigramSearchFilter | None:
if not search_string:
return None
# field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
return TrigramSearchFilter(
field_name="book.title",
value=search_string,
ignore_case=ignore_case or False,
)
filters[dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY] = Provide(
provide_trigram_search_filter, sync_to_thread=False
)
# OVERRIDE: Custom order by with "random" option
if config.get("sort_field"):
sort_field = config.get("sort_field")
def provide_custom_order_by(
field_name: str | None = Parameter(
title="Order by field",
query="orderBy",
default=sort_field,
required=False,
),
sort_order: str | None = Parameter(
title="Sort order (asc, desc, or random)",
query="sortOrder",
default=config.get("sort_order", "desc"),
required=False,
),
current_user: m.User | None = Dependency(default=None),
) -> CustomOrderBy | None:
if not field_name:
return None
# Validate sort_order
valid_orders = {"asc", "desc", "random"}
if sort_order not in valid_orders:
raise ValueError(f"sort_order must be one of {valid_orders}")
return CustomOrderBy(
field_name=field_name, sort_order=sort_order, user=current_user
)
filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(
provide_custom_order_by, sync_to_thread=False
)
return filters
def provide_libraries_filter(
library_ids: Optional[list[int]] = Parameter(
title="Filter by libraries", query="libraries", default=None, required=None
),
) -> CollectionFilter | None:
if not library_ids:
return None
return CollectionFilter(field_name="library_id", values=library_ids)
def provide_authors_filter(
author_ids: Optional[list[int]] = Parameter(
title="Filter by authors", query="authors", default=None, required=False
),
) -> AuthorFilter | None:
if not author_ids:
return None
return AuthorFilter(authors=author_ids)
def provide_publishers_filter(
publisher_ids: Optional[list[int]] = Parameter(
title="Filter by publishers", query="publishers", default=None, required=False
),
) -> CollectionFilter | None:
if not publisher_ids:
return None
return CollectionFilter(field_name="publisher_id", values=publisher_ids)
def provide_tags_filter(
tag_ids: Optional[list[int]] = Parameter(
title="Filter by tags", query="tags", default=None, required=False
),
) -> TagFilter | None:
if not tag_ids:
return None
return TagFilter(tags=tag_ids)
def provide_bookshelves_filter(
shelf_ids: Optional[list[int]] = Parameter(
title="Filter by bookshelves", query="shelves", default=None, required=False
),
current_user: m.User = Dependency(skip_validation=True),
) -> BookshelfFilter | None:
if not shelf_ids:
return None
return BookshelfFilter(lists=shelf_ids, user_id=current_user.id)
def provide_progress_filter(
progress_statuses: Optional[list[str]] = Parameter(
title="Filter by progress status",
query="progress",
default=None,
required=False,
),
current_user: m.User = Dependency(skip_validation=True),
) -> ProgressFilter | None:
if not progress_statuses:
return None
return ProgressFilter(user_id=current_user.id, statuses=set(progress_statuses))
def provide_book_filters(
libraries_filter: CollectionFilter | None = Dependency(skip_validation=True),
authors_filter: AuthorFilter | None = Dependency(skip_validation=True),
publishers_filter: CollectionFilter | None = Dependency(skip_validation=True),
tags_filter: TagFilter | None = Dependency(skip_validation=True),
bookshelves_filter: BookshelfFilter | None = Dependency(skip_validation=True),
progress_filter: ProgressFilter | None = Dependency(skip_validation=True),
) -> list[StatementFilter]:
"""Combine all optional filters into a single list."""
return [
f
for f in [
libraries_filter,
authors_filter,
publishers_filter,
tags_filter,
bookshelves_filter,
progress_filter,
]
if f is not None
]
provide_library_service = create_service_provider(
LibraryService,
)
provide_user_service = create_service_provider(
UserService,
error_messages={
"duplicate_key": "Verification token already exists.",
"integrity": "User operation failed.",
},
)
provide_shelf_service = create_service_provider(
ShelfService,
)
provide_progress_service = create_service_provider(
BookProgressService,
)
provide_tag_service = create_service_provider(TagService)
provide_author_service = create_service_provider(AuthorService)
provide_publisher_service = create_service_provider(PublisherService)
async def get_library_by_id(
library_service: LibraryService,
books_service: BookService,
library_id: int | None = None,
book_id: int | None = Dependency(),
) -> m.Library:
"""Retrieves the library matching the id."""
if not library_id:
try:
book = await books_service.get(book_id)
library_id = book.library_id
except NotFoundError:
raise HTTPException(status_code=404, detail="The given book does not exist")
try:
return await library_service.get(library_id)
except NotFoundError:
raise HTTPException(status_code=404, detail="The given library does not exist")
def provide_user(request: Request[m.User, Token, Any]) -> m.User:
return request.user
def provide_optional_user(request: Request[m.User, Token, Any]) -> m.User | None:
if request.user:
return request.user
return None
async def provide_user_via_basic_auth(request: Request[m.User, None, Any]) -> m.User:
return request.user

View File

@@ -0,0 +1,135 @@
# src/chitai/services/filesystem_library.py
from pathlib import Path
import re
from jinja2 import Template
from advanced_alchemy.service import ModelDictT
import chitai.database.models as m
# TODO: Replace Jinja2 templates with a simpler custom templating system.
# Current Jinja2 implementation is overly complex for basic path generation.
# Consider a simpler approach:
# - Use simple placeholder syntax: {author}/{series}/{series_position} - {title}
# - Support only essential variables: {author}, {series}, {series_position}, {title}, {isbn}
# - Auto-handle missing values (e.g., skip {series}/ if series is empty)
default_path_template = """
/{{book.authors[0] if book.authors else 'Unknown'}}
{%- if book.series -%}
/{{book.series}}
{%- endif -%}
{%- if book.series_position -%}
{%- set formatted_position = '{:04.1f}'.format(book.series_position|float) -%}
/{{formatted_position.rstrip('0').rstrip('.')}} - {{book.title}}
{%- else %}/{{book.title}}
{% endif -%}
"""
class BookPathGenerator:
"""
Generates organized filesystem paths for books based on customizable Jinja2 templates.
Uses book metadata (author, series, title, etc.) to organize books into a hierarchical
directory structure. Supports separate templates for the directory path and filename.
"""
def __init__(
self,
root_path: Path | str,
path_template: str = default_path_template,
filename_template: str | None = None,
) -> None:
"""
Initialize the path generator with templates.
Args:
root_path: The base filesystem path where books will be organized.
path_template: Jinja2 template string for directory structure.
Default organizes by author, series, and title.
filename_template: Optional Jinja2 template for custom filenames.
If None, original filenames are preserved.
"""
self.root_path = Path(root_path)
self.path_template: Template = Template(path_template)
self.filename_template = (
Template(filename_template) if filename_template else None
)
def generate_full_path(self, book_data: dict, filename: Path | str) -> Path:
"""
Generate the complete filesystem path including directory structure and filename.
Combines the generated book directory path with the generated filename to produce
the full path where a book file should be stored. This is the complete path
relative to the generator's root path.
Args:
book_data: Dictionary containing book metadata (author, title, series, etc.).
filename: The original filename or Path object to transform.
Returns:
Complete relative Path including directory structure and filename.
"""
return self.generate_path(book_data) / self.generate_filename(
book_data, filename
)
def generate_path(self, book_data: dict) -> Path:
"""
Generate the organized directory path for a book based on its metadata.
Uses the path template to create a hierarchical structure. Handles edge cases like
missing authors (defaults to 'Unknown'), optional series information, and series
position formatting. Sanitizes the result by removing consecutive slashes and
unnecessary whitespace.
Args:
book_data: Dictionary containing book metadata with keys like:
- authors: List of author names
- series: Series name (optional)
- series_position: Position in series (optional)
- title: Book title
Returns:
Cleaned Path object relative to the generator's root path.
"""
result = self.root_path / Path(self.path_template.render(book=book_data))
# Clean up
result = re.sub(r"/+", "/", str(result)) # Remove consecutive backslashes
result = re.sub(r"/\s+", "/", result) # Remove whitespace after backslashes
return Path(result.strip())
def generate_filename(self, book_data: dict, filename: Path | str) -> Path:
"""
Generate or transform a filename based on book metadata.
If a filename template is configured, uses it to generate a custom filename
based on book metadata. Otherwise, returns the original filename unchanged.
Args:
book_data: Dictionary containing book metadata.
filename: The original filename or Path object to transform.
Returns:
The generated or original filename as a Path object.
Raises:
jinja2.TemplateError: If the filename template contains invalid Jinja2 syntax.
"""
filename = Path(filename)
# TODO: Implement a filename template.
# For now just returns the original filename
if self.filename_template:
...
return Path(filename)

View File

@@ -0,0 +1,32 @@
from typing import Any, Optional
from dataclasses import dataclass
from advanced_alchemy.filters import (
StatementTypeT,
StatementFilter,
ModelT,
)
from chitai.database import models as m
from sqlalchemy import select
@dataclass
class AuthorLibraryFilter(StatementFilter):
"""Filter authors by library_id - returns authors with at least one book in the library."""
libraries: Optional[list[int]] = None
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args, **kwargs
) -> StatementTypeT:
if self.libraries:
statement = statement.where(
m.Author.id.in_(
select(m.BookAuthorLink.author_id)
.join(m.Book)
.where(m.Book.library_id.in_(self.libraries))
)
)
return super().append_to_statement(statement, model, *args, **kwargs)

View File

@@ -0,0 +1,241 @@
from enum import StrEnum
from typing import Any, Optional
from dataclasses import dataclass, field
from sqlalchemy.orm import aliased
from sqlalchemy import Select, and_, desc, func, or_, text
from advanced_alchemy.filters import (
StatementTypeT,
StatementFilter,
CollectionFilter,
ModelT,
)
from chitai.database import models as m
@dataclass
class TagFilter(StatementFilter):
"""Filters for books containing at least one of the tags."""
tags: Optional[list[int]]
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args, **kwargs
) -> StatementTypeT:
if self.tags:
statement = statement.where(m.Book.tags.any(m.Tag.id.in_(self.tags)))
return super().append_to_statement(statement, model, *args, **kwargs)
@dataclass
class AuthorFilter(StatementFilter):
"""Filters for books by any of the given authors."""
authors: Optional[list[int]]
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args, **kwargs
) -> StatementTypeT:
if self.authors:
statement = statement.where(
m.Book.authors.any(m.Author.id.in_(self.authors))
)
return super().append_to_statement(statement, model, *args, **kwargs)
@dataclass
class BookshelfFilter(StatementFilter):
"""Filters for books in the given bookshelves."""
lists: Optional[list[int]]
user_id: Optional[int] = None
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args, **kwargs
) -> StatementTypeT:
if self.lists:
if self.user_id:
statement = statement.where(
m.Book.lists.any(
and_(
m.BookList.user_id == self.user_id,
m.BookList.id.in_(self.lists),
)
)
)
else:
statement = statement.where(
m.Book.lists.any(m.BookList.id.in_(self.lists))
)
return super().append_to_statement(statement, model, *args, **kwargs)
class ProgressStatus(StrEnum):
"""Enum for different progress statuses."""
IN_PROGRESS = "in_progress"
READ = "read"
UNREAD = "unread"
@dataclass
class ProgressFilter(StatementFilter):
"""
Filters books based on their progress status.
Supports filtering completed, in progress, and unread books.
"""
user_id: int
statuses: set[ProgressStatus] = field(default_factory=set)
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args: Any, **kwargs: Any
) -> StatementTypeT:
if not isinstance(statement, Select) or not self.statuses:
return super().append_to_statement(statement, model, *args, **kwargs)
# Check if UNREAD is the only status or part of the filter
has_unread = ProgressStatus.UNREAD in self.statuses
has_other_statuses = bool(self.statuses - {ProgressStatus.UNREAD})
# If only filtering by UNREAD, use a simpler query
if has_unread and not has_other_statuses:
statement = statement.outerjoin(
m.BookProgress,
and_(
m.BookProgress.book_id == model.id,
m.BookProgress.user_id == self.user_id,
),
).where(m.BookProgress.id.is_(None))
return super().append_to_statement(statement, model, *args, **kwargs)
# For other statuses, build the where clause
statement = statement.outerjoin(
m.BookProgress,
and_(
m.BookProgress.book_id == model.id,
m.BookProgress.user_id == self.user_id,
),
).where(self._build_where_clause())
return super().append_to_statement(statement, model, *args, **kwargs)
def _build_where_clause(self) -> Any:
"""Build the appropriate where clause based on the status."""
status_conditions = []
if ProgressStatus.IN_PROGRESS in self.statuses:
status_conditions.append(
and_(
or_(
m.BookProgress.completed == False,
m.BookProgress.completed.is_(None),
),
m.BookProgress.progress > 0,
)
)
if ProgressStatus.READ in self.statuses:
status_conditions.append(m.BookProgress.completed == True)
if ProgressStatus.UNREAD in self.statuses:
status_conditions.append(m.BookProgress.id.is_(None))
return or_(*status_conditions)
@dataclass
class FileFilter(StatementFilter):
"""Filter books that are related to the given files."""
file_ids: list[int]
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args: Any, **kwargs: Any
) -> StatementTypeT:
statement = statement.where(
m.Book.files.any(m.FileMetadata.id.in_(self.file_ids))
)
return super().append_to_statement(statement, model, *args, **kwargs)
@dataclass
class CustomOrderBy(StatementFilter):
"""Order by filter with support for 'random' and 'last accessed' orderings."""
field_name: str
sort_order: str = "desc"
user: m.User | None = None
def append_to_statement(
self, statement: Any, model: type[Any], *args: Any, **kwargs: Any
) -> Any:
if not isinstance(statement, Select):
return super().append_to_statement(statement, model, *args, **kwargs)
# Sort randomly
if self.sort_order == "random":
statement = statement.order_by(func.random())
# Sort by BookProgress.updated_at for the current user
elif self.field_name == "last_accessed":
if self.user:
BP = aliased(m.BookProgress)
statement = statement.outerjoin(
BP, (BP.book_id == model.id) & (BP.user_id == self.user.id)
)
order_column = BP.updated_at
if self.sort_order == "asc":
statement = statement.order_by(order_column.asc().nullslast())
else:
statement = statement.order_by(order_column.desc().nullslast())
else:
# Skip sorting if no user provided
return statement
# Sort by a regular field
else:
field = getattr(model, self.field_name, None)
if field is not None:
if self.sort_order == "asc":
statement = statement.order_by(field.asc())
else:
statement = statement.order_by(field.desc())
return super().append_to_statement(statement, model, *args, **kwargs)
@dataclass
class TrigramSearchFilter(StatementFilter):
"""Custom search filter using PostgreSQL trigram similarity for fuzzy matching."""
field_name: str
value: str
ignore_case: bool = False
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args: Any, **kwargs: Any
) -> StatementTypeT:
if not isinstance(statement, Select):
return super().append_to_statement(statement, model, *args, **kwargs)
statement = statement.where(
or_(
text("books.title % :term").bindparams(term=self.value),
m.Book.title.ilike(f"%{self.value}%"),
)
).order_by(
text("greatest(similarity(books.title, :term)) DESC").bindparams(
term=self.value
)
)
return super().append_to_statement(statement, model, *args, **kwargs)

View File

@@ -0,0 +1,28 @@
from typing import Optional
from dataclasses import dataclass
from advanced_alchemy.filters import (
StatementTypeT,
StatementFilter,
ModelT,
)
from chitai.database import models as m
@dataclass
class PublisherLibraryFilter(StatementFilter):
"""Filter authors by library_id - returns publishers with at least one book in the library."""
libraries: Optional[list[int]] = None
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args, **kwargs
) -> StatementTypeT:
if self.libraries:
statement = statement.where(
m.Book.publisher_id == m.Publisher.id,
m.Book.library_id.in_(self.libraries),
)
return super().append_to_statement(statement, model, *args, **kwargs)

View File

@@ -0,0 +1,31 @@
from typing import Optional
from dataclasses import dataclass
from advanced_alchemy.filters import (
StatementTypeT,
StatementFilter,
ModelT,
)
from chitai.database import models as m
from sqlalchemy import select
@dataclass
class TagLibraryFilter(StatementFilter):
"""Filter tags by library_id - returns tags with at least one book in the library."""
libraries: Optional[list[int]] = None
def append_to_statement(
self, statement: StatementTypeT, model: type[ModelT], *args, **kwargs
) -> StatementTypeT:
if self.libraries:
statement = statement.where(
m.Tag.id.in_(
select(m.BookTagLink.tag_id)
.join(m.Book)
.where(m.Book.library_id.in_(self.libraries))
)
)
return super().append_to_statement(statement, model, *args, **kwargs)

View File

@@ -0,0 +1,134 @@
# src/chitai/services/library.py
# Third-party libraries
from pathlib import Path
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from advanced_alchemy import service
from advanced_alchemy.utils.text import slugify
# Local imports
from chitai.database.models.library import Library
from chitai.schemas.library import LibraryCreate, LibraryUpdate
from chitai.services.utils import (
DirectoryDoesNotExist,
create_directory,
directory_exists,
)
from chitai.config import settings
class LibraryService(SQLAlchemyAsyncRepositoryService[Library]):
"""Service for managing libraries and their configuration."""
class Repo(SQLAlchemyAsyncRepository[Library]):
"""Repository for Library entities."""
model_type = Library
repository_type = Repo
async def create(self, library: LibraryCreate, **kwargs) -> Library:
"""
Create a new library and initialize its filesystem structure.
For writable libraries, creates the root directory and any missing parent directories.
For read-only libraries, validates that the root directory already exists.
Args:
library: The library configuration to create.
**kwargs: Additional keyword arguments passed to parent create.
Returns:
The created Library entity.
Raises:
DirectoryDoesNotExist: If a read-only library's root directory does not exist.
OSError: If directory creation fails due to filesystem issues.
"""
# TODO: What if a library root_path is a child of an existing library?
if existing := await self.list(Library.root_path == library.root_path):
raise ValueError(f"Library already exists at '{library.root_path}'")
if library.read_only:
# Validate read-only library directory exists
if not await directory_exists(library.root_path):
raise DirectoryDoesNotExist(
f"Root directory '{library.root_path}' must exist for a read-only library"
)
# TODO: Verify the read-only library has read permissions
created_library = await super().create(
service.schema_dump(library, exclude_unset=False), **kwargs
)
if not library.read_only:
# Create directory for writable libraries
await create_directory(library.root_path)
# Create a consume directory
await create_directory(Path(settings.consume_path) / Path(library.slug))
return created_library
# TODO: Implement library deletion and optional file deletion
async def delete(
self, item_id: int, delete_files: bool = False, **kwargs
) -> Library:
"""
Delete a library and optionally delete its associated files from the filesystem.
For writable libraries, can optionally delete the root directory and all its contents.
Read-only libraries cannot have their files deleted through this method, as they are
typically external or managed elsewhere.
Args:
item_id: The ID of the library to delete.
delete_files: If True, delete the library's root directory and all files.
Cannot be True for read-only libraries. Defaults to False.
**kwargs: Additional keyword arguments passed to parent delete method.
Returns:
The deleted Library entity.
Raises:
ValueError: If delete_files=True is specified for a read-only library.
DirectoryDoesNotExist: If the library's root directory does not exist when attempting
to delete files.
OSError: If file deletion fails due to filesystem errors.
"""
raise NotImplementedError()
async def update(
self,
data: LibraryUpdate,
item_id: int | None = None,
**kwargs,
) -> Library:
"""
Update a library's configuration and optionally move its files to a new root path.
If the root_path is being changed and the library is not read-only, migrates all
associated book files to the new location. Read-only libraries cannot have their
files moved, and attempting to change their root_path will raise an error.
Args:
item_id: The ID of the library to update.
data: LibraryUpdate configuration with updated library fields.
**kwargs: Additional keyword arguments passed to parent update method.
Returns:
The updated Library entity.
Raises:
ValueError: If attempting to change root_path for a read-only library.
DirectoryDoesNotExist: If the current library root directory does not exist
when attempting to migrate files.
OSError: If file migration fails due to filesystem errors.
"""
raise NotImplementedError()

View File

@@ -0,0 +1,533 @@
# src/chitai/services/metadata_extractor.py
# TODO: Code is a mess. Clean it up and add docstrings
# Standard library
from abc import ABC, abstractmethod
import datetime
from pathlib import Path
from io import BytesIO
import re
from tempfile import SpooledTemporaryFile
from typing import Any, Protocol, BinaryIO
import logging
# Third-party libraries
import PIL
import PIL.Image
from litestar.datastructures import UploadFile
import pypdfium2
import ebooklib
from ebooklib import epub
# Local imports
from chitai.services.utils import (
create_temp_file,
get_file_extension,
get_filename,
is_valid_isbn,
)
logger = logging.getLogger(__name__)
class FileExtractor(Protocol):
@classmethod
async def extract_metadata(
cls, input: UploadFile | BinaryIO | bytes | Path | str
) -> dict[str, Any]: ...
@classmethod
async def extract_text(cls, input: UploadFile | BinaryIO | bytes | Path | str) -> str: ...
class Extractor:
"""Default extractor class that delegates to the proper extractor based on file type."""
format_priorities = {"epub": 1, "pdf": 2}
@classmethod
async def extract_metadata(cls, files: list[UploadFile] | list[Path], root_path: Path | None = None) -> dict[str, Any]:
metadata = {}
# Sort based on file priority
# EPUB tends to give better metadata results over pdf
sorted_files = sorted(files, key=lambda f: Extractor._get_file_priority(f))
for file in sorted_files:
match get_file_extension(file):
case "epub":
metadata = metadata | await EpubExtractor.extract_metadata(file)
case "pdf":
metadata = metadata | await PdfExtractor.extract_metadata(file)
case _:
break
# Get metadata from file names
for file in files:
metadata = FilenameExtractor.extract_metadata(file) | metadata
# Get metadata from filepath
metadata = metadata | FilepathExtractor.extract_metadata(files[0], root_path)
# format the title
if metadata.get('title', None):
title, subtitle = Extractor.format_book_title(metadata["title"])
metadata["title"] = title
metadata["subtitle"] = subtitle
return metadata
@classmethod
async def extract_text(cls, input: BinaryIO | bytes | Path | str) -> str: ...
@classmethod
def format_book_title(cls, title: str) -> tuple[str, str | None]:
# Convert all trailing underscores to colons
title = Extractor.convert_underscores_to_colons(title)
colon_count = title.count(":")
if colon_count < 2:
return title, None # No subtitle
# There is a subtitle, split the string at the second colon (colon is disregarded)
first_colon_idx = title.find(":")
second_colon_idx = title.find(
":", first_colon_idx + 1
) # Start search after first colon
return (title[:second_colon_idx], title[second_colon_idx + 1 :])
@classmethod
def _get_file_priority(cls, upload: UploadFile | Path) -> float:
filename = ""
if isinstance(upload, UploadFile):
filename = Path(upload.filename)
elif isinstance(upload, Path):
filename = upload.name
file_ext = get_file_extension(filename)
if file_ext is None:
return float('inf')
return Extractor.format_priorities.get(file_ext, float("inf"))
@classmethod
def convert_underscores_to_colons(cls, text: str) -> str:
"""
Converts underscores that trail a word and are followed by a space and another word to colons.
Only converts when there's a space between words.
For example:
- "hello_ world" -> "hello: world"
- "this_ is_ a_ test" -> "this: is: a: test"
- "hello_world" -> "hello_world" (unchanged, no space)
- "word_" -> "word_" (unchanged, no following word)
Args:
text (str): The input text to be processed
Returns:
str: The processed text with underscores converted to colons
"""
# This pattern matches a word followed by underscore(s), followed by a space, followed by another word
pattern = r"(\w+)(_+)(\s+)(\w+)"
# Replace all occurrences with word1: word2 (preserving the space)
result = re.sub(pattern, r"\1:\3\4", text)
return result
class PdfExtractor(FileExtractor):
"""Extract metadata and text from a PDF file."""
@classmethod
async def extract_metadata(
cls, data: UploadFile | SpooledTemporaryFile | BinaryIO | bytes | Path | str
) -> dict[str, Any]:
"""Extracts metadata from a PDF file."""
metadata = {}
if isinstance(data, UploadFile):
data = data.file
doc = pypdfium2.PdfDocument(data)
basic_metadata = doc.get_metadata_dict(skip_empty=False)
metadata["title"] = basic_metadata["Title"]
metadata["authors"] = PdfExtractor._parse_authors(basic_metadata["Author"])
metadata["pages"] = len(doc)
# Use the documents creation date as a guess for the date published
metadata["published_date"] = PdfExtractor._parse_pdf_date(
basic_metadata["CreationDate"]
)
metadata["cover_image"] = PdfExtractor._extract_cover(doc)
text_by_page = PdfExtractor._extract_text_from_pages(doc, 6)
metadata["identifiers"] = PdfExtractor._extract_isbns("".join(text_by_page))
doc.close()
return {key: val for key, val in metadata.items() if val} # Remove empty values
@classmethod
async def extract_text(cls, data: BinaryIO | bytes | Path | str) -> str:
raise NotImplementedError() # TODO: Implement text extraction from PDF files
@classmethod
def _parse_authors(cls, author_string: str):
"""Extracts author data from the authors string based on common formats."""
# Remove trailing delimiters and whitespace
author_string = author_string.strip("; &").strip()
if not author_string:
return []
try:
# Determine the format based on presence of commas
if "," in author_string:
# Handle LastName, FirstName; LastName, FirstName format
authors = [a.strip() for a in author_string.split(";") if a.strip()]
formatted_authors = []
for author in authors:
last_name, first_name = [name.strip() for name in author.split(",")]
formatted_authors.append(f"{first_name} {last_name}")
elif "&" in author_string:
# Handle FirstName LastName & FirstName LastName format
authors = [a.strip() for a in author_string.split("&") if a.strip()]
formatted_authors = authors
elif ";" in author_string:
# Handle FirstName LastName; FirstName LastName format
authors = [a.strip() for a in author_string.split(";") if a.strip()]
formatted_authors = authors
else:
return [author_string]
return formatted_authors
except Exception:
return []
@classmethod
def _parse_pdf_date(cls, date_string: str | None) -> datetime.date | None:
if date_string is None:
return None
# Remove "D:" prefix if present
if date_string.startswith("D:"):
date_string = date_string[2:]
# Extract just the date portion (YYYYMMDD) and parse it
date_portion = date_string[:8] # Just take first 8 characters (YYYYMMDD)
try:
return datetime.datetime.strptime(date_portion, "%Y%m%d").date()
except Exception as e:
return None
@classmethod
def _extract_text_from_pages(cls, doc: pypdfium2.PdfDocument, num_pages=5):
num_pages = min(num_pages, len(doc))
text_by_page = []
for i in range(num_pages):
text_by_page.append(doc[i].get_textpage().get_text_bounded())
return text_by_page
@classmethod
def _extract_isbns(cls, text: str) -> dict[str, str]:
isbn_pattern = (
r"(\b\d{9}[\dXx]\b|\b\d{1,5}-?\d{1,7}-?\d{1,6}-?\d{1,6}-?[\dXx]\b)"
)
isbns = {}
matches = re.findall(isbn_pattern, text)
for match in matches:
# Strip hyphens if present
isbn = match.replace("-", "")
if not is_valid_isbn(isbn):
continue
if len(isbn) == 10:
isbns.update({"isbn-10": isbn})
elif len(isbn) == 13:
isbns.update({"isbn-13": isbn})
if len(isbns) >= 2: # Exit after both isbns have been matched
break
return isbns
@classmethod
def _extract_cover(cls, doc: pypdfium2.PdfDocument) -> PIL.Image.Image | None:
page = doc[0]
bitmap: pypdfium2.PdfBitmap = page.render()
image = bitmap.to_pil()
return image
class EpubExtractor(FileExtractor):
"""Extract metadata and text from an EPUB file."""
@classmethod
async def extract_metadata(
cls, input: UploadFile | BinaryIO | bytes | Path | str
) -> dict[str, Any]:
"""Extract metadata from an EPUB file.
Args:
input: Can be an UploadFile, file-like object, bytes, Path, or string path
Returns:
Dictionary containing extracted metadata
"""
# Handle different input types
if isinstance(input, UploadFile):
await input.seek(0)
data = await input.read()
path = await create_temp_file(data)
elif isinstance(input, bytes):
path = await create_temp_file(input)
elif isinstance(input, BinaryIO):
input.seek(0)
data = input.read()
path = await create_temp_file(data)
elif isinstance(input, (Path, str)):
path = Path(input) if isinstance(input, str) else input
else:
raise TypeError(f"Unsupported input type: {type(input)}")
try:
metadata = EpubExtractor._get_metadata(path)
except Exception as e:
logger.error(f"Error extracting metadata from epub: {e}")
return {}
return metadata
@classmethod
def extract_text(cls, input: BinaryIO | bytes | Path | str) -> str:
raise NotImplementedError() # TODO: Implement text extraction from EPUB files
@classmethod
def _get_metadata(cls, file_path) -> dict[str, Any]:
"""Uses ebooklib to extract metadata from an EPUB file.
TODO: Some EPUBs fail to parse due to missing files, this is an issue with ebooklib;
more info and possible solution here: https://github.com/aerkalov/ebooklib/issues/281
"""
book = epub.read_epub(file_path)
metadata = {}
metadata["title"] = book.get_metadata("DC", "title")[0][0]
metadata["identifiers"] = EpubExtractor._extract_identifiers(book)
metadata["language"] = book.get_metadata("DC", "language")[0][0]
metadata["published_date"] = EpubExtractor._extract_published_date(book)
metadata["description"] = EpubExtractor._extract_description(book)
metadata["publisher"] = EpubExtractor._extract_publisher(book)
metadata["authors"] = EpubExtractor._extract_authors(book)
metadata["cover_image"] = EpubExtractor._extract_cover(book)
metadata = {
key: val for key, val in metadata.items() if val
} # Remove all empty values
return metadata
@classmethod
def _extract_authors(cls, epub: epub.EpubBook) -> list[str]:
authors = []
for creator in epub.get_metadata("DC", "creator"):
authors.append(creator[0])
return authors
@classmethod
def _extract_identifiers(cls, epub: epub.EpubBook) -> dict[str, str]:
identifiers = {}
for id in epub.get_metadata("DC", "identifier"):
if is_valid_isbn(id[0]):
if len(id[0]) == 13:
identifiers.update({"isbn-13": id[0]})
elif len(id[0]) == 10:
identifiers.update({"isbn-10": id[0]})
return identifiers
@classmethod
def _extract_description(cls, epub: epub.EpubBook) -> str | None:
try:
return epub.get_metadata("DC", "description")[0][0]
except:
return None
@classmethod
def _extract_published_date(cls, epub: epub.EpubBook) -> datetime.date | None:
try:
date_str = epub.get_metadata("DC", "date")[0][0].split("T")[0]
return datetime.date.fromisoformat(date_str)
except:
return None
@classmethod
def _extract_publisher(cls, epub: epub.EpubBook) -> str | None:
try:
epub.get_metadata("DC", "publisher")[0][0]
except:
return None
@classmethod
def _extract_cover(cls, epub: epub.EpubBook) -> PIL.Image.Image | None:
"""Extract cover image from EPUB."""
# Strategy 1: Check for cover metadata (most common)
try:
cover_meta = epub.get_metadata("OPF", "cover")
if cover_meta:
# Try to get the cover ID from metadata
for meta in cover_meta:
if isinstance(meta, tuple) and len(meta) > 1:
# Handle both dict and direct string references
cover_id = (
meta[1].get("content")
if isinstance(meta[1], dict)
else meta[1]
)
if cover_id:
cover_item = epub.get_item_with_id(cover_id)
if cover_item:
return PIL.Image.open(BytesIO(cover_item.content))
except Exception as e:
pass # Fallback to next strategy
# Strategy 2: Search image filenames for "cover" keyword
images = list(epub.get_items_of_type(ebooklib.ITEM_IMAGE))
for image in images:
filename = (
str(image.get_name()).casefold()
if hasattr(image, "get_name")
else str(image).casefold()
)
if "cover" in filename:
return PIL.Image.open(BytesIO(image.content))
return None
class FilepathExtractor(FileExtractor):
"""Extracts metadata from the filepath."""
@classmethod
def extract_metadata(cls, input: UploadFile | Path | str, root_path: Path | None = None) -> dict[str, Any]:
if isinstance(input, UploadFile):
path = Path(input.filename).parent
else:
path = Path(input).parent
if root_path:
path = path.relative_to(root_path)
parts = path.parts
metadata: dict[str, str | None] = {}
if len(parts) == 3:
# Format: Author/Series/Part - Title/filename
metadata['author'] = parts[0]
# Extract part number and title from directory name (parts[2])
dirname = parts[2]
match = re.match(r'^([\d.]+)\s*-\s*(.+)$', dirname)
if match:
metadata['series_position'] = match.group(1) # Keep as string
metadata['series'] = parts[1]
metadata['title'] = match.group(2).strip()
else:
metadata['series'] = parts[1]
metadata['title'] = path.stem
elif len(parts) == 2:
# Format: Author/Title
metadata['author'] = parts[0]
metadata['title'] = path.stem # Remove extension
return metadata
class FilenameExtractor(FileExtractor):
"""Extracts metadata from the filename."""
@classmethod
def extract_metadata(cls, input: Path | str | UploadFile) -> dict[str, Any]:
if isinstance(input, str):
filename = get_filename(Path(input), ext=False)
elif isinstance(input, Path):
filename = get_filename(input, ext=False)
elif isinstance(input, UploadFile):
filename = Path(input.filename).name
else:
raise ValueError("Input type not supported")
# Common patterns to split title and authors
patterns = [
r"^(.+?)\s*[-]\s*([^-]+)$", # Pattern with hyphen separator
r"^(.+?)\s*by\s+([^-]+)$", # Pattern with "by" separator
r"^([^:]+):\s*(.+)$", # Pattern with colon separator
]
title = filename
authors_str = ""
# Try each pattern until we find a match
for pattern in patterns:
match = re.match(pattern, filename, re.IGNORECASE)
if match:
title, authors_str = match.groups()
break
# Clean up title
title = re.sub(r"\s*\d+[^a-zA-Z]+.*$", "", title) # Remove edition numbers
title = re.sub(r"\s*\([^)]*\)", "", title) # Remove parentheses
title = title.strip()
# Parse authors
authors = []
if authors_str:
# Split on common author separators
author_parts = re.split(r"[,_]|\sand\s|\s&\s", authors_str)
for author in author_parts:
# Clean up each author name
author = author.strip()
if author:
# Remove any remaining parenthetical content
author = re.sub(r"\s*\([^)]*\)", "", author)
# Remove titles and other common prefixes
author = re.sub(
r"^(MD\.Dr\.|Mr\.|Mrs\.|Ms\.|Prof\.)\s+", "", author
)
authors.append(author.strip())
return {"title": title, "authors": authors}
@classmethod
def extract_text(cls, input: Path | str):
raise NotImplementedError() # Obviously can't extract text from filename...

View File

@@ -0,0 +1,178 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Literal, Optional, Sequence
from pydantic import BaseModel, Field
from datetime import datetime
class LinkTypes(StrEnum):
NAVIGATION = "application/atom+xml;profile=opds-catalog;kind=navigation"
ACQUISITION = "application/atom+xml;profile=opds-catalog;kind=acquisition"
OPEN_SEARCH = "application/opensearchdescription+xml"
class AcquisitionRelations(StrEnum):
_BASE = "http://opds-spec.org/acquisition"
ACQUISITION = _BASE # A generic relation that indicates that the entry may be retrieved
OPEN_ACCESS = f"{_BASE}/open-access" # Entry may be retrieved without any requirement
BORROW = f"{_BASE}/borrow" # Entry may be retrieved as part of a lending transaction
BUY = f"{_BASE}/buy" # Entry may be retrieved as part of a purchase
SAMPLE = f"{_BASE}/sample" # Subset of the entry may be retrieved
PREVIEW = f"{_BASE}/preview" # Subset of the entry may be retrieved
SUBSCRIBE = f"{_BASE}/subscribe" # Entry my be retrieved as a part of a subscription
class NavigationRelations(StrEnum):
_BASE = ""
class LinkRelations(StrEnum):
""" Link types for OPDSv1.2 related resources
https://specs.opds.io/opds-1.2.html#6-additional-link-relations
"""
_BASE = "http://opds-spec.org"
START = "start" # The OPDS catalog root
SUBSECTION = "subsection" # an OPDS feed not better described by any of the below relations
SHELF = f"{_BASE}/shelf" # Entries acquired by the euser
SUBSCRIPTIONS = f"{_BASE}/subscriptions" # Entries available with users's subscription
NEW = f"{_BASE}/sort/new" # Newest entries
POPULAR = f"{_BASE}/sort/popular" # Most popular entries
FEATURED = f"{_BASE}/featured" # Entries selected for promotion
RECOMMENDED = f"{_BASE}/recommended" # Entries recommended to the specific user
class Feed(BaseModel): # OPDS Catalog root element
xmlns: Literal["http://www.w3.org/2005/Atom"] = Field(
default="http://www.w3.org/2005/Atom", serialization_alias="@xmlns"
)
id: str
link: list["Link"]
title: str
updated: datetime
author: Optional["Author"] = None
entry: list["Entry"] = Field(default_factory=list)
def model_dump(self, **kwargs) -> dict:
data = super().model_dump(**kwargs)
return {"feed": data}
class NavigationFeed(Feed): ...
class AcquisitionFeed(Feed):
xmlns_dc: str | None = Field(
default="http://purl.org/dc/terms/", serialization_alias="@xmlns:dc"
)
xmlns_opds: str | None = Field(
default="http://opds-spec.org/2010/catalog", serialization_alias="@xmlns:opds"
)
class Link(BaseModel):
rel: str = Field(serialization_alias="@rel")
href: str = Field(serialization_alias="@href")
type: str = Field(serialization_alias="@type")
title: str | None = Field(default=None, serialization_alias="@title")
class AcquisitionLink(Link):
rel: str = Field(
default=AcquisitionRelations.OPEN_ACCESS, serialization_alias="@rel"
)
class ImageLink(Link):
rel: str = Field(default="http://opds-spec.org/image", serialization_alias="@rel")
class Author(BaseModel):
name: str
uri: str | None = None
class AtomCategory(BaseModel):
# https://datatracker.ietf.org/doc/html/rfc4287#section-4.2.2
# MUST have a term attribute
term: str = Field() # The category to which the entry or feed belongs to
# MAY have scheme and label attributes
scheme: str | None = Field(
default=None, serialization_alias="@scheme"
) # IRI that identifies a categorization scheme
label: str | None = Field(
default=None, serialization_alias="@label"
) # Human readable label for display in end-user applications
class AcquisitionFeedLink(Link):
rel: str = Field(default=LinkRelations.SUBSECTION, serialization_alias="@rel")
type: str = Field(
default=LinkTypes.ACQUISITION,
serialization_alias="@type",
)
class NavigationFeedLink(Link):
type: str = Field(
default=LinkTypes.NAVIGATION, serialization_alias="@type"
)
class Content(BaseModel):
type: Literal["text"] = Field(default="text", serialization_alias="@type")
text: str = Field(serialization_alias="#text")
class Summary(BaseModel):
type: Literal["text"] = Field(default="text", serialization_alias="@type")
text: str = Field(serialization_alias="#text")
class Entry(BaseModel):
# https://specs.opds.io/opds-1.2.html#5-opds-catalog-entry-documents
# TODO Follow for RFC 3339 date formats
# MUST have these fields
id: str
title: str
updated: datetime = Field(default_factory=datetime.now)
# SHOULD have these fields
dc_identifier: list[str] | None = Field(
default=None, serialization_alias="dc:identifier"
)
dc_issued: str | None = Field(
default=None, serialization_alias="dc:issued"
) # Publication date of the book
author: list[Author] | None = None
category: str | None = Field(default=None)
rights: str | None = Field(default=None)
link: Sequence["Link"] = Field(default_factory=list)
summary: str | None = None
content: Optional["Content"] = None
# MAY have these fields
contributor: str | None = (
None # represent contributors to the Publication beside its creators.
)
published: str | None = (
None # indicates when the OPDS Catalog Entry was first accessible
)
dc_language: str | None = Field(default=None, serialization_alias="dc:language")
dc_publisher: str | None = Field(default=None, serialization_alias="dc:publisher")
def model_dump(self, **kwargs) -> dict[str, Any]:
data = super().model_dump(**kwargs)
return {"entry": data}
@dataclass
class PaginationResult:
next_link: Optional[Link]
prev_link: Optional[Link]
current_offset: int
total_count: int

View File

@@ -0,0 +1,319 @@
from typing import Any, Callable, Sequence
from urllib.parse import quote_plus, urlencode
from litestar import Request
import xmltodict
from datetime import datetime
from chitai.database import models as m
from .models import (
Author,
Entry,
Link,
LinkTypes,
ImageLink,
AcquisitionLink,
AcquisitionFeed,
AcquisitionFeedLink,
NavigationFeed,
NavigationFeedLink,
PaginationResult
)
def get_opensearch_document(base_url: str = "/opds/search?") -> str:
search = {
"OpenSearchDescription": {
"@xmlns": "http://a9.com/-/spec/opensearch/1.1/",
"Url": {
"@type": "application/atom+xml;profile=opds-catalog;kind=acquisition",
"@template": base_url + "searchString={searchTerms}",
},
}
}
return xmltodict.unparse(search)
def convert_book_to_entry(book: m.Book) -> Entry:
return Entry(
id=str(book.id),
title=book.title,
updated=book.updated_at,
author=[Author(name=author.name) for author in book.authors],
dc_identifier=[f"{id.name}:{id.value}" for id in book.identifiers],
dc_publisher=book.publisher.name if book.publisher is not None else None,
dc_language=book.language,
dc_issued=str(book.published_date),
summary=book.description,
link=[
ImageLink(href=f"/{book.cover_image}", type="image/webp"),
*[
AcquisitionLink(
href=f"/opds/download/{book.id}/{file.id}", type=file.content_type
)
for file in book.files
],
],
)
def create_acquisition_feed(
id: str,
title: str,
url: str,
books: Sequence[m.Book],
updated: datetime = datetime.now(),
links: Sequence[Link] = list(),
) -> str:
feed = AcquisitionFeed(
id=id,
title=title,
updated=updated,
link=[
Link(
rel="self",
href=url,
type="application/atom+xml;profile=opds-catalog;kind=acquisition",
),
*links,
],
entry=[convert_book_to_entry(book) for book in books],
)
return xmltodict.unparse(
feed.model_dump(by_alias=True, exclude_none=True, exclude_defaults=False),
pretty=True,
)
def create_navigation_feed(
id: str,
title: str,
self_url: str,
entries: list[Entry] = list(),
links: Sequence[Link] = list(),
updated: datetime = datetime.now(),
parent_url: str | None = None,
) -> str:
self_link = Link(rel="self", href=self_url, type=LinkTypes.NAVIGATION)
parent_link = Link(
rel="up" if parent_url else "start",
href=parent_url if parent_url else self_url,
type=LinkTypes.NAVIGATION,
)
feed = NavigationFeed(
id=id,
link=[self_link, parent_link, *links],
title=title,
updated=updated,
entry=entries,
)
return xmltodict.unparse(
feed.model_dump(by_alias=True, exclude_none=True, exclude_defaults=False),
pretty=True,
)
def create_library_navigation_feed(library: m.Library) -> str:
entries = [
Entry(
id=f"/opds/library/{library.id}/all-books",
title='All Books',
link=[
AcquisitionFeedLink(
rel="subsection",
href=f"/opds/acquisition?libraries={library.id}&paginated=1&pageSize=50&feed_title=AllBooks&feed_id=/opds/library/{library.id}/all-books",
title="All Books",
)
]
),
Entry(
id=f"/opds/library/{library.id}/recently-added",
title="Recently Added",
link=[
NavigationFeedLink(
rel="http://opds-spec.org/sort/new",
href=f"/opds/acquisition?libraries={library.id}&orderBy=created_at&pageSize=50&feed_title=RecentlyAdded&feed_id=/opds/library/{library.id}/recently-added",
title="Recently Added"
)
]
),
Entry(
id=f"/opds/library/{library.id}/shelves",
title="Bookshelves",
link=[
NavigationFeedLink(
rel="subsection",
href=f"/opds/library/{library.id}/shelves?paginated=1&pageSize=10",
title="Bookshelves"
)
]
),
Entry(
id=f"/opds/library/{library.id}/tags",
title="Tags",
link=[
NavigationFeedLink(
rel="subsection",
href=f"/opds/library/{library.id}/tags?paginated=1&pageSize=10",
title="Tags"
)
]
),
Entry(
id=f"/opds/library/{library.id}/authors",
title="Authors",
link=[
NavigationFeedLink(
rel="subsection",
href=f"/opds/library/{library.id}/authors?paginated=1&pageSize=10",
title="Authors"
)
]
),
Entry(
id=f"/opds/library/{library.id}/publishers",
title="Publishers",
link=[
NavigationFeedLink(
rel="subsection",
href=f"/opds/library/{library.id}/publishers?paginated=1&pageSize=10",
title="Publishers"
)
]
),
]
feed = create_navigation_feed(
id=f'/library/{library.id}',
title=library.name,
self_url=f'/opds/library/{library.id}',
links=[
],
entries=entries
)
return feed
def create_collection_navigation_feed(
library: m.Library,
collection_type: str,
items: Sequence[m.BookList | m.Tag | m.Author | m.Publisher | m.BookSeries],
links: list[Link] = list(),
# Title is usually derived from the model's name or title
get_title: Callable[[Any], str] = lambda x: getattr(x, 'title', getattr(x, 'name', str(x)))
) -> str:
entries = [
Entry(
id=f"/opds/library/{library.id}/{collection_type}/{item.id}",
title=get_title(item),
link=[
AcquisitionFeedLink(
href=f"/opds/acquisition?{collection_type}={item.id}&pageSize=50&paginated=1&feed_title={quote_plus(get_title(item))}&feed_id=/opds/library/{library.id}/{collection_type}/{item.id}&search=True",
title=get_title(item),
)
]
) for item in items
]
return create_navigation_feed(
id=f"/opds/library/{library.id}/{collection_type}",
title=collection_type.title(),
self_url=f'/opds/library/{library.id}/{collection_type}',
entries=entries,
links=links
)
def create_next_paginated_link(
request: Request,
total: int,
current_count: int,
offset: int,
feed_title: str
) -> Link | None:
if total <= current_count + offset:
return None
params = dict(request.query_params)
params['currentPage'] = params.get('currentPage', 1) + 1
next_url = f"{request.url.path}?{urlencode(list(params.items()), doseq=True)}"
return Link(
rel="next",
href=next_url,
title=feed_title,
type=LinkTypes.NAVIGATION
)
def create_search_link(
request: Request,
exclude_params: set[str] | None = None
) -> Link:
"""Create search link with current filters applied"""
if exclude_params is None:
exclude_params = {'currentPage', 'feed_title', 'feed_id', 'search', 'paginated'}
params = {
k: v for k, v in request.query_params.items()
if k not in exclude_params
}
return Link(
rel="search",
href=f"/opds/opensearch?{urlencode(list(params.items()), doseq=True)}",
type=LinkTypes.OPEN_SEARCH,
title="Search books",
)
def create_pagination_links(
request: Request,
total: int,
limit: int,
offset: int,
feed_title: str = "",
link_type: str = LinkTypes.ACQUISITION,
) -> PaginationResult:
"""Create next/prev pagination links using limit/offset"""
next_link = None
prev_link = None
# Create next link if there are more items
if offset + limit < total:
params = dict(request.query_params)
# Calculate next page number
current_page = (offset // limit) + 1
params['currentPage'] = current_page + 1
next_url = f"{request.url.path}?{urlencode(params, doseq=True)}"
next_link = Link(
rel="next",
href=next_url,
title=f"{feed_title} - Next",
type=link_type
)
# Create previous link if not on first page
if offset > 0:
params = dict(request.query_params)
# Calculate previous page number
current_page = (offset // limit) + 1
params['currentPage'] = max(1, current_page - 1)
prev_url = f"{request.url.path}?{urlencode(params, doseq=True)}"
prev_link = Link(
rel="previous",
href=prev_url,
title=f"{feed_title} - Previous",
type=link_type
)
return PaginationResult(next_link, prev_link, offset, total)

View File

@@ -0,0 +1,19 @@
# src/chitai/services/publisher.py
# Third-party libraries
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
# Local Imports
from chitai.database.models import Publisher
class PublisherService(SQLAlchemyAsyncRepositoryService[Publisher]):
"""Publisher service for managing publisher operations."""
class Repo(SQLAlchemyAsyncRepository[Publisher]):
"""Publisher repository."""
model_type = Publisher
repository_type = Repo

View File

@@ -0,0 +1,19 @@
# src/chitai/services/tag.py
# Third-party libraries
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
# Local Imports
from chitai.database.models import Tag
class TagService(SQLAlchemyAsyncRepositoryService[Tag]):
"""Tag service for managing tag operations."""
class Repo(SQLAlchemyAsyncRepository[Tag]):
"""Tag repository."""
model_type = Tag
repository_type = Repo

View File

@@ -0,0 +1,76 @@
# src/chitai/services/user.py
# Standard library
import logging
# Third-party library
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from litestar.exceptions import PermissionDeniedException
from litestar.security.jwt import Token
# Local imports
from chitai.database.models.user import User
logger = logging.getLogger(__name__)
class UserService(SQLAlchemyAsyncRepositoryService[User]):
"""User service for managing user operations."""
class Repo(SQLAlchemyAsyncRepository[User]):
model_type = User
repository_type = Repo
async def authenticate(self, email: str, password: str) -> User:
"""
Authenticate a user by email and password.
Verifies that the user exists and the provided password matches their argon2 hash.
Args:
email: The user's email address.
password: The plain text password to verify.
Returns:
The authenticated User entity.
Raises:
PermissionDeniedException: If the user does not exist or password is invalid.
Uses generic message to prevent user enumeration.
"""
user = await self.get_one_or_none(email=email)
if not user or not user.password.verify(password):
logger.warning(f"Login failed for: {email}")
raise PermissionDeniedException("User not found or password invalid.")
return user
async def get_user_by_token(self, token: Token) -> User:
"""
Retrieve a user by their authentication token.
Assumes the token has already been validated and decoded by the OAuth2 middleware.
This function only fetches the corresponding user from the database using the
email from the token's 'sub' claim.
Args:
token: The decoded authentication token (pre-validated by OAuth2 middleware)
containing the user's email in the 'sub' claim.
Returns:
The User entity associated with the token.
Raises:
PermissionDeniedException: If no user exists for the email in the token,
or if the user account is inactive.
"""
user = await self.get_one_or_none(email=token.sub)
if not user:
raise PermissionDeniedException("Invalid token.")
return user

View File

@@ -0,0 +1,358 @@
# src/chitai/services/utils.py
# Standard library
from pathlib import Path
import shutil
from typing import BinaryIO
# Third-party libraries
import PIL
import PIL.Image
import aiofiles
import aiofiles.os as aios
from litestar.datastructures import UploadFile
##################################
# Filesystem related utilities #
##################################
class DirectoryDoesNotExist(Exception):
"""Raised when an operation requires a directory that does not exist."""
...
async def directory_exists(dir_path: Path | str):
"""
Check if a directory exists.
Args:
dir_path: The directory path to check.
Returns:
True if the directory exists, False otherwise.
"""
return await aios.path.exists(dir_path)
async def create_directory(dir_path: Path | str) -> None:
"""
Create a directory if it doesn't already exist.
Prints a message indicating whether the directory was created or already existed.
Args:
dir_path: The directory path to create.
"""
await aios.makedirs(dir_path, exist_ok=True)
async def move_file(src_path: Path, dest_path: Path, create_dirs=True) -> None:
"""
Move a file from source to destination asynchronously.
Args:
source_path: Path to the source file
destination_path: Path to the destination file
"""
# Create destination directory if needed
if create_dirs:
dest_dir = dest_path.parent
if dest_dir: # Only create if there's a directory path
await aios.makedirs(dest_dir, exist_ok=True)
await aios.rename(src_path, dest_path)
async def move_dir_contents(source_dir: Path | str, target_dir: Path | str) -> None:
"""
Move all contents from source directory to target directory.
Creates the target directory if it doesn't exist. Raises an error if the
source directory does not exist or is not a directory.
Args:
source_dir: The directory to move contents from.
target_dir: The directory to move contents to.
Raises:
FileNotFoundError: If source_dir does not exist.
NotADirectoryError: If source_dir is not a directory.
"""
source_dir = Path(source_dir)
target_dir = Path(target_dir)
if not source_dir.exists():
raise FileNotFoundError(f"The directory {source_dir} does not exist")
if not source_dir.is_dir():
raise NotADirectoryError(f"The path {source_dir} is not a directory")
if not target_dir.exists():
target_dir.mkdir(parents=True, exist_ok=True)
for item in source_dir.iterdir():
shutil.move(item, target_dir / item.name)
# TODO: Refactor delete_directory to use async I/O for better performance in async contexts.
def delete_directory(directory_path: Path | str):
"""
Recursively delete a directory and all its contents.
TODO: Consider changing the implementation to async
Args:
directory_path: The directory to delete.
"""
shutil.rmtree(directory_path)
def cleanup_empty_parent_directories(path: Path, root_path: Path) -> None:
"""
Remove empty directories from path up to (but not including) root_path.
Walks up the directory tree from the given path, deleting each empty directory
until reaching the root_path or a non-empty directory.
Args:
path: The starting directory path.
root_path: The root directory to stop at (not included in deletion).
"""
# Ensure that the path is a subpath of the root
if not path.is_relative_to(root_path):
raise ValueError("Path must be a subdirectory of root path.")
while path != root_path:
if path.exists() and is_empty(path):
delete_directory(path)
path = path.parent
def is_empty(directory: Path | str):
"""
Check if a directory is empty.
Args:
directory: The directory path to check.
Returns:
True if the directory is empty, False otherwise.
Raises:
ValueError: If the path does not exist or is not a directory.
"""
path = Path(directory)
if not path.exists() or not path.is_dir():
raise ValueError(f"{directory} is not a valid directory")
return not any(path.iterdir())
async def delete_file(filepath: Path | str):
"""
Idempotently delete a file.
Args:
filepath: The file path to delete.
"""
try:
await aios.remove(filepath)
except FileNotFoundError:
pass
async def create_temp_file(data: BinaryIO | bytes):
"""
Create a temporary file with the given data.
Args:
data: Binary data or file-like object to write to the temporary file.
Returns:
Path to the created temporary file.
"""
async with aiofiles.tempfile.NamedTemporaryFile(
"wb+", delete_on_close=False, delete=False
) as file:
if isinstance(data, bytes):
await file.write(data)
else:
# Handle file-like objects
data.seek(0)
file_data = data.read()
await file.write(file_data)
file_path = str(file.name)
return Path(file_path)
async def save_file(data: BinaryIO | bytes, filepath: Path) -> None:
"""
Save binary data to a file.
Handles both raw bytes and file-like objects. If data is a file-like object,
it will be read and the file pointer will be reset to the beginning.
Args:
data: Binary data or file-like object to save.
filepath: The file path where data will be saved.
"""
async with aiofiles.open(filepath, "wb+") as f:
if not isinstance(data, bytes):
await data.seek(0)
data = await data.read()
await f.write(data)
await data.seek(0)
async def save_image(
img: PIL.Image.Image, filepath: Path, image_type: str = "WEBP"
) -> None:
"""
Save a PIL Image to a file.
Args:
img: The PIL Image object to save.
filepath: The file path where the image will be saved.
image_type: The image format to save in (default: "WEBP").
"""
img.save(filepath, image_type)
##################################
# Path and filename operations #
##################################
def get_file_extension(file: Path | str | UploadFile) -> str | None:
"""
Extract the file extension from a file path.
Returns the extension in lowercase without the leading dot.
Args:
file_path: The file path as a string, Path object, or UploadFile object.
Returns:
The file extension in lowercase (e.g., "pdf", "txt"), or None if the path is empty.
"""
if isinstance(file, str):
return Path(file).suffix.lower()[1:]
elif isinstance(file, Path):
return file.suffix.lower()[1:]
elif isinstance(file, UploadFile):
return Path(file.filename).suffix.lower()[1:]
raise ValueError("file object type is not supported")
def get_filename(file: Path | str, ext: bool = True) -> str:
"""
Extract the filename from a file path.
Args:
file_path: The file path as a string or Path object.
ext: If True (default), returns the full filename with extension.
If False, returns just the stem (filename without extension).
Returns:
The filename, or None if the path is empty.
"""
if isinstance(file, UploadFile):
filename = Path(file.filename)
elif isinstance(file, str):
filename = Path(Path(file).name)
elif isinstance(file, Path):
filename = Path(file.name)
else:
raise ValueError("file object type is not supported")
if ext:
return str(filename)
return filename.stem
###############################
# ISBN Validation utilities #
###############################
def is_valid_isbn(isbn: str) -> bool:
"""
Validate an ISBN-10 or ISBN-13 number.
Automatically detects the ISBN format based on length and validates accordingly.
Args:
isbn: The ISBN string to validate.
Returns:
True if the ISBN is valid, False otherwise.
"""
try:
if len(isbn) == 10:
return is_valid_isbn10(isbn)
elif len(isbn) == 13:
return is_valid_isbn13(isbn)
else:
return False
except:
return False
def is_valid_isbn10(isbn: str) -> bool:
"""
Validate an ISBN-10 number using its check digit.
ISBN-10 uses a weighted sum where each digit is multiplied by its position
(10 down to 1) and the result modulo 11 determines the check digit.
Args:
isbn: The 10-character ISBN string to validate.
Returns:
True if the ISBN-10 is valid, False otherwise.
"""
sum = 0
for i in range(9):
sum += int(isbn[i]) * (10 - i)
check_digit = 11 - (sum % 11)
return str(check_digit) == isbn[-1] or (check_digit == 10 and isbn[-1] in "Xx")
def is_valid_isbn13(isbn: str) -> bool:
"""
Validate an ISBN-13 number using its check digit.
ISBN-13 uses a weighted sum where digits alternate between multipliers of 1 and 3,
and the result modulo 10 determines the check digit.
Args:
isbn: The 13-character ISBN string to validate.
Returns:
True if the ISBN-13 is valid, False otherwise.
"""
sum = 0
for i in range(12):
sum += int(isbn[i]) * (1 if i % 2 == 0 else 3)
check_digit = 10 - (sum % 10)
return str(check_digit % 10) == isbn[-1]

View File

257
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,257 @@
from pathlib import Path
from uuid import uuid4
import pytest
from advanced_alchemy.base import UUIDAuditBase
from litestar.testing import AsyncTestClient
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.engine import URL
from sqlalchemy.pool import NullPool
from chitai.database import models as m
from chitai.config import settings
from collections.abc import AsyncGenerator
from litestar import Litestar
from pytest_databases.docker.postgres import PostgresService
from sqlalchemy.orm import selectinload
pytest_plugins = [
"tests.data_fixtures",
"pytest_databases.docker",
"pytest_databases.docker.postgres",
]
# Set the environment to use the testing database
# os.environ.update(
# {
# "POSTGRES_DB": "chitai_testing"
# }
# )
@pytest.fixture(autouse=True)
def _patch_settings(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(settings, "book_cover_path", f"{tmp_path}/covers")
@pytest.fixture(name="engine")
async def fx_engine(
postgres_service: PostgresService,
) -> AsyncGenerator[AsyncEngine, None]:
"""PostgreSQL instance for testing"""
engine = create_async_engine(
URL(
drivername="postgresql+asyncpg",
username=postgres_service.user,
password=postgres_service.password,
host=postgres_service.host,
port=postgres_service.port,
database=postgres_service.database,
query={}, # type:ignore[arg-type]
),
echo=False,
poolclass=NullPool,
)
# Add pg_trgm extension
async with engine.begin() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
# Create all tables
metadata = UUIDAuditBase.registry.metadata
async with engine.begin() as conn:
await conn.run_sync(metadata.create_all)
yield engine
# Clean up
async with engine.begin() as conn:
await conn.run_sync(metadata.drop_all)
await engine.dispose()
@pytest.fixture(name="sessionmaker")
def fx_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
"""Create sessionmaker factory."""
return async_sessionmaker(bind=engine, expire_on_commit=False)
@pytest.fixture
async def session(
sessionmaker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[AsyncSession, None]:
"""Create database session for tests."""
async with sessionmaker() as session:
yield session
await session.rollback()
await session.close()
@pytest.fixture
async def app() -> Litestar:
"""Create Litestar app for testing."""
from chitai.app import create_app
return create_app()
@pytest.fixture
async def client(app: Litestar) -> AsyncGenerator[AsyncTestClient, None]:
"""Create test client."""
async with AsyncTestClient(app=app) as client:
yield client
@pytest.fixture
async def authenticated_client(
client: AsyncTestClient, test_user: m.User
) -> AsyncTestClient:
"""Create authenticated test client."""
# login and set auth headers
login_response = await client.post(
"access/login", data={"email": test_user.email, "password": "password123"}
)
assert login_response.status_code == 201
result = login_response.json()
token = result["access_token"]
client.headers.update({"Authorization": f"Bearer {token}"})
return client
@pytest.fixture
async def other_authenticated_client(
app: Litestar, test_user2: m.User
) -> AsyncGenerator[AsyncTestClient, None]:
"""Create second authenticated test client as different user."""
async with AsyncTestClient(app=app) as other_client:
login_response = await other_client.post(
"access/login", data={"email": test_user2.email, "password": "password234"}
)
assert login_response.status_code == 201
result = login_response.json()
token = result["access_token"]
other_client.headers.update({"Authorization": f"Bearer {token}"})
yield other_client
# Service fixtures
from chitai import services
@pytest.fixture
async def user_service(
sessionmaker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[services.UserService, None]:
"""Create UserService instance."""
async with sessionmaker() as session:
async with services.UserService.new(session) as service:
yield service
@pytest.fixture
async def library_service(
sessionmaker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[services.LibraryService, None]:
"""Create LibraryService instance."""
async with sessionmaker() as session:
async with services.LibraryService.new(session) as service:
yield service
@pytest.fixture
async def books_service(
sessionmaker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[services.BookService, None]:
"""Create BookService instance."""
async with sessionmaker() as session:
async with services.BookService.new(
session,
load=[
selectinload(m.Book.author_links).selectinload(m.BookAuthorLink.author),
selectinload(m.Book.tag_links).selectinload(m.BookTagLink.tag),
m.Book.publisher,
m.Book.files,
m.Book.identifiers,
m.Book.series,
],
) as service:
yield service
@pytest.fixture
async def bookshelf_service(
sessionmaker: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[services.ShelfService]:
"""Create ShelfService instance."""
async with sessionmaker() as session:
async with services.ShelfService.new(session) as service:
yield service
# Data fixtures
@pytest.fixture
async def test_user(session: AsyncSession) -> AsyncGenerator[m.User, None]:
"""Create a test user."""
unique_id = str(uuid4())[:8]
user = m.User(
email=f"user{unique_id}@example.com",
password="password123",
)
session.add(user)
await session.commit()
await session.refresh(user)
yield user
@pytest.fixture
async def test_user2(session: AsyncSession) -> AsyncGenerator[m.User, None]:
"""Create another test user."""
unique_id = str(uuid4())[:8]
user = m.User(
email=f"user{unique_id}@example.com",
password="password234",
)
session.add(user)
await session.commit()
await session.refresh(user)
yield user
@pytest.fixture
async def test_library(
session: AsyncSession, tmp_path: Path
) -> AsyncGenerator[m.Library, None]:
"""Create a test library."""
library = m.Library(
name="Testing Library",
slug="testing-library",
root_path=str(tmp_path),
path_template="{author_name}/{title}.{ext}",
read_only=False,
)
session.add(library)
await session.commit()
await session.refresh(library)
yield library

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@@ -0,0 +1,57 @@
from pathlib import Path
import pytest
from typing import Any
from chitai.database import models as m
@pytest.fixture(name="raw_users")
def fx_raw_users() -> list[m.User | dict[str, Any]]:
"""Unstructured user representations."""
return [
{"email": "user1@example.com", "password": "password123"},
{"email": "user2@example.com", "password": "password234"},
]
@pytest.fixture(name="raw_libraries")
def fx_raw_libraries(tmp_path: Path) -> list[m.Library | dict[str, Any]]:
"""Unstructured library representations."""
return [
{
"name": "Default Test Library",
"slug": "default-test-library",
"root_path": f"{tmp_path}/default_library",
"path_template": "{author}/{title}",
"read_only": False,
},
{
"name": "Test Textbook Library",
"slug": "test-textbook-library",
"root_path": f"{tmp_path}/textbooks",
"path_template": "{author}/{title}",
"read_only": False,
},
]
@pytest.fixture(name="raw_books")
def fx_raw_books() -> list[m.Book | dict[str, Any]]:
"""Unstructured book representations."""
return [
{
"library_id": 1,
"title": "The Fellowship of the Ring",
"path": "books/J.R.R Tolkien/Lord of the Rings/01 - The Fellowship of The Ring",
"pages": 427,
"authors": [{"name": "J.R.R Tolkien"}],
"tags": [{"name": "Fantasy"}, {"name": "Adventure"}],
"identifiers": {"isbn-13": "9780261102354"},
"series": {"name": "The Lord of the Rings"},
"series_position": "1",
}
]

View File

@@ -0,0 +1,157 @@
import json
from pathlib import Path
from httpx import AsyncClient
import pytest
from collections.abc import AsyncGenerator
from typing import Any
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
async_sessionmaker,
)
from advanced_alchemy.base import UUIDAuditBase
from chitai.database import models as m
from chitai import services
from chitai.database.config import config
@pytest.fixture(autouse=True)
def _patch_db(
engine: AsyncEngine,
sessionmaker: async_sessionmaker[AsyncSession],
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(config, "session_maker", sessionmaker)
monkeypatch.setattr(config, "engine_instance", engine)
@pytest.fixture(autouse=True)
async def _seed_db(
engine: AsyncEngine,
sessionmaker: async_sessionmaker[AsyncSession],
raw_users: list[m.User | dict[str, Any]],
raw_libraries: list[m.Library | dict[str, Any]],
) -> AsyncGenerator[None, None]:
"""Populate test database with.
Args:
engine: The SQLAlchemy engine instance.
sessionmaker: The SQLAlchemy sessionmaker factory.
raw_users: Test users to add to the database
raw_teams: Test teams to add to the database
"""
metadata = UUIDAuditBase.registry.metadata
async with engine.begin() as conn:
await conn.run_sync(metadata.drop_all)
await conn.run_sync(metadata.create_all)
async with services.UserService.new(sessionmaker()) as users_service:
await users_service.create_many(raw_users, auto_commit=True)
async with services.LibraryService.new(sessionmaker()) as library_service:
await library_service.create_many(raw_libraries, auto_commit=True)
yield
@pytest.fixture
async def populated_authenticated_client(
authenticated_client: AsyncClient, books_to_upload: list[tuple[Path, dict]]
) -> AsyncClient:
# Upload books
for path, data in books_to_upload:
await upload_book_via_api(authenticated_client, data, path)
await create_bookshelf(
authenticated_client, {"title": "Favourites", "library_id": 1}
)
await add_books_to_bookshelf(authenticated_client, 1, [1, 2])
return authenticated_client
@pytest.fixture
def books_to_upload() -> list[tuple[Path, dict]]:
return [
(
Path(
"tests/data_files/The Adventures of Sherlock Holmes - Arthur Conan Doyle.epub"
),
{
"library_id": "1",
"title": "The Adventures of Sherlock Holmes",
"description": "Some description...",
"authors": "Arthur Conan Doyle",
"publisher": "Some Publisher",
"tags": "Mystery",
"edition": "2",
"language": "en",
"pages": "300",
"series": "Sherlock Holmes",
"series_position": "1",
"identifiers": json.dumps({"isbn-10": "1234567890"}),
},
),
(
Path("tests/data_files/Frankenstein - Mary Shelley.epub"),
{
"library_id": "1",
"title": "Frankenstein",
"description": "Some description...",
"authors": "Mary Shelley",
"tags": "Mystery",
"edition": "1",
"language": "en",
"pages": "250",
},
),
(
Path("tests/data_files/A Tale of Two Cities - Charles Dickens.epub"),
{
"library_id": "1",
"title": "A Tale of Two Cities",
"description": "Some description...",
"authors": "Charles Dickens",
"tags": "Classic",
"edition": "1",
"language": "en",
"pages": "500",
},
),
]
async def upload_book_via_api(
client: AsyncClient, book_data: dict, book_file_path: Path
) -> None:
with book_file_path.open("rb") as file:
files = {"files": file}
response = await client.post("/books?library_id=1", data=book_data, files=files)
assert response.status_code == 201
async def create_bookshelf(client: AsyncClient, shelf_data: dict) -> None:
response = await client.post("/shelves", json=shelf_data)
assert response.status_code == 201
async def add_books_to_bookshelf(
client: AsyncClient, shelf_id: int, book_ids: list[int]
) -> None:
query_params = ""
for id in book_ids:
query_params += f"book_ids={id}&"
response = await client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": book_ids}
)
assert response.status_code == 201

View File

@@ -0,0 +1,91 @@
import pytest
from httpx import AsyncClient
from chitai.database import models as m
@pytest.mark.parametrize(
("email", "password", "expected_status_code"),
[
("user1@example.com", "password123", 201), # Valid credentials
("user1@example.com", "password234", 401), # Invalid password
("user2@example.com", "password234", 201), # Valid credentials
("user2@example.com", "password123", 401), # Invalid password
("nonexistentUser@example.com", "password123", 401), # Invalid email
],
)
async def test_user_login(
client: AsyncClient, email: str, password: str, expected_status_code: int
) -> None:
"""Test login functionality with valid and invalid credentials."""
response = await client.post(
"/access/login", data={"email": email, "password": password}
)
assert response.status_code == expected_status_code
if response.status_code == 201:
result = response.json()
assert result["access_token"] is not None
async def test_get_user_by_access_token(
authenticated_client: AsyncClient, test_user: m.User
) -> None:
"""Test getting user info via their access token."""
response = await authenticated_client.get("/access/me")
assert response.status_code == 200
result = response.json()
assert result["email"] == test_user.email
async def test_get_user_without_access_token(client: AsyncClient) -> None:
response = await client.get("/access/me")
assert response.status_code == 401
async def test_user_registration_weak_password(client: AsyncClient) -> None:
"""Test user registration with a weak password."""
response = await client.post(
"/access/signup", json={"email": "weak@example.com", "password": "weak"}
)
assert response.status_code == 400
msg = response.json()["extra"][0]["message"]
assert "Password must be at least 8 characters long" in msg
async def test_user_registration(client: AsyncClient) -> None:
"""Test registering a new user and successfully loggin in."""
user_data = {"email": "newuser@example.com", "password": "password123"}
signup_response = await client.post("/access/signup", json=user_data)
assert signup_response.status_code == 201
# Login using the same credentials
login_response = await client.post("/access/login", data=user_data)
assert login_response.status_code == 201
async def test_user_registration_with_duplicate_email(client: AsyncClient) -> None:
"""Test registerig a new user using a duplicate email."""
user_data = {"email": "user1@example.com", "password": "password12345"}
response = await client.post("/access/signup", json=user_data)
assert response.status_code == 409
result = response.json()
assert "A user with this email already exists" in result["detail"]

View File

@@ -0,0 +1,870 @@
import pytest
from httpx import AsyncClient
from pathlib import Path
@pytest.mark.parametrize(
("path_to_book", "library_id", "title", "authors"),
[
(
Path("tests/data_files/Metamorphosis - Franz Kafka.epub"),
1,
"Metamorphosis",
["Franz Kafka"],
),
(
Path("tests/data_files/Moby Dick; Or, The Whale - Herman Melville.epub"),
1,
"Moby Dick; Or, The Whale",
["Herman Melville"],
),
(
Path(
"tests/data_files/Relativity: The Special and General Theory - Albert Einstein.epub"
),
2,
"Relativity : the Special and General Theory",
["Albert Einstein"],
),
(
Path(
"/home/patrick/projects/chitai-api/tests/data_files/On The Origin of Species By Means of Natural Selection - Charles Darwin.epub"
),
2,
"On the Origin of Species By Means of Natural Selection / Or, the Preservation of Favoured Races in the Struggle for Life",
["Charles Darwin"],
),
(
Path("tests/data_files/Calculus Made Easy - Silvanus Thompson.pdf"),
2,
"The Project Gutenberg eBook #33283: Calculus Made Easy, 2nd Edition",
["Silvanus Phillips Thompson"],
),
],
)
async def test_upload_book_without_data(
authenticated_client: AsyncClient,
path_to_book: Path,
library_id: int,
title: str,
authors: list[str],
) -> None:
"""Test uploading a book file. Book information should be extracted from file."""
# Read file contents
file_content = path_to_book.read_bytes()
# Prepare multipart form data
files = [("files", (path_to_book.name, file_content, "application/epub+zip"))]
# The rest of the book data will be parsed from file
data = {
"library_id": library_id,
}
response = await authenticated_client.post(
f"/books?library_id={library_id}",
files=files,
data=data,
)
assert response.status_code == 201
book_data = response.json()
assert book_data["id"] is not None
assert book_data["title"] == title
assert book_data["library_id"] == library_id
# Check if authors were properly parsed
book_authors = [author["name"] for author in book_data["authors"]]
assert book_authors == authors
@pytest.mark.parametrize(
("path_to_book", "library_id", "title", "authors", "tags"),
[
(
Path("tests/data_files/Metamorphosis - Franz Kafka.epub"),
1,
"The Metamorphosis",
["Franz Kafka"],
["Psychological fiction"],
),
(
Path("tests/data_files/Moby Dick; Or, The Whale - Herman Melville.epub"),
1,
"Moby Dick",
["Herman Melville"],
["Classic Literature", "Whaling"],
),
(
Path(
"tests/data_files/Relativity: The Special and General Theory - Albert Einstein.epub"
),
2,
"Relativity: the Special and General Theory",
["Albert Einstein"],
["Physics", "Mathematics"],
),
(
Path(
"/home/patrick/projects/chitai-api/tests/data_files/On The Origin of Species By Means of Natural Selection - Charles Darwin.epub"
),
2,
"On the Origin of Species By Means of Natural Selection",
["Charles Darwin"],
["Biology"],
),
(
Path("tests/data_files/Calculus Made Easy - Silvanus Thompson.pdf"),
2,
"Calculus Made Easy",
["Silvanus Thompson"],
["Mathematics"],
),
],
)
async def test_upload_book_with_data(
authenticated_client: AsyncClient,
path_to_book: Path,
library_id: int,
title: str,
authors: list[str],
tags: list[str],
) -> None:
"""Test uploading a book file with some book information provided."""
# Read file contents
file_content = path_to_book.read_bytes()
# Prepare multipart form data
files = [("files", (path_to_book.name, file_content, "application/epub+zip"))]
# The rest of the book data will be parsed from file
data = {"library_id": library_id, "title": title, "authors": authors, "tags": tags}
response = await authenticated_client.post(
f"/books?library_id={library_id}",
files=files,
data=data,
)
assert response.status_code == 201
book_data = response.json()
assert book_data["id"] is not None
assert book_data["title"] == title
assert book_data["library_id"] == library_id
book_authors = [author["name"] for author in book_data["authors"]]
assert book_authors == authors
book_tags = [tag["name"] for tag in book_data["tags"]]
assert book_tags == tags
@pytest.mark.parametrize(
("path_to_book", "library_id"),
[
(Path("tests/data_files/Moby Dick; Or, The Whale - Herman Melville.epub"), 1),
(Path("tests/data_files/Calculus Made Easy - Silvanus Thompson.pdf"), 2),
],
)
async def test_get_book_file(
authenticated_client: AsyncClient, path_to_book: Path, library_id: int
) -> None:
"""Test uploading a book then downloading the book file."""
# Read file contents
file_content = path_to_book.read_bytes()
# Prepare multipart form data
files = [("files", (path_to_book.name, file_content, "application/epub+zip"))]
# The rest of the book data will be parsed from file
data = {
"library_id": library_id,
}
create_response = await authenticated_client.post(
f"/books?library_id={library_id}",
files=files,
data=data,
)
assert create_response.status_code == 201
book_data = create_response.json()
# Retrieve the book file
book_id = book_data["id"]
file_id = book_data["files"][0]["id"]
file_response = await authenticated_client.get(
f"/books/download/{book_id}/{file_id}"
)
assert file_response.status_code == 200
downloaded_content = file_response.content
assert len(downloaded_content) == len(file_content)
assert downloaded_content == file_content
async def test_get_book_by_id(populated_authenticated_client: AsyncClient) -> None:
"""Test retrieving a specific book by ID."""
# Retrieve the book
response = await populated_authenticated_client.get(f"/books/1")
assert response.status_code == 200
book_data = response.json()
assert book_data["id"] == 1
assert book_data["title"] == "The Adventures of Sherlock Holmes"
assert len(book_data["authors"]) == 1
assert book_data["authors"][0]["name"] == "Arthur Conan Doyle"
assert len(book_data["tags"]) == 1
assert book_data["tags"][0]["name"] == "Mystery"
async def test_list_books(populated_authenticated_client: AsyncClient) -> None:
"""Test listing books with pagination and filters."""
response = await populated_authenticated_client.get(
"/books?library_id=1&pageSize=10&page=1"
)
assert response.status_code == 200
data = response.json()
assert "items" in data # Pagination structure
assert isinstance(data.get("items"), list)
assert data["total"] == 3 # There should be 3 books in this library
async def test_list_books_with_tag_filter(
populated_authenticated_client: AsyncClient,
) -> None:
"""Test listing books filtered by tags."""
response = await populated_authenticated_client.get("/books?library_id=1&tags=1")
assert response.status_code == 200
result = response.json()
assert len(result["items"]) == 2 # Two books with the "Mystery" tag
async def test_list_books_with_author_filter(
populated_authenticated_client: AsyncClient,
) -> None:
"""Test listing books filtered by authors."""
response = await populated_authenticated_client.get("/books?library_id=1&authors=1")
assert response.status_code == 200
result = response.json()
assert len(result["items"]) == 1 # One book by "Arthur Conan Doyle"
async def test_list_books_with_search(
populated_authenticated_client: AsyncClient,
) -> None:
"""Test listing books with title search."""
response = await populated_authenticated_client.get(
"/books?library_id=1&searchString=frankenstein"
)
assert response.status_code == 200
result = response.json()
assert len(result["items"]) == 1 # One matching book
async def test_delete_book_metadata_only(
populated_authenticated_client: AsyncClient,
) -> None:
"""Test deleting a book record without deleting files."""
# Delete book without deleting files
response = await populated_authenticated_client.delete(
f"/books?book_ids=3&delete_files=false&library_id=1"
)
assert response.status_code == 204
# Verify book is deleted
get_response = await populated_authenticated_client.get(f"/books/3")
assert get_response.status_code == 404
async def test_delete_book_with_files(
populated_authenticated_client: AsyncClient,
) -> None:
"""Test deleting a book and its associated files."""
# Delete book and files
response = await populated_authenticated_client.delete(
f"/books?book_ids=3&delete_files=true&library_id=1"
)
assert response.status_code == 204
async def test_delete_specific_book_files(
populated_authenticated_client: AsyncClient,
) -> None:
"""Test deleting specific files from a book."""
# Delete specific file
response = await populated_authenticated_client.delete(
f"/books/1/files?file_ids=1",
)
assert response.status_code == 204
async def test_update_reading_progress(
populated_authenticated_client: AsyncClient,
) -> None:
"""Test updating reading progress for a book."""
# Update progress
progress_data = {
"progress": 0.5,
}
response = await populated_authenticated_client.post(
f"/books/progress/1",
json=progress_data,
)
assert response.status_code == 201
# Get the book and check if the progress is correct
response = await populated_authenticated_client.get("/books/1")
assert response.status_code == 200
book = response.json()
assert book["progress"]["progress"] == 0.5
async def test_create_multiple_books_from_directory(
authenticated_client: AsyncClient,
) -> None:
"""Test creating multiple books from a directory of files."""
path1 = Path("tests/data_files/Metamorphosis - Franz Kafka.epub")
path2 = Path("tests/data_files/Moby Dick; Or, The Whale - Herman Melville.epub")
files = [
("files", (path1.name, path1.read_bytes(), "application/epub+zip")),
("files", (path2.name, path2.read_bytes(), "application/epub+zip")),
]
response = await authenticated_client.post(
"/books/fromFiles?library_id=1",
files=files,
data={"library_id": 1},
)
assert response.status_code == 201
data = response.json()
assert len(data.get("items") or data.get("data")) >= 1
# async def test_delete_book_metadata(authenticated_client: AsyncClient) -> None:
# raise NotImplementedError()
# async def test_delete_book_metadata_and_files(
# authenticated_client: AsyncClient,
# ) -> None:
# raise NotImplementedError()
# async def test_edit_book_metadata(authenticated_client: AsyncClient) -> None:
# raise NotImplementedError()
import pytest
import aiofiles
from httpx import AsyncClient
from pathlib import Path
from litestar.status_codes import HTTP_400_BAD_REQUEST
import pytest
from httpx import AsyncClient
from datetime import date
class TestMetadataUpdates:
@pytest.mark.parametrize(
("updated_field", "new_value"),
[
("title", "New Title"),
("subtitle", "Updated Subtitle"),
("pages", 256),
("description", "An updated description"),
("edition", 2),
("language", "pl"),
("published_date", "1910-04-05"),
("series_position", "3"),
],
)
async def test_update_single_scalar_field(
self,
populated_authenticated_client: AsyncClient,
updated_field: str,
new_value: str,
) -> None:
"""Test updating a single field without affecting others."""
# Get original book state
original_response = await populated_authenticated_client.get("/books/1")
assert original_response.status_code == 200
original_data = original_response.json()
# Update field
response = await populated_authenticated_client.patch(
"/books/1",
json={updated_field: str(new_value)},
)
assert response.status_code == 200
updated_data = response.json()
# Verify updated field
assert updated_data[updated_field] == new_value
# Verify all other fields remain unchanged
for field, original_value in original_data.items():
if field != updated_field and field not in ["updated_at", "created_at"]:
assert updated_data[field] == original_value, (
f"Field {field} was unexpectedly changed"
)
@pytest.mark.parametrize(
("updated_field", "new_values", "assertion_func"),
[
(
"authors", # Update with new authors
["New Author 1", "New Author 2"],
lambda data: {a["name"] for a in data["authors"]}
== {"New Author 1", "New Author 2"},
),
(
"authors", # Clear authors
[],
lambda data: data["authors"] == [],
),
(
"tags", # Update with new tags
["Tag 1", "Tag 2", "Tag 3"],
lambda data: {t["name"] for t in data["tags"]}
== {"Tag 1", "Tag 2", "Tag 3"},
),
(
"tags", # Clear tags
[],
lambda data: data["tags"] == [],
),
(
"publisher", # Update with new publisher
"Updated Publisher",
lambda data: data["publisher"]["name"] == "Updated Publisher",
),
(
"publisher", # Clear publisher
None,
lambda data: data["publisher"] is None,
),
(
"identifiers", # Update with new identifiers
{"isbn-13": "978-1234567890", "doi": "10.example/id"},
lambda data: data["identifiers"]
== {"isbn-13": "978-1234567890", "doi": "10.example/id"},
),
(
"identifiers", # Clear identifiers
{},
lambda data: data["identifiers"] == {},
),
(
"series", # Update with new series
"Updated Series",
lambda data: data["series"]["title"] == "Updated Series",
),
(
"series", # Clear series
None,
lambda data: data["series"] is None,
),
],
)
async def test_update_relationship_fields(
self,
populated_authenticated_client: AsyncClient,
updated_field: str,
new_values: list[str],
assertion_func,
) -> None:
"""Test updating relationship fields (authors, tags, etc.)"""
# Get original book state
original_response = await populated_authenticated_client.get("/books/1")
assert original_response.status_code == 200
original_data = original_response.json()
response = await populated_authenticated_client.patch(
"/books/1",
json={updated_field: new_values},
)
assert response.status_code == 200
# Verify updated field
updated_data = response.json()
assert assertion_func(updated_data)
# Verify all other fields remain unchanged
for field, original_value in original_data.items():
if field != updated_field and field not in ["updated_at", "created_at"]:
assert updated_data[field] == original_value, (
f"Field {field} was unexpectedly changed"
)
async def test_update_with_invalid_pages_value(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test updating with invalid pages value (non-integer)."""
data = {
"pages": "not_a_number",
}
response = await populated_authenticated_client.patch(
"/books/1",
files=data,
)
assert response.status_code == 400
async def test_update_with_invalid_edition_value(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test updating with invalid edition value (non-integer)."""
data = {
"edition": "invalid_edition",
}
response = await populated_authenticated_client.patch(
"/books/1",
files=data,
)
assert response.status_code == 400
async def test_update_with_invalid_date_format(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test updating with invalid published_date format."""
data = {
"published_date": "invalid-date",
}
response = await populated_authenticated_client.patch(
"/books/1",
files=data,
)
assert response.status_code == 400
async def test_update_with_invalid_identifiers_json(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test updating with invalid JSON in identifiers."""
data = {
"identifiers": "not valid json",
}
response = await populated_authenticated_client.patch(
"/books/1",
files=data,
)
assert response.status_code == 400
async def test_update_multiple_fields_at_once(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test updating multiple fields simultaneously."""
data = {
"title": "New Title",
"description": "New description",
"publisher": "New Publisher",
"pages": 430,
}
response = await populated_authenticated_client.patch(
"/books/1",
json=data,
)
assert response.status_code == 200
book_data = response.json()
assert book_data["title"] == "New Title"
assert book_data["description"] == "New description"
assert book_data["publisher"]["name"] == "New Publisher"
assert book_data["pages"] == 430
async def test_update_nonexistent_book(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test updating a nonexistent book returns 404."""
data = {
"title": "New Title",
}
response = await populated_authenticated_client.patch(
"/books/99999",
json=data,
)
assert response.status_code == 404
assert "book does not exist" in response.text
@pytest.mark.parametrize(
("updated_field"),
[
("subtitle"),
("pages"),
("description"),
("edition"),
("language"),
("published_date"),
("series_position"),
],
)
async def test_update_clears_optional_field(
self, populated_authenticated_client: AsyncClient, updated_field: str
) -> None:
"""Test that optional fields can be cleared by passing None or empty values."""
data = {updated_field: None}
response = await populated_authenticated_client.patch(
"/books/1",
json=data,
)
assert response.status_code == 200
result = response.json()
assert result[updated_field] == None
@pytest.mark.parametrize(
("updated_field"),
[
("title"),
],
)
async def test_update_clears_required_field(
self, populated_authenticated_client: AsyncClient, updated_field: str
) -> None:
"""Test that optional fields can be cleared by passing None or empty values."""
data = {updated_field: None}
response = await populated_authenticated_client.patch(
"/books/1",
json=data,
)
assert response.status_code == 400
async def test_update_cover_successfully(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test successfully updating a book's cover image."""
path = Path("tests/data_files/cover.jpg")
file_content = path.read_bytes()
files = {"cover_image": (path.name, file_content, "image/jpeg")}
# Get original book state
original_response = await populated_authenticated_client.get("/books/1")
assert original_response.status_code == 200
original_data = original_response.json()
response = await populated_authenticated_client.put(
"/books/1/cover",
files=files,
)
assert response.status_code == 200
book_data = response.json()
assert book_data["id"] == 1
assert book_data["cover_image"] is not None
assert book_data["cover_image"] != original_data["cover_image"]
class TestFileManagement:
async def test_add_multiple_files_to_book(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test adding multiple files to a book in a single request."""
path1 = Path("tests/data_files/Metamorphosis - Franz Kafka.epub")
path2 = Path("tests/data_files/Calculus Made Easy - Silvanus Thompson.pdf")
file1_content = path1.read_bytes()
file2_content = path2.read_bytes()
files = [
("files", (path1.name, file1_content, "application/epub+zip")),
("files", (path2.name, file2_content, "application/pdf")),
]
response = await populated_authenticated_client.post(
"/books/1/files",
files=files,
)
assert response.status_code == 201
book_data = response.json()
assert len(book_data["files"]) >= 2
async def test_add_files_to_nonexistent_book(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test adding files to a nonexistent book returns 404."""
path = Path("tests/data_files/Metamorphosis - Franz Kafka.epub")
file_content = path.read_bytes()
files = [("files", (path.name, file_content, "application/epub+zip"))]
response = await populated_authenticated_client.post(
"/books/99999/files?library_id=1",
files=files,
)
assert response.status_code == 404
async def test_remove_multiple_files_at_once(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test removing multiple files from a book at once."""
# First add multiple files
path1 = Path("tests/data_files/Metamorphosis - Franz Kafka.epub")
path2 = Path("tests/data_files/Calculus Made Easy - Silvanus Thompson.pdf")
files = [
("files", (path1.name, path1.read_bytes(), "application/epub+zip")),
("files", (path2.name, path2.read_bytes(), "application/pdf")),
]
add_response = await populated_authenticated_client.post(
"/books/1/files",
files=files,
)
assert add_response.status_code == 201
book_data = add_response.json()
file_ids = [f["id"] for f in book_data["files"]]
# Remove multiple files
response = await populated_authenticated_client.delete(
f"/books/1/files?file_ids={file_ids[0]}&file_ids={file_ids[1]}&delete_files=false",
)
assert response.status_code == 204
# Verify files were removed
book = await populated_authenticated_client.get("/books/1")
book_data = book.json()
assert len(book_data["files"]) < len(file_ids)
async def test_remove_zero_files_raises_error(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test error when trying to remove zero files."""
response = await populated_authenticated_client.delete(
"/books/1/files?delete_files=false",
)
assert response.status_code == HTTP_400_BAD_REQUEST
assert "missing required query parameter 'file_ids'" in response.text.lower()
async def test_remove_files_from_nonexistent_book(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test removing files from a nonexistent book returns 404."""
response = await populated_authenticated_client.delete(
"/books/99999/files?file_ids=1&delete_files=false&library_id=1",
)
assert response.status_code == 404
async def test_remove_nonexistent_file_id_from_book(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test removing a file_id that doesn't exist on that book."""
# Try to remove a file that doesn't belong to this book
response = await populated_authenticated_client.delete(
"/books/1/files?file_ids=99999&delete_files=false",
)
# Should succeed (idempotent) or return 404, depending on implementation
assert response.status_code in [204, 404]
async def test_remove_file_with_delete_files_false_keeps_filesystem_file(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test that removing with delete_files=false keeps file on disk."""
# Get the book to retrieve file info
book_response = await populated_authenticated_client.get("/books/1")
book_data = book_response.json()
if not book_data["files"]:
pytest.skip("Book has no files")
file_id = book_data["files"][0]["id"]
filename = book_data["files"][0].get("path")
# Remove file without deleting from filesystem
response = await populated_authenticated_client.delete(
f"/books/1/files?file_ids={file_id}&delete_files=false",
)
assert response.status_code == 204
async def test_remove_file_with_delete_files_true_removes_filesystem_file(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test that removing with delete_files=true deletes from filesystem."""
# Add a new file first
path = Path("tests/data_files/Calculus Made Easy - Silvanus Thompson.pdf")
file_content = path.read_bytes()
files = [("files", (path.name, file_content, "application/pdf"))]
add_response = await populated_authenticated_client.post(
"/books/1/files",
files=files,
)
assert add_response.status_code == 201
book_data = add_response.json()
file_id = book_data["files"][-1]["id"]
file_path = book_data["files"][-1].get("path")
# Remove file with deletion from filesystem
response = await populated_authenticated_client.delete(
f"/books/1/files?file_ids={file_id}&delete_files=true",
)
assert response.status_code == 204
async def test_remove_file_idempotent(
self, populated_authenticated_client: AsyncClient
) -> None:
"""Test that removing the same file twice doesn't error on second attempt."""
book_response = await populated_authenticated_client.get("/books/1")
book_data = book_response.json()
if not book_data["files"]:
pytest.skip("Book has no files")
file_id = book_data["files"][0]["id"]
# Remove file first time
response1 = await populated_authenticated_client.delete(
f"/books/1/files?file_ids={file_id}&delete_files=false",
)
assert response1.status_code == 204
# Try to remove same file again
response2 = await populated_authenticated_client.delete(
f"/books/1/files?file_ids={file_id}&delete_files=false",
)
# Should succeed (idempotent)
assert response2.status_code == 204

View File

@@ -0,0 +1,286 @@
import pytest
from httpx import AsyncClient
async def test_get_shelves_without_auth(client: AsyncClient) -> None:
response = await client.get("/shelves")
assert response.status_code == 401
async def test_get_shelves_with_auth(authenticated_client: AsyncClient) -> None:
response = await authenticated_client.get("/shelves")
assert response.status_code == 200
result = response.json()
assert len(result["items"]) == 0
async def test_get_existing_shelves(
populated_authenticated_client: AsyncClient,
) -> None:
response = await populated_authenticated_client.get("/shelves")
assert response.status_code == 200
result = response.json()
assert len(result["items"]) == 1
async def test_create_shelf(authenticated_client: AsyncClient) -> None:
shelf_data = {"title": "Favourites"}
response = await authenticated_client.post("/shelves", json=shelf_data)
assert response.status_code == 201
result = response.json()
assert result["title"] == "Favourites"
assert result["library_id"] is None
async def test_create_shelf_in_nonexistent_library(
authenticated_client: AsyncClient,
) -> None:
shelf_data = {"title": "Favourites", "library_id": 5}
response = await authenticated_client.post("/shelves", json=shelf_data)
assert response.status_code == 400
assert f"Library with ID {shelf_data['library_id']} does not exist" in response.text
async def test_create_shelf_in_existing_library(
authenticated_client: AsyncClient,
) -> None:
shelf_data = {"title": "Favourites", "library_id": 1}
response = await authenticated_client.post("/shelves", json=shelf_data)
assert response.status_code == 201
result = response.json()
assert result["title"] == "Favourites"
assert result["library_id"] == 1
async def test_delete_shelf_without_auth(client: AsyncClient) -> None:
response = await client.delete("/shelves/1")
assert response.status_code == 401
async def test_delete_shelf_unauthorized(
authenticated_client: AsyncClient, other_authenticated_client: AsyncClient
) -> None:
"""Verify users can't delete shelves they don't own."""
# Create a shelf as authenticated_client
shelf_data = {"title": "My Shelf"}
shelf_response = await authenticated_client.post("/shelves", json=shelf_data)
shelf_id = shelf_response.json()["id"]
# Try to delete as other_authenticated_client
response = await other_authenticated_client.delete(f"/shelves/{shelf_id}")
assert response.status_code == 403
assert "do not have permission" in response.text
async def test_add_books_unauthorized(
authenticated_client: AsyncClient, other_authenticated_client: AsyncClient
) -> None:
"""Verify users can't add books to shelves they don't own."""
# Create a shelf as authenticated_client
shelf_data = {"title": "Other User's Shelf"}
shelf_response = await authenticated_client.post("/shelves", json=shelf_data)
shelf_id = shelf_response.json()["id"]
# Try to add books as other_authenticated_client
response = await other_authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": [1, 2]}
)
assert response.status_code == 403
assert "do not have permission" in response.text
async def test_remove_books_unauthorized(
authenticated_client: AsyncClient, other_authenticated_client: AsyncClient
) -> None:
"""Verify users can't remove books from shelves they don't own."""
# Create a shelf and add books as authenticated_client
shelf_data = {"title": "Other User's Shelf"}
shelf_response = await authenticated_client.post("/shelves", json=shelf_data)
shelf_id = shelf_response.json()["id"]
await authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": [1, 2]}
)
# Try to remove books as other_authenticated_client
response = await other_authenticated_client.delete(
f"/shelves/{shelf_id}/books", params={"book_ids": [1]}
)
assert response.status_code == 403
assert "do not have permission" in response.text
async def test_delete_nonexistent_shelf(authenticated_client: AsyncClient) -> None:
"""Verify 404 when deleting a shelf that doesn't exist."""
response = await authenticated_client.delete("/shelves/99999")
assert response.status_code == 404
async def test_add_books_to_shelf(populated_authenticated_client: AsyncClient) -> None:
"""Successfully add books to a shelf."""
shelf_data = {"title": "Test Shelf"}
shelf_response = await populated_authenticated_client.post(
"/shelves", json=shelf_data
)
shelf_id = shelf_response.json()["id"]
book_ids = [1, 2]
response = await populated_authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": book_ids}
)
assert response.status_code == 201
# Verify by listing books filtered by shelf
books_response = await populated_authenticated_client.get(
"/books", params={"shelves": shelf_id}
)
assert books_response.status_code == 200
assert len(books_response.json()["items"]) == 2
async def test_add_books_to_nonexistent_shelf(
authenticated_client: AsyncClient,
) -> None:
"""Verify 404 when adding to nonexistent shelf."""
response = await authenticated_client.post(
"/shelves/99999/books", params={"book_ids": [1, 2]}
)
assert response.status_code == 404
async def test_add_nonexistent_books(authenticated_client: AsyncClient) -> None:
"""Verify appropriate error when book IDs don't exist."""
shelf_data = {"title": "Test Shelf"}
shelf_response = await authenticated_client.post("/shelves", json=shelf_data)
shelf_id = shelf_response.json()["id"]
response = await authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": [99999, 99998]}
)
assert response.status_code == 400
async def test_remove_books_from_shelf(
populated_authenticated_client: AsyncClient,
) -> None:
"""Successfully remove books from a shelf."""
shelf_data = {"title": "Test Shelf"}
shelf_response = await populated_authenticated_client.post(
"/shelves", json=shelf_data
)
shelf_id = shelf_response.json()["id"]
# Add books first
book_ids = [1, 2, 3]
await populated_authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": book_ids}
)
# Remove one book
response = await populated_authenticated_client.delete(
f"/shelves/{shelf_id}/books", params={"book_ids": [1]}
)
assert response.status_code == 200
# Verify by listing books
books_response = await populated_authenticated_client.get(
"/books", params={"shelves": shelf_id}
)
assert books_response.status_code == 200
assert books_response.json()["total"] == 2
async def test_remove_books_from_nonexistent_shelf(
authenticated_client: AsyncClient,
) -> None:
"""Verify 404 when removing from nonexistent shelf."""
response = await authenticated_client.delete(
"/shelves/99999/books", params={"book_ids": [1, 2]}
)
assert response.status_code == 404
async def test_remove_nonexistent_books(
populated_authenticated_client: AsyncClient,
) -> None:
"""Verify appropriate error handling when removing nonexistent books."""
shelf_data = {"title": "Test Shelf"}
shelf_response = await populated_authenticated_client.post(
"/shelves", json=shelf_data
)
shelf_id = shelf_response.json()["id"]
# Add a book first
await populated_authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": [1]}
)
# Try to remove books that don't exist on shelf
response = await populated_authenticated_client.delete(
f"/shelves/{shelf_id}/books", params={"book_ids": [99999]}
)
# Idempotent behaviour
assert response.status_code == 200
async def test_add_duplicate_books(populated_authenticated_client: AsyncClient) -> None:
"""Verify behavior when adding books already on shelf."""
shelf_data = {"title": "Test Shelf"}
shelf_response = await populated_authenticated_client.post(
"/shelves", json=shelf_data
)
shelf_id = shelf_response.json()["id"]
# Add books
book_ids = [1, 2]
await populated_authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": book_ids}
)
# Try to add some of the same books again
response = await populated_authenticated_client.post(
f"/shelves/{shelf_id}/books", params={"book_ids": [1, 2, 3]}
)
assert response.status_code == 201
# Verify final state
books_response = await populated_authenticated_client.get(
"/books", params={"shelves": shelf_id}
)
result = books_response.json()["items"]
assert len(result) == 3
async def test_create_shelf_with_empty_title(authenticated_client: AsyncClient) -> None:
"""Verify validation rejects empty shelf titles."""
shelf_data = {"title": ""}
response = await authenticated_client.post("/shelves", json=shelf_data)
assert response.status_code == 400
assert "title" in response.text.lower()
async def test_create_shelf_with_whitespace_only_title(
authenticated_client: AsyncClient,
) -> None:
"""Verify validation rejects whitespace-only titles."""
shelf_data = {"title": " "}
response = await authenticated_client.post("/shelves", json=shelf_data)
assert response.status_code == 400

View File

@@ -0,0 +1,44 @@
import pytest
from httpx import AsyncClient
from pathlib import Path
async def test_get_libraries_without_auth(client: AsyncClient) -> None:
response = await client.get("/libraries?library_id=1")
assert response.status_code == 401
async def test_get_libraries_with_auth(authenticated_client: AsyncClient) -> None:
response = await authenticated_client.get("/libraries?library_id=1")
assert response.status_code == 200
@pytest.mark.parametrize(
("name", "read_only", "expected_status_code"),
[("Test Library", False, 201), ("Read-Only Library", True, 400)],
)
async def test_create_library(
authenticated_client: AsyncClient,
tmp_path: Path,
name: str,
read_only: bool,
expected_status_code: int,
) -> None:
library_data = {
"name": "Test Library",
"root_path": f"{tmp_path}/books",
"path_template": "{author}/{title}",
"read_only": read_only,
}
response = await authenticated_client.post("/libraries", json=library_data)
assert response.status_code == expected_status_code
if response.status_code == 201:
result = response.json()
assert result["name"] == "Test Library"
assert result["root_path"] == f"{tmp_path}/books"
assert result["path_template"] == "{author}/{title}"
assert result["read_only"] == False
assert result["description"] is None

View File

@@ -0,0 +1,17 @@
import pytest
from pathlib import Path
from datetime import date
from chitai.services.metadata_extractor import EpubExtractor
@pytest.mark.asyncio()
class TestEpubExtractor:
async def test_extraction_by_path(self):
path = Path("tests/data_files/Moby Dick; Or, The Whale - Herman Melville.epub")
metadata = await EpubExtractor.extract_metadata(path)
assert metadata["title"] == "Moby Dick; Or, The Whale"
assert metadata["authors"] == ["Herman Melville"]
assert metadata["published_date"] == date(year=2001, month=7, day=1)

View File

@@ -0,0 +1,74 @@
"""Tests for BookService"""
import pytest
from pathlib import Path
import aiofiles.os as aios
from chitai.schemas import BookCreate
from chitai.services import BookService
from chitai.database import models as m
@pytest.mark.asyncio
class TestBookServiceCRUD:
"""Test CRUD operation for libraries."""
async def test_update_book(
self, books_service: BookService, test_library: m.Library
) -> None:
book_data = BookCreate(
library_id=1,
title="Fellowship of the Ring",
authors=["J.R.R Tolkien"],
tags=["Fantasy"],
identifiers={"isbn-13": "9780261102354"},
pages=427,
)
book = await books_service.to_model_on_create(book_data.model_dump())
# Add path manually as it won't be generated (not using the create function, but manually inserting into db)
book.path = f"{test_library.root_path}/J.R.R Tolkien/The Fellowship of the Ring"
await aios.makedirs(book.path)
books_service.repository.session.add(book)
await books_service.repository.session.commit()
await books_service.repository.session.refresh(book)
await books_service.update(
book.id,
{
"title": "The Fellowship of the Ring",
"identifiers": {"isbn-10": "9780261102354"},
"edition": 3,
"publisher": "Tolkien Estate",
"series": "The Lord of the Rings",
"series_position": "1",
"tags": ["Fantasy", "Adventure"],
},
test_library,
)
updated_book = await books_service.get(book.id)
# Assert updated information is correct
assert updated_book.title == "The Fellowship of the Ring"
assert (
updated_book.path
== f"{test_library.root_path}/J.R.R Tolkien/The Lord of the Rings/01 - The Fellowship of the Ring"
)
assert len(updated_book.identifiers)
assert updated_book.identifiers[0].value == "9780261102354"
assert updated_book.edition == 3
assert updated_book.publisher.name == "Tolkien Estate"
assert len(updated_book.tags) == 2
# book = await books_service.create(book_data.model_dump())

View File

@@ -0,0 +1,298 @@
"""Tests for LibraryService"""
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from chitai.services import ShelfService
from chitai.database import models as m
import pytest
from sqlalchemy import select
from chitai.services.bookshelf import ShelfService
from chitai.services import BookService
from chitai.database.models.book_list import BookList, BookListLink
from chitai.database import models as m
@pytest.fixture
async def db_session(bookshelf_service: ShelfService) -> AsyncSession:
return bookshelf_service.repository.session
@pytest.fixture(autouse=True)
async def test_books(db_session: AsyncSession) -> list[m.Book]:
"""Create test books in the database."""
library = m.Library(
name="Default Library",
slug="default-library",
root_path="./path",
path_template="{author}/{title}",
read_only=False,
)
db_session.add(library)
books = [m.Book(title=f"Book {i}", library_id=1) for i in range(1, 8)]
db_session.add_all(books)
user = m.User(email="test_user@example.com", password="password123")
db_session.add(user)
await db_session.flush()
await db_session.commit()
return books
class TestAddBooks:
async def test_add_books_to_empty_shelf(
self, bookshelf_service: ShelfService, db_session
) -> None:
"""Successfully add books to an empty shelf."""
# Create a shelf
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
# Add books
await bookshelf_service.add_books(shelf.id, [1, 2, 3])
# Verify books were added
links = await db_session.execute(
select(BookListLink).where(BookListLink.list_id == shelf.id)
)
added_links = links.scalars().all()
assert len(added_links) == 3
assert {link.book_id for link in added_links} == {1, 2, 3}
async def test_add_books_preserves_positions(
self, bookshelf_service: ShelfService, db_session
) -> None:
"""Verify books are assigned correct positions on the shelf."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
await bookshelf_service.add_books(shelf.id, [1, 2, 3])
links = await db_session.execute(
select(BookListLink)
.where(BookListLink.list_id == shelf.id)
.order_by(BookListLink.position)
)
added_links = links.scalars().all()
# Verify positions are sequential starting from 0
assert [link.position for link in added_links] == [0, 1, 2]
async def test_add_books_to_shelf_with_existing_books(
self, bookshelf_service: ShelfService, db_session
) -> None:
"""Adding books to a shelf with existing books assigns correct positions."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
# Add initial books
await bookshelf_service.add_books(shelf.id, [1, 2])
# Add more books
await bookshelf_service.add_books(shelf.id, [3, 4])
links = await db_session.execute(
select(BookListLink)
.where(BookListLink.list_id == shelf.id)
.order_by(BookListLink.position)
)
added_links = links.scalars().all()
# Verify new books continue from position 2
assert len(added_links) == 4
assert [link.position for link in added_links] == [0, 1, 2, 3]
async def test_add_duplicate_books_is_idempotent(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Adding books already on shelf should not create duplicates."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
# Add books
await bookshelf_service.add_books(shelf.id, [1, 2, 3])
# Try to add overlapping books
await bookshelf_service.add_books(shelf.id, [2, 3, 4])
links = await db_session.execute(
select(BookListLink).where(BookListLink.list_id == shelf.id)
)
added_links = links.scalars().all()
# Should have 4 books total (1, 2, 3, 4), not 7
assert len(added_links) == 4
assert {link.book_id for link in added_links} == {1, 2, 3, 4}
async def test_add_books_raises_on_nonexistent_shelf(
self, bookshelf_service: ShelfService
) -> None:
"""Adding books to nonexistent shelf raises error."""
with pytest.raises(Exception): # Could be SQLAlchemyError or specific error
await bookshelf_service.add_books(99999, [1, 2, 3])
async def test_add_nonexistent_books_raises_error(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Adding nonexistent books raises ValueError."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
with pytest.raises(ValueError, match="One or more books not found"):
await bookshelf_service.add_books(shelf.id, [99999, 99998])
async def test_add_partial_nonexistent_books_raises_error(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Adding a mix of existent and nonexistent books raises error."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
# Book 1 might exist, but 99999 doesn't
with pytest.raises(ValueError, match="One or more books not found"):
await bookshelf_service.add_books(shelf.id, [1, 99999])
async def test_add_empty_book_list(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Adding empty book list should return without error."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
# Should not raise
await bookshelf_service.add_books(shelf.id, [])
class TestRemoveBooks:
async def test_remove_books_from_shelf(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Successfully remove books from a shelf."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
# Add books
await bookshelf_service.add_books(shelf.id, [1, 2, 3, 4])
# Remove some books
await bookshelf_service.remove_books(shelf.id, [2, 3])
links = await db_session.execute(
select(BookListLink).where(BookListLink.list_id == shelf.id)
)
remaining_links = links.scalars().all()
assert len(remaining_links) == 2
assert {link.book_id for link in remaining_links} == {1, 4}
async def test_remove_all_books_from_shelf(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Removing all books should leave shelf empty."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
await bookshelf_service.add_books(shelf.id, [1, 2, 3])
await bookshelf_service.remove_books(shelf.id, [1, 2, 3])
links = await db_session.execute(
select(BookListLink).where(BookListLink.list_id == shelf.id)
)
remaining_links = links.scalars().all()
assert len(remaining_links) == 0
async def test_remove_nonexistent_book_is_idempotent(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Removing books not on shelf should not raise error."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
await bookshelf_service.add_books(shelf.id, [1, 2, 3])
# Remove books that don't exist on shelf - should not raise
await bookshelf_service.remove_books(shelf.id, [99, 100])
links = await db_session.execute(
select(BookListLink).where(BookListLink.list_id == shelf.id)
)
remaining_links = links.scalars().all()
# Original books should still be there
assert len(remaining_links) == 3
async def test_remove_books_raises_on_nonexistent_shelf(
self, bookshelf_service: ShelfService
) -> None:
"""Removing books from nonexistent shelf raises error."""
with pytest.raises(Exception):
await bookshelf_service.remove_books(99999, [1, 2, 3])
async def test_remove_empty_book_list(
self, bookshelf_service: ShelfService, db_session: AsyncSession
) -> None:
"""Removing empty book list should return without error."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
await bookshelf_service.add_books(shelf.id, [1, 2, 3])
# Should not raise
await bookshelf_service.remove_books(shelf.id, [])
links = await db_session.execute(
select(BookListLink).where(BookListLink.list_id == shelf.id)
)
remaining_links = links.scalars().all()
# All books should still be there
assert len(remaining_links) == 3
async def test_add_remove_add_sequence(
self, bookshelf_service: ShelfService, db_session
) -> None:
"""Add books, remove from middle, then add more - positions should be maintained."""
shelf = BookList(title="Test Shelf", user_id=1)
db_session.add(shelf)
await db_session.flush()
# Add initial books [1, 2, 3, 4, 5]
await bookshelf_service.add_books(shelf.id, [1, 2, 3, 4, 5])
# Remove books from middle [2, 3, 4]
await bookshelf_service.remove_books(shelf.id, [2, 3, 4])
# Add more books [6, 7]
await bookshelf_service.add_books(shelf.id, [6, 7])
links = await db_session.execute(
select(BookListLink)
.where(BookListLink.list_id == shelf.id)
.order_by(BookListLink.position)
)
final_links = links.scalars().all()
# Should have [1, 5, 6, 7] with positions [0, 1, 2, 3]
assert len(final_links) == 4
assert [link.book_id for link in final_links] == [1, 5, 6, 7]
assert [link.position for link in final_links] == [0, 1, 2, 3]

View File

@@ -0,0 +1,153 @@
"""Tests for LibraryService"""
import pytest
from pathlib import Path
from sqlalchemy.ext.asyncio import AsyncSession
from chitai.schemas.library import LibraryCreate
from chitai.services import LibraryService
from chitai.database import models as m
from chitai.services.utils import DirectoryDoesNotExist
@pytest.mark.asyncio
class TestLibraryServiceCRUD:
"""Test CRUD operation for libraries."""
async def test_create_library(
self, library_service: LibraryService, tmp_path: Path
) -> None:
"""Test creating a library with a valid root path."""
library_path = f"{tmp_path}/books"
library_data = LibraryCreate(
name="Test Library",
root_path=library_path,
path_template="{author}/{title}",
read_only=False,
)
library = await library_service.create(library_data)
assert library.name == "Test Library"
assert library.root_path == library_path
assert library.path_template == "{author}/{title}"
assert library.description == None
assert library.read_only == False
# Check if directory was created
assert Path(library.root_path).is_dir()
async def test_create_library_root_path_permission_error(
self, library_service: LibraryService, tmp_path: Path
) -> None:
"""Test creating a library with a root path that is not permitted."""
# Change permissions on the temp path
tmp_path.chmod(0o544)
library_path = f"{tmp_path}/books"
library_data = LibraryCreate(
name="Test Library",
root_path=library_path,
path_template="{author}/{title}",
read_only=False,
)
with pytest.raises(PermissionError) as exc_info:
library = await library_service.create(library_data)
# Check if directory was created
assert not Path(library_path).exists()
# Change permissions back
tmp_path.chmod(0o755)
async def test_create_library_read_only_path_exists(
self, library_service: LibraryService, tmp_path: Path
) -> None:
"""Test creating a read-only library with a root path that exists."""
# Create the path beforehand
library_path = f"{tmp_path}/books"
Path(library_path).mkdir()
library_data = LibraryCreate(
name="Test Library",
root_path=library_path,
path_template="{author}/{title}",
read_only=True,
)
library = await library_service.create(library_data)
assert library.name == "Test Library"
assert library.root_path == library_path
assert library.path_template == "{author}/{title}"
assert library.description == None
assert library.read_only == True
async def test_create_library_read_only_nonexistent_path(
self, library_service: LibraryService, tmp_path: Path
) -> None:
"""Test creating a read-only library with a nonexistent root path."""
library_path = f"{tmp_path}/books"
library_data = LibraryCreate(
name="Test Library",
root_path=library_path,
path_template="{author}/{title}",
read_only=True,
)
with pytest.raises(DirectoryDoesNotExist) as exc_info:
await library_service.create(library_data)
assert "Root directory" in str(exc_info.value)
assert "must exist for a read-only library" in str(exc_info.value)
# Check if directory was created
assert not Path(library_path).exists()
async def test_get_library(
self, session: AsyncSession, library_service: LibraryService
) -> None:
"""Test retrieving a library."""
# Add a library to the database
library_data = m.Library(
name="Testing Library",
slug="testing-library",
root_path="./books",
path_template="{author}/{title}",
description=None,
read_only=False,
)
session.add(library_data)
await session.commit()
await session.refresh(library_data)
library = await library_service.get(library_data.id)
assert library is not None
assert library.id == library_data.id
assert library.name == "Testing Library"
assert library.root_path == "./books"
assert library.path_template == "{author}/{title}"
assert library.description is None
assert library.read_only == False
# async def test_delete_library_keep_files(
# self, session: AsyncSession, library_service: LibraryService
# ) -> None:
# """Test deletion of a library's metadata and associated entities."""
# raise NotImplementedError()
# async def test_delete_library_delete_files(
# self, session: AsyncSession, library_service: LibraryService
# ) -> None:
# """Test deletion of a library's metadata, associated enties, and files/directories."""
# raise NotImplementedError()

View File

@@ -0,0 +1,116 @@
"""Tests for UserService"""
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from litestar.exceptions import PermissionDeniedException
from sqlalchemy.exc import IntegrityError
from chitai.services import UserService
from chitai.database import models as m
@pytest.mark.asyncio
class TestUserServiceAuthentication:
"""Test authentication functionality."""
async def test_authenticate_success(
self, session: AsyncSession, user_service: UserService
) -> None:
"""Test successful user authentication."""
# Create a user with a known password
password = "password123"
user = m.User(email=f"test@example.com", password=password)
session.add(user)
await session.commit()
# Authenticate user
authenticated_user = await user_service.authenticate(
"test@example.com", password
)
assert authenticated_user is not None
assert authenticated_user.email == "test@example.com"
assert authenticated_user.id == user.id
async def test_authenticate_user_not_found(
self, session: AsyncSession, user_service: UserService
) -> None:
"""Test authentication with non-existent user."""
with pytest.raises(PermissionDeniedException) as exc_info:
await user_service.authenticate("nonexistent@example.com", "password")
assert "User not found or password invalid" in str(exc_info.value)
async def test_authenticate_wrong_password(
self, session: AsyncSession, user_service: UserService
) -> None:
"""Test authentication with wrong password."""
# Create user
password = "password123"
user = m.User(email=f"test@example.com", password=password)
session.add(user)
await session.commit()
# Attempt authentication
with pytest.raises(PermissionDeniedException) as exc_info:
await user_service.authenticate("test@example.com", "WrongPassword")
assert "User not found or password invalid" in str(exc_info.value)
@pytest.mark.asyncio
class TestUserServiceCRUD:
"""Test basic CRUD operations."""
async def test_create_user_with_password_hashing(
self, session: AsyncSession, user_service: UserService
) -> None:
user_data = {"email": "newuser@example.com", "password": "password123"}
user = await user_service.create(data=user_data)
assert user.email == "newuser@example.com"
assert user.password is not None
assert user.password != "password123" # Password should be hashed
assert user.password.verify("password123")
async def test_get_user_by_email(
self, session: AsyncSession, user_service: UserService
) -> None:
"""Test getting user by email."""
user = m.User(email=f"test@example.com", password="password123")
session.add(user)
await session.commit()
found_user = await user_service.get_one_or_none(email="test@example.com")
assert found_user is not None
assert found_user.id == user.id
assert found_user.email == user.email
async def test_create_user_with_duplicate_email(
self, session: AsyncSession, user_service: UserService
) -> None:
"""Test creating a new user with a duplicate email."""
# Create first user
user = m.User(email=f"test@example.com", password="password123")
session.add(user)
await session.commit()
# Create second user
user = m.User(email=f"test@example.com", password="password12345")
with pytest.raises(IntegrityError) as exc_info:
session.add(user)
await session.commit()
assert "violates unique constraint" in str(exc_info.value)

1468
backend/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff