Implement db init and db backup commands in the cli

This commit is contained in:
2025-12-24 12:38:29 +01:00
parent b4c84ab7ea
commit bb6575d403
2 changed files with 141 additions and 0 deletions

View File

@@ -1,8 +1,10 @@
import os
import sys
import argparse
import sqlite3
import uvicorn
import mft.config
from datetime import datetime
from pathlib import Path
@@ -33,6 +35,17 @@ def parse_args() -> argparse.Namespace:
run_parser.add_argument(
"--debug", action="store_true", help="Run in debug mode with auto-reload"
)
# db subcommand with nested operations
db_parser = subparsers.add_parser("db", help="Database operations")
db_subparsers = db_parser.add_subparsers(dest="db_command", required=True)
# db init
db_init_parser = db_subparsers.add_parser("init", help="Initialize database")
# db backup
db_backup_parser = db_subparsers.add_parser("backup", help="Create database backup")
return transform_args(parser.parse_args())
@@ -58,6 +71,8 @@ def main():
if args.command == "run":
run_command(args, settings)
elif args.command == "db":
db_command(args, settings)
def run_command(args, settings):
@@ -70,5 +85,58 @@ def run_command(args, settings):
)
def db_command(args, settings):
if args.db_command == "init":
db_init_command(args, settings)
elif args.db_command == "backup":
db_backup_command(args, settings)
def db_init_command(args, settings):
from mft.database import get_db, db_init, SchemaError
with get_db() as conn:
try:
db_init(conn)
except SchemaError as e:
print(
f"Database initialization failed with a SchemaError: {e}",
file=sys.stderr,
)
sys.exit(1)
print("Database initialized")
def db_backup_command(args, settings):
db_path = settings.database_path
if not db_path.exists():
print(f"Error: Database file `{db_path}` does not exist.", file=sys.stderr)
sys.exit(1)
# Generate backup filename with timestamp
timestamp = datetime.now().strftime("%Y-%m-%d.%H%M%S")
backup_path = db_path.parent / f"{db_path.name}.backup-{timestamp}"
try:
# Open source database
source = sqlite3.connect(str(db_path))
# Create backup database
backup = sqlite3.connect(str(backup_path))
# Use SQLite's backup API to safely copy
with backup:
source.backup(backup)
backup.close()
source.close()
print(f"Database backed up successfully to: {backup_path}")
except sqlite3.Error as e:
print(f"Error creating backup: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,19 @@
import sqlite3
from contextlib import contextmanager
from typing import Generator
from pathlib import Path
from mft.settings import settings
SCHEMA_PATH = Path(__file__).parent / "schema"
SCHEMA_VERSION = 1
class SchemaError(RuntimeError):
"""Signals a database schema related error"""
pass
@contextmanager
def get_db() -> Generator[sqlite3.Connection, None, None]:
@@ -14,3 +24,66 @@ def get_db() -> Generator[sqlite3.Connection, None, None]:
yield conn
finally:
conn.close()
def db_validate_schema(conn: sqlite3.Connection) -> None:
"""
Verifies if the current database has the correct schema
Raises SchemaError if the schema is incorrect
"""
version = db_schema_version(conn)
if version != SCHEMA_VERSION:
raise SchemaError(f"Schema version is {version} but expected {SCHEMA_VERSION}")
def db_schema_version(conn: sqlite3.Connection) -> int:
"""
Return the version number of the schema in the database
Returns 0 for an empty database.
Raises SchemaError when the database schema is in an unknown or
incompatible state.
"""
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM sqlite_master")
if cursor.fetchone()[0] == 0:
return 0
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='config'"
)
if not cursor.fetchone():
raise SchemaError("Config table missing")
cursor.execute("SELECT value FROM config WHERE key = 'schema.version'")
result = cursor.fetchone()
if not result:
raise SchemaError("Config table exists but schema.version is missing")
try:
version = int(result[0])
if version <= 0:
raise SchemaError(f"Invalid schema version value: {version}")
return version
except (ValueError, TypeError) as e:
raise SchemaError(f"Invalid schema version value: {result[0]}") from e
def db_init(conn: sqlite3.Connection) -> None:
"""
Initializes the database with the default schema.
Raises SchemaError if the database is in an unknown, incompatible or
already initialized state.
"""
version = db_schema_version(conn)
if version != 0:
raise SchemaError(f"Database is already initialized (schema version {version})")
# Read and execute the schema file
schema_file = SCHEMA_PATH / "schema.sql"
schema_sql = schema_file.read_text()
conn.executescript(schema_sql)
conn.commit()