Implement db init and db backup commands in the cli

This commit is contained in:
2025-12-24 12:38:29 +01:00
parent b4c84ab7ea
commit bb6575d403
2 changed files with 141 additions and 0 deletions

View File

@@ -1,8 +1,10 @@
import os import os
import sys import sys
import argparse import argparse
import sqlite3
import uvicorn import uvicorn
import mft.config import mft.config
from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -33,6 +35,17 @@ def parse_args() -> argparse.Namespace:
run_parser.add_argument( run_parser.add_argument(
"--debug", action="store_true", help="Run in debug mode with auto-reload" "--debug", action="store_true", help="Run in debug mode with auto-reload"
) )
# db subcommand with nested operations
db_parser = subparsers.add_parser("db", help="Database operations")
db_subparsers = db_parser.add_subparsers(dest="db_command", required=True)
# db init
db_init_parser = db_subparsers.add_parser("init", help="Initialize database")
# db backup
db_backup_parser = db_subparsers.add_parser("backup", help="Create database backup")
return transform_args(parser.parse_args()) return transform_args(parser.parse_args())
@@ -58,6 +71,8 @@ def main():
if args.command == "run": if args.command == "run":
run_command(args, settings) run_command(args, settings)
elif args.command == "db":
db_command(args, settings)
def run_command(args, settings): def run_command(args, settings):
@@ -70,5 +85,58 @@ def run_command(args, settings):
) )
def db_command(args, settings):
if args.db_command == "init":
db_init_command(args, settings)
elif args.db_command == "backup":
db_backup_command(args, settings)
def db_init_command(args, settings):
from mft.database import get_db, db_init, SchemaError
with get_db() as conn:
try:
db_init(conn)
except SchemaError as e:
print(
f"Database initialization failed with a SchemaError: {e}",
file=sys.stderr,
)
sys.exit(1)
print("Database initialized")
def db_backup_command(args, settings):
db_path = settings.database_path
if not db_path.exists():
print(f"Error: Database file `{db_path}` does not exist.", file=sys.stderr)
sys.exit(1)
# Generate backup filename with timestamp
timestamp = datetime.now().strftime("%Y-%m-%d.%H%M%S")
backup_path = db_path.parent / f"{db_path.name}.backup-{timestamp}"
try:
# Open source database
source = sqlite3.connect(str(db_path))
# Create backup database
backup = sqlite3.connect(str(backup_path))
# Use SQLite's backup API to safely copy
with backup:
source.backup(backup)
backup.close()
source.close()
print(f"Database backed up successfully to: {backup_path}")
except sqlite3.Error as e:
print(f"Error creating backup: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,9 +1,19 @@
import sqlite3 import sqlite3
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator from typing import Generator
from pathlib import Path
from mft.settings import settings from mft.settings import settings
SCHEMA_PATH = Path(__file__).parent / "schema"
SCHEMA_VERSION = 1
class SchemaError(RuntimeError):
"""Signals a database schema related error"""
pass
@contextmanager @contextmanager
def get_db() -> Generator[sqlite3.Connection, None, None]: def get_db() -> Generator[sqlite3.Connection, None, None]:
@@ -14,3 +24,66 @@ def get_db() -> Generator[sqlite3.Connection, None, None]:
yield conn yield conn
finally: finally:
conn.close() conn.close()
def db_validate_schema(conn: sqlite3.Connection) -> None:
"""
Verifies if the current database has the correct schema
Raises SchemaError if the schema is incorrect
"""
version = db_schema_version(conn)
if version != SCHEMA_VERSION:
raise SchemaError(f"Schema version is {version} but expected {SCHEMA_VERSION}")
def db_schema_version(conn: sqlite3.Connection) -> int:
"""
Return the version number of the schema in the database
Returns 0 for an empty database.
Raises SchemaError when the database schema is in an unknown or
incompatible state.
"""
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM sqlite_master")
if cursor.fetchone()[0] == 0:
return 0
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='config'"
)
if not cursor.fetchone():
raise SchemaError("Config table missing")
cursor.execute("SELECT value FROM config WHERE key = 'schema.version'")
result = cursor.fetchone()
if not result:
raise SchemaError("Config table exists but schema.version is missing")
try:
version = int(result[0])
if version <= 0:
raise SchemaError(f"Invalid schema version value: {version}")
return version
except (ValueError, TypeError) as e:
raise SchemaError(f"Invalid schema version value: {result[0]}") from e
def db_init(conn: sqlite3.Connection) -> None:
"""
Initializes the database with the default schema.
Raises SchemaError if the database is in an unknown, incompatible or
already initialized state.
"""
version = db_schema_version(conn)
if version != 0:
raise SchemaError(f"Database is already initialized (schema version {version})")
# Read and execute the schema file
schema_file = SCHEMA_PATH / "schema.sql"
schema_sql = schema_file.read_text()
conn.executescript(schema_sql)
conn.commit()