diff --git a/mft/cli.py b/mft/cli.py index fc08c8c..5aa1609 100644 --- a/mft/cli.py +++ b/mft/cli.py @@ -1,8 +1,10 @@ import os import sys import argparse +import sqlite3 import uvicorn import mft.config +from datetime import datetime from pathlib import Path @@ -33,6 +35,17 @@ def parse_args() -> argparse.Namespace: run_parser.add_argument( "--debug", action="store_true", help="Run in debug mode with auto-reload" ) + + # db subcommand with nested operations + db_parser = subparsers.add_parser("db", help="Database operations") + db_subparsers = db_parser.add_subparsers(dest="db_command", required=True) + + # db init + db_init_parser = db_subparsers.add_parser("init", help="Initialize database") + + # db backup + db_backup_parser = db_subparsers.add_parser("backup", help="Create database backup") + return transform_args(parser.parse_args()) @@ -58,6 +71,8 @@ def main(): if args.command == "run": run_command(args, settings) + elif args.command == "db": + db_command(args, settings) def run_command(args, settings): @@ -70,5 +85,58 @@ def run_command(args, settings): ) +def db_command(args, settings): + if args.db_command == "init": + db_init_command(args, settings) + elif args.db_command == "backup": + db_backup_command(args, settings) + + +def db_init_command(args, settings): + from mft.database import get_db, db_init, SchemaError + + with get_db() as conn: + try: + db_init(conn) + except SchemaError as e: + print( + f"Database initialization failed with a SchemaError: {e}", + file=sys.stderr, + ) + sys.exit(1) + print("Database initialized") + + +def db_backup_command(args, settings): + db_path = settings.database_path + + if not db_path.exists(): + print(f"Error: Database file `{db_path}` does not exist.", file=sys.stderr) + sys.exit(1) + + # Generate backup filename with timestamp + timestamp = datetime.now().strftime("%Y-%m-%d.%H%M%S") + backup_path = db_path.parent / f"{db_path.name}.backup-{timestamp}" + + try: + # Open source database + source = sqlite3.connect(str(db_path)) + + # Create backup database + backup = sqlite3.connect(str(backup_path)) + + # Use SQLite's backup API to safely copy + with backup: + source.backup(backup) + + backup.close() + source.close() + + print(f"Database backed up successfully to: {backup_path}") + except sqlite3.Error as e: + print(f"Error creating backup: {e}", file=sys.stderr) + sys.exit(1) + + if __name__ == "__main__": main() diff --git a/mft/database.py b/mft/database.py index 6559930..92b0bdb 100644 --- a/mft/database.py +++ b/mft/database.py @@ -1,9 +1,19 @@ import sqlite3 from contextlib import contextmanager from typing import Generator +from pathlib import Path from mft.settings import settings +SCHEMA_PATH = Path(__file__).parent / "schema" +SCHEMA_VERSION = 1 + + +class SchemaError(RuntimeError): + """Signals a database schema related error""" + + pass + @contextmanager def get_db() -> Generator[sqlite3.Connection, None, None]: @@ -14,3 +24,66 @@ def get_db() -> Generator[sqlite3.Connection, None, None]: yield conn finally: conn.close() + + +def db_validate_schema(conn: sqlite3.Connection) -> None: + """ + Verifies if the current database has the correct schema + + Raises SchemaError if the schema is incorrect + """ + version = db_schema_version(conn) + if version != SCHEMA_VERSION: + raise SchemaError(f"Schema version is {version} but expected {SCHEMA_VERSION}") + + +def db_schema_version(conn: sqlite3.Connection) -> int: + """ + Return the version number of the schema in the database + + Returns 0 for an empty database. + Raises SchemaError when the database schema is in an unknown or + incompatible state. + """ + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM sqlite_master") + if cursor.fetchone()[0] == 0: + return 0 + + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='config'" + ) + if not cursor.fetchone(): + raise SchemaError("Config table missing") + + cursor.execute("SELECT value FROM config WHERE key = 'schema.version'") + result = cursor.fetchone() + if not result: + raise SchemaError("Config table exists but schema.version is missing") + + try: + version = int(result[0]) + if version <= 0: + raise SchemaError(f"Invalid schema version value: {version}") + return version + except (ValueError, TypeError) as e: + raise SchemaError(f"Invalid schema version value: {result[0]}") from e + + +def db_init(conn: sqlite3.Connection) -> None: + """ + Initializes the database with the default schema. + + Raises SchemaError if the database is in an unknown, incompatible or + already initialized state. + """ + version = db_schema_version(conn) + + if version != 0: + raise SchemaError(f"Database is already initialized (schema version {version})") + + # Read and execute the schema file + schema_file = SCHEMA_PATH / "schema.sql" + schema_sql = schema_file.read_text() + conn.executescript(schema_sql) + conn.commit()