Implement db init and db backup commands in the cli
This commit is contained in:
68
mft/cli.py
68
mft/cli.py
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import sqlite3
|
||||
import uvicorn
|
||||
import mft.config
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -33,6 +35,17 @@ def parse_args() -> argparse.Namespace:
|
||||
run_parser.add_argument(
|
||||
"--debug", action="store_true", help="Run in debug mode with auto-reload"
|
||||
)
|
||||
|
||||
# db subcommand with nested operations
|
||||
db_parser = subparsers.add_parser("db", help="Database operations")
|
||||
db_subparsers = db_parser.add_subparsers(dest="db_command", required=True)
|
||||
|
||||
# db init
|
||||
db_init_parser = db_subparsers.add_parser("init", help="Initialize database")
|
||||
|
||||
# db backup
|
||||
db_backup_parser = db_subparsers.add_parser("backup", help="Create database backup")
|
||||
|
||||
return transform_args(parser.parse_args())
|
||||
|
||||
|
||||
@@ -58,6 +71,8 @@ def main():
|
||||
|
||||
if args.command == "run":
|
||||
run_command(args, settings)
|
||||
elif args.command == "db":
|
||||
db_command(args, settings)
|
||||
|
||||
|
||||
def run_command(args, settings):
|
||||
@@ -70,5 +85,58 @@ def run_command(args, settings):
|
||||
)
|
||||
|
||||
|
||||
def db_command(args, settings):
|
||||
if args.db_command == "init":
|
||||
db_init_command(args, settings)
|
||||
elif args.db_command == "backup":
|
||||
db_backup_command(args, settings)
|
||||
|
||||
|
||||
def db_init_command(args, settings):
|
||||
from mft.database import get_db, db_init, SchemaError
|
||||
|
||||
with get_db() as conn:
|
||||
try:
|
||||
db_init(conn)
|
||||
except SchemaError as e:
|
||||
print(
|
||||
f"Database initialization failed with a SchemaError: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
print("Database initialized")
|
||||
|
||||
|
||||
def db_backup_command(args, settings):
|
||||
db_path = settings.database_path
|
||||
|
||||
if not db_path.exists():
|
||||
print(f"Error: Database file `{db_path}` does not exist.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Generate backup filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d.%H%M%S")
|
||||
backup_path = db_path.parent / f"{db_path.name}.backup-{timestamp}"
|
||||
|
||||
try:
|
||||
# Open source database
|
||||
source = sqlite3.connect(str(db_path))
|
||||
|
||||
# Create backup database
|
||||
backup = sqlite3.connect(str(backup_path))
|
||||
|
||||
# Use SQLite's backup API to safely copy
|
||||
with backup:
|
||||
source.backup(backup)
|
||||
|
||||
backup.close()
|
||||
source.close()
|
||||
|
||||
print(f"Database backed up successfully to: {backup_path}")
|
||||
except sqlite3.Error as e:
|
||||
print(f"Error creating backup: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from pathlib import Path
|
||||
|
||||
from mft.settings import settings
|
||||
|
||||
SCHEMA_PATH = Path(__file__).parent / "schema"
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
class SchemaError(RuntimeError):
|
||||
"""Signals a database schema related error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db() -> Generator[sqlite3.Connection, None, None]:
|
||||
@@ -14,3 +24,66 @@ def get_db() -> Generator[sqlite3.Connection, None, None]:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def db_validate_schema(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Verifies if the current database has the correct schema
|
||||
|
||||
Raises SchemaError if the schema is incorrect
|
||||
"""
|
||||
version = db_schema_version(conn)
|
||||
if version != SCHEMA_VERSION:
|
||||
raise SchemaError(f"Schema version is {version} but expected {SCHEMA_VERSION}")
|
||||
|
||||
|
||||
def db_schema_version(conn: sqlite3.Connection) -> int:
|
||||
"""
|
||||
Return the version number of the schema in the database
|
||||
|
||||
Returns 0 for an empty database.
|
||||
Raises SchemaError when the database schema is in an unknown or
|
||||
incompatible state.
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM sqlite_master")
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return 0
|
||||
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='config'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
raise SchemaError("Config table missing")
|
||||
|
||||
cursor.execute("SELECT value FROM config WHERE key = 'schema.version'")
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
raise SchemaError("Config table exists but schema.version is missing")
|
||||
|
||||
try:
|
||||
version = int(result[0])
|
||||
if version <= 0:
|
||||
raise SchemaError(f"Invalid schema version value: {version}")
|
||||
return version
|
||||
except (ValueError, TypeError) as e:
|
||||
raise SchemaError(f"Invalid schema version value: {result[0]}") from e
|
||||
|
||||
|
||||
def db_init(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Initializes the database with the default schema.
|
||||
|
||||
Raises SchemaError if the database is in an unknown, incompatible or
|
||||
already initialized state.
|
||||
"""
|
||||
version = db_schema_version(conn)
|
||||
|
||||
if version != 0:
|
||||
raise SchemaError(f"Database is already initialized (schema version {version})")
|
||||
|
||||
# Read and execute the schema file
|
||||
schema_file = SCHEMA_PATH / "schema.sql"
|
||||
schema_sql = schema_file.read_text()
|
||||
conn.executescript(schema_sql)
|
||||
conn.commit()
|
||||
|
||||
Reference in New Issue
Block a user