Implement db init and db backup commands in the cli
This commit is contained in:
68
mft/cli.py
68
mft/cli.py
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
import sqlite3
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import mft.config
|
import mft.config
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -33,6 +35,17 @@ def parse_args() -> argparse.Namespace:
|
|||||||
run_parser.add_argument(
|
run_parser.add_argument(
|
||||||
"--debug", action="store_true", help="Run in debug mode with auto-reload"
|
"--debug", action="store_true", help="Run in debug mode with auto-reload"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# db subcommand with nested operations
|
||||||
|
db_parser = subparsers.add_parser("db", help="Database operations")
|
||||||
|
db_subparsers = db_parser.add_subparsers(dest="db_command", required=True)
|
||||||
|
|
||||||
|
# db init
|
||||||
|
db_init_parser = db_subparsers.add_parser("init", help="Initialize database")
|
||||||
|
|
||||||
|
# db backup
|
||||||
|
db_backup_parser = db_subparsers.add_parser("backup", help="Create database backup")
|
||||||
|
|
||||||
return transform_args(parser.parse_args())
|
return transform_args(parser.parse_args())
|
||||||
|
|
||||||
|
|
||||||
@@ -58,6 +71,8 @@ def main():
|
|||||||
|
|
||||||
if args.command == "run":
|
if args.command == "run":
|
||||||
run_command(args, settings)
|
run_command(args, settings)
|
||||||
|
elif args.command == "db":
|
||||||
|
db_command(args, settings)
|
||||||
|
|
||||||
|
|
||||||
def run_command(args, settings):
|
def run_command(args, settings):
|
||||||
@@ -70,5 +85,58 @@ def run_command(args, settings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def db_command(args, settings):
|
||||||
|
if args.db_command == "init":
|
||||||
|
db_init_command(args, settings)
|
||||||
|
elif args.db_command == "backup":
|
||||||
|
db_backup_command(args, settings)
|
||||||
|
|
||||||
|
|
||||||
|
def db_init_command(args, settings):
|
||||||
|
from mft.database import get_db, db_init, SchemaError
|
||||||
|
|
||||||
|
with get_db() as conn:
|
||||||
|
try:
|
||||||
|
db_init(conn)
|
||||||
|
except SchemaError as e:
|
||||||
|
print(
|
||||||
|
f"Database initialization failed with a SchemaError: {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
print("Database initialized")
|
||||||
|
|
||||||
|
|
||||||
|
def db_backup_command(args, settings):
|
||||||
|
db_path = settings.database_path
|
||||||
|
|
||||||
|
if not db_path.exists():
|
||||||
|
print(f"Error: Database file `{db_path}` does not exist.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Generate backup filename with timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d.%H%M%S")
|
||||||
|
backup_path = db_path.parent / f"{db_path.name}.backup-{timestamp}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Open source database
|
||||||
|
source = sqlite3.connect(str(db_path))
|
||||||
|
|
||||||
|
# Create backup database
|
||||||
|
backup = sqlite3.connect(str(backup_path))
|
||||||
|
|
||||||
|
# Use SQLite's backup API to safely copy
|
||||||
|
with backup:
|
||||||
|
source.backup(backup)
|
||||||
|
|
||||||
|
backup.close()
|
||||||
|
source.close()
|
||||||
|
|
||||||
|
print(f"Database backed up successfully to: {backup_path}")
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
print(f"Error creating backup: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,9 +1,19 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from mft.settings import settings
|
from mft.settings import settings
|
||||||
|
|
||||||
|
SCHEMA_PATH = Path(__file__).parent / "schema"
|
||||||
|
SCHEMA_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaError(RuntimeError):
|
||||||
|
"""Signals a database schema related error"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db() -> Generator[sqlite3.Connection, None, None]:
|
def get_db() -> Generator[sqlite3.Connection, None, None]:
|
||||||
@@ -14,3 +24,66 @@ def get_db() -> Generator[sqlite3.Connection, None, None]:
|
|||||||
yield conn
|
yield conn
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def db_validate_schema(conn: sqlite3.Connection) -> None:
|
||||||
|
"""
|
||||||
|
Verifies if the current database has the correct schema
|
||||||
|
|
||||||
|
Raises SchemaError if the schema is incorrect
|
||||||
|
"""
|
||||||
|
version = db_schema_version(conn)
|
||||||
|
if version != SCHEMA_VERSION:
|
||||||
|
raise SchemaError(f"Schema version is {version} but expected {SCHEMA_VERSION}")
|
||||||
|
|
||||||
|
|
||||||
|
def db_schema_version(conn: sqlite3.Connection) -> int:
|
||||||
|
"""
|
||||||
|
Return the version number of the schema in the database
|
||||||
|
|
||||||
|
Returns 0 for an empty database.
|
||||||
|
Raises SchemaError when the database schema is in an unknown or
|
||||||
|
incompatible state.
|
||||||
|
"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT COUNT(*) FROM sqlite_master")
|
||||||
|
if cursor.fetchone()[0] == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='config'"
|
||||||
|
)
|
||||||
|
if not cursor.fetchone():
|
||||||
|
raise SchemaError("Config table missing")
|
||||||
|
|
||||||
|
cursor.execute("SELECT value FROM config WHERE key = 'schema.version'")
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if not result:
|
||||||
|
raise SchemaError("Config table exists but schema.version is missing")
|
||||||
|
|
||||||
|
try:
|
||||||
|
version = int(result[0])
|
||||||
|
if version <= 0:
|
||||||
|
raise SchemaError(f"Invalid schema version value: {version}")
|
||||||
|
return version
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
raise SchemaError(f"Invalid schema version value: {result[0]}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def db_init(conn: sqlite3.Connection) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the database with the default schema.
|
||||||
|
|
||||||
|
Raises SchemaError if the database is in an unknown, incompatible or
|
||||||
|
already initialized state.
|
||||||
|
"""
|
||||||
|
version = db_schema_version(conn)
|
||||||
|
|
||||||
|
if version != 0:
|
||||||
|
raise SchemaError(f"Database is already initialized (schema version {version})")
|
||||||
|
|
||||||
|
# Read and execute the schema file
|
||||||
|
schema_file = SCHEMA_PATH / "schema.sql"
|
||||||
|
schema_sql = schema_file.read_text()
|
||||||
|
conn.executescript(schema_sql)
|
||||||
|
conn.commit()
|
||||||
|
|||||||
Reference in New Issue
Block a user