Add basic database interaction

This commit is contained in:
2025-05-08 14:18:54 +02:00
parent 8a896b27d3
commit cd0fb4edd2
3 changed files with 166 additions and 0 deletions

2
go.mod
View File

@ -1,3 +1,5 @@
module git.omicron.one/omicron/linkshare
go 1.24
require github.com/mattn/go-sqlite3 v1.14.28

2
go.sum
View File

@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=

View File

@ -0,0 +1,162 @@
// Package database provides all database interactions for linkshare.
// This includes functions to read and write structured link data, setting and
// getting configurations, updating and initializing the schema and backing up
// data
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
_ "github.com/mattn/go-sqlite3"
)
const expectedSchemaVersion = 1
// DB represents a database connection
type DB struct {
conn *sql.DB
}
var (
ErrDatabaseNotInitialized = errors.New("database not initialized")
ErrDatabaseSchemaOutdated = errors.New("database schema needs updating")
ErrDatabaseSchemaUnsupported = errors.New("database schema is too new for the server")
ErrMigrationFailed = errors.New("migration failed")
)
// Open opens a connection to the sqlite database at the given path
func Open(dbPath string) (*DB, error) {
conn, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
conn.SetMaxOpenConns(1) // SQLite only supports one writer at a time
if err := conn.Ping(); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
_, err = conn.Exec("PRAGMA foreign_keys = ON")
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to enable foreign key constraints: %w", err)
}
return &DB{conn: conn}, nil
}
// Close closes the database connection if it's open
func (db *DB) Close() error {
if db.conn != nil {
return db.conn.Close()
}
return nil
}
// Initialize the database schema
func (db *DB) Initialize(schemaPath string) error {
err := db.checkIsInitialized()
if err == nil {
return nil
}
currentSchema := filepath.Join(schemaPath, "current.sql")
schema, err := os.ReadFile(currentSchema)
if err != nil {
return fmt.Errorf("failed to read schema file: %w", err)
}
_, err = db.conn.Exec(string(schema))
if err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
return nil
}
func (db *DB) checkIsInitialized() error {
var count int
err := db.conn.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='settings'").Scan(&count)
if err != nil {
return fmt.Errorf("failed to check if database is initialized: %w", err)
}
if count == 0 {
return ErrDatabaseNotInitialized
}
return nil
}
func (db *DB) getSchemaVersion() (int, error) {
var version string
err := db.conn.QueryRow("SELECT value FROM settings WHERE key='schema-version'").Scan(&version)
if err != nil {
return 0, fmt.Errorf("failed to get schema version: %w", err)
}
versionInt, err := strconv.Atoi(version)
if err != nil {
return 0, fmt.Errorf("invalid schema version: %w", err)
}
if versionInt < 1 {
return 0, fmt.Errorf("invalid schema version %d", versionInt)
}
return versionInt, nil
}
// CheckSchemaVersion verifies that the schema is initialized and has the correct version
func (db *DB) CheckSchemaVersion() error {
err := db.checkIsInitialized()
if err != nil {
return err
}
version, err := db.getSchemaVersion()
if err != nil {
return err
}
if version < expectedSchemaVersion {
return ErrDatabaseSchemaOutdated
} else if version > expectedSchemaVersion {
return ErrDatabaseSchemaUnsupported
}
return nil
}
// Transaction executes a function within a database transaction,
// handling commit/rollback automatically.
func (db *DB) Transaction(ctx context.Context, fn func(*sql.Tx) error) error {
tx, err := db.conn.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("error rolling back transaction: %v (original error: %w)", rbErr, err)
}
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}