Compare commits

...

11 Commits

Author SHA1 Message Date
50e1e3345b Add database functionality to manipulate the links table 2025-05-11 02:33:04 +02:00
914836cf34 Add an index to the schema for queries that exclude private links 2025-05-11 01:42:35 +02:00
0ec233e7c5 Add test coverage to the makefile 2025-05-10 23:58:48 +02:00
925d588f71 Add tests for db manager functionality 2025-05-10 23:15:07 +02:00
eac3bc4ff5 Export Transaction from the database 2025-05-10 22:12:09 +02:00
a9a9b4d9bb Add Option type 2025-05-10 13:49:14 +02:00
f7c72626ee Use cobra to turn linkctl into a proper cli
Most commands are currently placeholders but version and db init work
2025-05-09 03:37:52 +02:00
e66d800881 Add version package and update the makefile
The makefile will grab the version info from git and pass it to the
linker so that version information based on tags, commits and commit
times is available in the code.
2025-05-09 03:37:49 +02:00
9acc9a03aa Add basic database interaction 2025-05-09 03:36:29 +02:00
8a896b27d3 Add database schema 2025-05-08 15:11:54 +02:00
e4923b500a Move formatting to gofumpt (from gofmt) 2025-05-04 00:50:44 +02:00
16 changed files with 1398 additions and 11 deletions

1
.gitignore vendored
View File

@@ -1 +1,2 @@
/bin /bin
/reports

View File

@@ -3,6 +3,16 @@ BINARIES = $(patsubst cmd/%/,%,$(wildcard cmd/*/))
.PHONY: all build test validate clean run $(BINARIES) .PHONY: all build test validate clean run $(BINARIES)
VERSION := $(shell git describe --tags --always --dirty)
COMMIT := $(shell git rev-parse --short HEAD)
COMMIT_DATETIME := $(shell git log -1 --format=%cd --date=iso8601)
LDFLAGS := -X git.omicron.one/omicron/linkshare/internal/version.Version=$(VERSION) \
-X git.omicron.one/omicron/linkshare/internal/version.GitCommit=$(COMMIT) \
-X "git.omicron.one/omicron/linkshare/internal/version.CommitDateTime=$(COMMIT_DATETIME)"
OPEN = xdg-open
all: build all: build
@@ -13,17 +23,21 @@ $(BINARY_DIR):
mkdir -p $(BINARY_DIR) mkdir -p $(BINARY_DIR)
$(BINARIES): %: $(BINARY_DIR) $(BINARIES): %: $(BINARY_DIR)
go build -o $(BINARY_DIR)/$@ ./cmd/$@/ go build -ldflags '$(LDFLAGS)' -o $(BINARY_DIR)/$@ ./cmd/$@/
test: test:
go test ./... mkdir -p reports/coverage/
go test ./... -coverprofile=reports/coverage/coverage.out
go tool cover -html=reports/coverage/coverage.out -o reports/coverage/coverage.html && $(OPEN) reports/coverage/coverage.html
validate: validate:
@test -z "$(shell gofmt -l .)" || (echo "Incorrect formatting in:"; gofmt -l .; exit 1) @test -z "$(shell gofumpt -l .)" && echo "No files need formatting" || (echo "Incorrect formatting in:"; gofumpt -l .; exit 1)
go vet ./... go vet ./...
clean: clean:
rm -rf $(BINARY_DIR) rm -rf $(BINARY_DIR)
rm -rf reports
go clean go clean
run: $(LINKSERV) run: $(LINKSERV)

27
cmd/linkctl/config.go Normal file
View File

@@ -0,0 +1,27 @@
package main
import (
"fmt"
"github.com/spf13/cobra"
)
func configPreRun(cmd *cobra.Command, args []string) error {
return setupDb()
}
func configPostRun(cmd *cobra.Command, args []string) error {
return cleanupDb()
}
func configSetHandler(cmd *cobra.Command, args []string) {
fmt.Println("Not implemented")
}
func configGetHandler(cmd *cobra.Command, args []string) {
fmt.Println("Not implemented")
}
func configListHandler(cmd *cobra.Command, args []string) {
fmt.Println("Not implemented")
}

48
cmd/linkctl/db.go Normal file
View File

@@ -0,0 +1,48 @@
package main
import (
"fmt"
"git.omicron.one/omicron/linkshare/internal/database"
"git.omicron.one/omicron/linkshare/internal/util"
"git.omicron.one/omicron/linkshare/internal/version"
"github.com/spf13/cobra"
)
func openDB() (*database.DB, error) {
paths, err := util.FindDirectories(dbPath)
if err != nil {
return nil, err
}
return database.Open(paths.DatabaseFile)
}
func dbPreRun(cmd *cobra.Command, args []string) error {
return setupDb()
}
func dbPostRun(cmd *cobra.Command, args []string) error {
return cleanupDb()
}
func dbInitHandler(cmd *cobra.Command, args []string) {
err := db.Initialize(paths.SchemaDir)
if err == database.ErrAlreadyInitialized {
fmt.Printf("Database %q is already initialized\n", dbPath)
return
}
if err == nil {
fmt.Printf("Initialized database %q with schema version %d\n", dbPath, version.SchemaVersion)
return
}
fmt.Printf("Failed to initialize database %q: %v\n", dbPath, err)
}
func dbBackupHandler(cmd *cobra.Command, args []string) {
fmt.Println("Not implemented")
}
func dbUpdateHandler(cmd *cobra.Command, args []string) {
fmt.Println("Not implemented")
}

View File

@@ -1,15 +1,157 @@
package main package main
import "fmt" import (
import "git.omicron.one/omicron/linkshare/internal/util" "fmt"
"os"
func main() { "git.omicron.one/omicron/linkshare/internal/database"
paths, err := util.FindDirectories("") "git.omicron.one/omicron/linkshare/internal/util"
"git.omicron.one/omicron/linkshare/internal/version"
"github.com/spf13/cobra"
)
var (
dbPath string
verbosity int
)
var (
paths *util.AppPaths
db *database.DB
)
func exitIfError(err error) {
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1)
}
} }
fmt.Println("Paths:") func setupPaths() error {
fmt.Println(" Schema:", paths.SchemaDir) if paths != nil {
fmt.Println(" Database:", paths.DatabaseFile) return nil
}
paths_, err := util.FindDirectories(dbPath)
if err != nil {
return err
}
paths = paths_
return nil
}
func setupDb() error {
if db != nil {
return nil
}
err := setupPaths()
if err != nil {
return nil
}
db_, err := database.Open(dbPath)
if err != nil {
return nil
}
db = db_
return nil
}
func cleanupDb() error {
if db != nil {
err := db.Close()
if err != nil {
return err
}
}
return nil
}
func main() {
rootCmd := &cobra.Command{
Use: "linkctl",
Short: "LinkShare CLI tool",
Long: `Command line tool to manage your self-hosted LinkShare service.`,
Run: func(cmd *cobra.Command, args []string) {
cmd.Help()
},
}
rootCmd.CompletionOptions.DisableDefaultCmd = true
rootCmd.PersistentFlags().StringVarP(&dbPath, "db", "d", "", "Database file path")
rootCmd.PersistentFlags().CountVarP(&verbosity, "verbose", "v", "Increase verbosity level")
configCmd := &cobra.Command{
Use: "config",
Short: "Configuration commands",
Run: func(cmd *cobra.Command, args []string) {
cmd.Help()
},
PersistentPreRunE: configPreRun,
PersistentPostRunE: configPostRun,
}
configSetCmd := &cobra.Command{
Use: "set",
Short: "Set a configuration value",
Run: configSetHandler,
}
configGetCmd := &cobra.Command{
Use: "get",
Short: "Get a configuration value",
Run: configGetHandler,
}
configListCmd := &cobra.Command{
Use: "list",
Short: "List all configuration values",
Run: configListHandler,
}
configCmd.AddCommand(configSetCmd, configGetCmd, configListCmd)
dbCmd := &cobra.Command{
Use: "db",
Short: "Database commands",
Run: func(cmd *cobra.Command, args []string) {
cmd.Help()
},
PersistentPreRunE: dbPreRun,
PersistentPostRunE: dbPostRun,
}
dbInitCmd := &cobra.Command{
Use: "init",
Short: "Initialize the database",
Run: dbInitHandler,
}
dbBackupCmd := &cobra.Command{
Use: "backup",
Short: "Backup the database",
Run: dbBackupHandler,
}
dbUpdateCmd := &cobra.Command{
Use: "update",
Short: "Update the database schema",
Run: dbUpdateHandler,
}
dbCmd.AddCommand(dbInitCmd, dbBackupCmd, dbUpdateCmd)
versionCmd := &cobra.Command{
Use: "version",
Short: "Display version information",
Run: func(cmd *cobra.Command, args []string) {
version.Print()
},
}
rootCmd.AddCommand(configCmd, dbCmd, versionCmd)
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
} }

10
go.mod
View File

@@ -1,3 +1,13 @@
module git.omicron.one/omicron/linkshare module git.omicron.one/omicron/linkshare
go 1.24 go 1.24
require (
github.com/mattn/go-sqlite3 v1.14.28
github.com/spf13/cobra v1.9.1
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect
)

12
go.sum
View File

@@ -0,0 +1,12 @@
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,192 @@
package links
import (
"database/sql"
"time"
"git.omicron.one/omicron/linkshare/internal/database"
. "git.omicron.one/omicron/linkshare/internal/util/option"
)
// Link represents a stored link
type Link struct {
ID int64
URL string
Title string
CreatedAt time.Time
UpdatedAt Option[time.Time]
IsPrivate bool
}
// Repository handles link storage operations
type Repository struct {
db *database.DB
}
// NewRepository creates a new link repository
func NewRepository(db *database.DB) *Repository {
return &Repository{db: db}
}
// Create adds a new link to the database
func (r *Repository) Create(url, title string, isPrivate bool) (int64, error) {
var id int64
err := r.db.Transaction(func(tx *sql.Tx) error {
now := time.Now().UTC().Format(time.RFC3339)
result, err := tx.Exec(
"INSERT INTO links (url, title, created_at, is_private) VALUES (?, ?, ?, ?)",
url, title, now, isPrivate,
)
if err != nil {
return err
}
id, err = result.LastInsertId()
return err
})
return id, err
}
// Get retrieves a single link by ID
func (r *Repository) Get(id int64) (*Link, error) {
var (
link Link
createdAt string
updatedAt sql.NullString
)
err := r.db.Transaction(func(tx *sql.Tx) error {
row := tx.QueryRow(
"SELECT id, url, title, created_at, updated_at, is_private FROM links WHERE id = ?",
id,
)
err := row.Scan(&link.ID, &link.URL, &link.Title, &createdAt, &updatedAt, &link.IsPrivate)
if err != nil {
return err
}
created, err := time.Parse(time.RFC3339, createdAt)
if err != nil {
return err
}
link.CreatedAt = created
if updatedAt.Valid {
updated, err := time.Parse(time.RFC3339, updatedAt.String)
if err != nil {
return err
}
link.UpdatedAt = Some(updated)
} else {
link.UpdatedAt = None[time.Time]()
}
return nil
})
if err != nil {
return nil, err
}
return &link, nil
}
// Update updates an existing link's fields
func (r *Repository) Update(id int64, url, title string, isPrivate bool) error {
return r.db.Transaction(func(tx *sql.Tx) error {
now := time.Now().UTC().Format(time.RFC3339)
_, err := tx.Exec(
"UPDATE links SET url = ?, title = ?, updated_at = ?, is_private = ? WHERE id = ?",
url, title, now, isPrivate, id,
)
return err
})
}
// Delete removes a link from the database
func (r *Repository) Delete(id int64) error {
return r.db.Transaction(func(tx *sql.Tx) error {
_, err := tx.Exec("DELETE FROM links WHERE id = ?", id)
return err
})
}
// List returns a paginated list of links
func (r *Repository) List(includePrivate bool, offset, limit int) ([]*Link, error) {
var links []*Link
err := r.db.Transaction(func(tx *sql.Tx) error {
var rows *sql.Rows
var err error
if includePrivate {
rows, err = tx.Query(
`SELECT id, url, title, created_at, updated_at, is_private
FROM links ORDER BY created_at DESC LIMIT ? OFFSET ?`,
limit, offset,
)
} else {
rows, err = tx.Query(
`SELECT id, url, title, created_at, updated_at, is_private
FROM links WHERE is_private = 0 ORDER BY created_at DESC LIMIT ? OFFSET ?`,
limit, offset,
)
}
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
link Link
createdAt string
updatedAt sql.NullString
)
err := rows.Scan(&link.ID, &link.URL, &link.Title, &createdAt, &updatedAt, &link.IsPrivate)
if err != nil {
return err
}
created, err := time.Parse(time.RFC3339, createdAt)
if err != nil {
return err
}
link.CreatedAt = created
if updatedAt.Valid {
updated, err := time.Parse(time.RFC3339, updatedAt.String)
if err != nil {
return err
}
link.UpdatedAt = Some(updated)
} else {
link.UpdatedAt = None[time.Time]()
}
links = append(links, &link)
}
return rows.Err()
})
if err != nil {
return nil, err
}
return links, nil
}
// Count returns the total number of links in the database
func (r *Repository) Count(includePrivate bool) (int, error) {
var count int
err := r.db.Transaction(func(tx *sql.Tx) error {
var row *sql.Row
if includePrivate {
row = tx.QueryRow("SELECT COUNT(*) FROM links")
} else {
row = tx.QueryRow("SELECT COUNT(*) FROM links WHERE is_private")
}
return row.Scan(&count)
})
return count, err
}

View File

@@ -0,0 +1,329 @@
package links_test
import (
"os"
"testing"
"time"
"git.omicron.one/omicron/linkshare/internal/database"
"git.omicron.one/omicron/linkshare/internal/database/links"
)
func setupTestDB(t *testing.T) (*database.DB, string) {
t.Helper()
cwd, err := os.Getwd()
t.Logf("Current working directory: %s", cwd)
// Create temp file for database
tempFile, err := os.CreateTemp("", "linkshare-links-test-*.db")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
tempFile.Close()
dbPath := tempFile.Name()
// Open database
db, err := database.Open(dbPath)
if err != nil {
os.Remove(dbPath)
t.Fatalf("Failed to open database: %v", err)
}
// Initialize database with schema
err = db.Initialize("../../../schema")
if err != nil {
db.Close()
os.Remove(dbPath)
t.Fatalf("Failed to initialize database: %v", err)
}
return db, dbPath
}
func TestRepository_Create(t *testing.T) {
db, dbPath := setupTestDB(t)
defer func() {
db.Close()
os.Remove(dbPath)
}()
repo := links.NewRepository(db)
// Test creating a link
id, err := repo.Create("https://example.com", "Example", false)
if err != nil {
t.Fatalf("Failed to create link: %v", err)
}
if id <= 0 {
t.Fatalf("Expected positive ID, got %d", id)
}
// Verify link was created by retrieving it
link, err := repo.Get(id)
if err != nil {
t.Fatalf("Failed to get link: %v", err)
}
if link.URL != "https://example.com" {
t.Errorf("Expected URL 'https://example.com', got '%s'", link.URL)
}
}
func TestRepository_Get(t *testing.T) {
db, dbPath := setupTestDB(t)
defer func() {
db.Close()
os.Remove(dbPath)
}()
repo := links.NewRepository(db)
// Insert test data
id, err := repo.Create("https://example.com", "Example", true)
if err != nil {
t.Fatalf("Failed to create link: %v", err)
}
// Test getting a link
link, err := repo.Get(id)
if err != nil {
t.Fatalf("Failed to get link: %v", err)
}
if link.ID != id {
t.Errorf("Expected ID %d, got %d", id, link.ID)
}
if link.URL != "https://example.com" {
t.Errorf("Expected URL 'https://example.com', got '%s'", link.URL)
}
if link.Title != "Example" {
t.Errorf("Expected Title 'Example', got '%s'", link.Title)
}
if link.IsPrivate != true {
t.Errorf("Expected IsPrivate true, got %v", link.IsPrivate)
}
if link.UpdatedAt.IsSome() {
t.Errorf("Expected UpdatedAt to be None, got %v", link.UpdatedAt)
}
// Test getting non-existent link
_, err = repo.Get(id + 1)
if err == nil {
t.Fatal("Expected error when getting non-existent link")
}
}
func TestRepository_Update(t *testing.T) {
db, dbPath := setupTestDB(t)
defer func() {
db.Close()
os.Remove(dbPath)
}()
repo := links.NewRepository(db)
// Insert test data
id, err := repo.Create("https://example.com", "Example", false)
if err != nil {
t.Fatalf("Failed to create link: %v", err)
}
// Test updating a link
err = repo.Update(id, "https://updated.com", "Updated", true)
if err != nil {
t.Fatalf("Failed to update link: %v", err)
}
// Verify link was updated
link, err := repo.Get(id)
if err != nil {
t.Fatalf("Failed to get link: %v", err)
}
if link.URL != "https://updated.com" {
t.Errorf("Expected URL 'https://updated.com', got '%s'", link.URL)
}
if link.Title != "Updated" {
t.Errorf("Expected Title 'Updated', got '%s'", link.Title)
}
if link.IsPrivate != true {
t.Errorf("Expected IsPrivate true, got %v", link.IsPrivate)
}
if !link.UpdatedAt.IsSome() {
t.Error("Expected UpdatedAt to be set")
}
}
func TestRepository_Delete(t *testing.T) {
db, dbPath := setupTestDB(t)
defer func() {
db.Close()
os.Remove(dbPath)
}()
repo := links.NewRepository(db)
// Insert test data
id, err := repo.Create("https://example.com", "Example", false)
if err != nil {
t.Fatalf("Failed to create link: %v", err)
}
// Test deleting a link
err = repo.Delete(id)
if err != nil {
t.Fatalf("Failed to delete link: %v", err)
}
// Verify link was deleted
_, err = repo.Get(id)
if err == nil {
t.Fatal("Expected error after deletion")
}
}
func TestRepository_List(t *testing.T) {
db, dbPath := setupTestDB(t)
defer func() {
db.Close()
os.Remove(dbPath)
}()
repo := links.NewRepository(db)
// Insert test data
urls := []struct {
url string
isPrivate bool
}{
{"https://example1.com", true},
{"https://example2.com", false},
{"https://example3.com", false},
{"https://example4.com", true},
{"https://example5.com", false},
}
for i, info := range urls {
_, err := repo.Create(info.url, "Example "+string(rune('A'+i)), info.isPrivate)
if err != nil {
t.Fatalf("Failed to create link: %v", err)
}
// Add a small delay to ensure different created_at times
time.Sleep(10 * time.Millisecond)
}
// Test full listing with pagination
links, err := repo.List(true, 0, 3)
if err != nil {
t.Fatalf("Failed to list links: %v", err)
}
if len(links) != 3 {
t.Fatalf("Expected 3 links, got %d", len(links))
}
// Check order (newest first)
for i := 0; i < len(links)-1; i++ {
if links[i].CreatedAt.Before(links[i+1].CreatedAt) {
t.Errorf("Links not in correct order")
}
}
// Test second page of full listing
links, err = repo.List(true, 3, 2)
if err != nil {
t.Fatalf("Failed to list links: %v", err)
}
if len(links) != 2 {
t.Fatalf("Expected 2 links, got %d", len(links))
}
// Test public listing
links, err = repo.List(false, 0, 3)
if err != nil {
t.Fatalf("Failed to list links: %v", err)
}
if len(links) != 3 {
t.Fatalf("Expected 3 links, got %d", len(links))
}
for _, link := range links {
if link.IsPrivate {
t.Fatalf("private link in public listing %v", link)
}
}
// Try to get more public links
links, err = repo.List(false, 3, 3)
if err != nil {
t.Fatalf("Failed to list links: %v", err)
}
if len(links) != 0 {
t.Fatalf("Expected 0 links, got %d", len(links))
}
}
func TestRepository_Count(t *testing.T) {
db, dbPath := setupTestDB(t)
defer func() {
db.Close()
os.Remove(dbPath)
}()
repo := links.NewRepository(db)
// Check full count with empty table
count, err := repo.Count(true)
if err != nil {
t.Fatalf("Failed to count links: %v", err)
}
if count != 0 {
t.Fatalf("Expected 0 links, got %d", count)
}
// Check public count with empty table
count, err = repo.Count(false)
if err != nil {
t.Fatalf("Failed to count links: %v", err)
}
if count != 0 {
t.Fatalf("Expected 0 links, got %d", count)
}
// Insert test data
numLinks := 5
for i := 0; i < numLinks; i++ {
_, err := repo.Create(
"https://example"+string(rune('1'+i))+".com",
"Example "+string(rune('A'+i)),
i%2 == 1,
)
if err != nil {
t.Fatalf("Failed to create link: %v", err)
}
}
pubLinks := numLinks / 2
// Check full count again
count, err = repo.Count(true)
if err != nil {
t.Fatalf("Failed to count links: %v", err)
}
if count != numLinks {
t.Fatalf("Expected %d links, got %d", numLinks, count)
}
// Check public count again
count, err = repo.Count(false)
if err != nil {
t.Fatalf("Failed to count links: %v", err)
}
if count != pubLinks {
t.Fatalf("Expected %d links, got %d", pubLinks, count)
}
}

View File

@@ -0,0 +1,166 @@
// Package database provides all database interactions for linkshare.
// This includes functions to read and write structured link data, setting and
// getting configurations, updating and initializing the schema and backing up
// data
package database
import (
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
_ "github.com/mattn/go-sqlite3"
"git.omicron.one/omicron/linkshare/internal/version"
)
// DB represents a database connection
type DB struct {
conn *sql.DB
}
var (
ErrNotInitialized = errors.New("database not initialized")
ErrAlreadyInitialized = errors.New("database already initialized")
ErrSchemaOutdated = errors.New("database schema needs updating")
ErrSchemaUnsupported = errors.New("database schema is too new for the server")
ErrMigrationFailed = errors.New("migration failed")
)
// Open opens a connection to the sqlite database at the given path
func Open(dbPath string) (*DB, error) {
conn, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
conn.SetMaxOpenConns(1) // SQLite only supports one writer at a time
if err := conn.Ping(); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
_, err = conn.Exec("PRAGMA foreign_keys = ON")
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to enable foreign key constraints: %w", err)
}
return &DB{conn: conn}, nil
}
// Close closes the database connection if it's open
func (db *DB) Close() error {
if db.conn != nil {
return db.conn.Close()
}
return nil
}
// Initialize the database schema
func (db *DB) Initialize(schemaPath string) error {
err := db.CheckInitialized()
if err == nil {
return ErrAlreadyInitialized
}
currentSchema := filepath.Join(schemaPath, "current.sql")
schema, err := os.ReadFile(currentSchema)
if err != nil {
return fmt.Errorf("failed to read schema file: %w", err)
}
_, err = db.conn.Exec(string(schema))
if err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
return nil
}
// CheckInitialized returns nil if the database is initialized and an error otherwise
func (db *DB) CheckInitialized() error {
var count int
err := db.conn.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='settings'").Scan(&count)
if err != nil {
return fmt.Errorf("failed to check if database is initialized: %w", err)
}
if count == 0 {
return ErrNotInitialized
}
return nil
}
// GetSchemaVersion returns the schema version or an error
func (db *DB) GetSchemaVersion() (int, error) {
var version string
err := db.conn.QueryRow("SELECT value FROM settings WHERE key='schema-version'").Scan(&version)
if err != nil {
return 0, fmt.Errorf("failed to get schema version: %w", err)
}
versionInt, err := strconv.Atoi(version)
if err != nil {
return 0, fmt.Errorf("invalid schema version: %w", err)
}
if versionInt < 1 {
return 0, fmt.Errorf("invalid schema version %d", versionInt)
}
return versionInt, nil
}
// CheckSchemaVersion verifies that the schema is initialized and has the correct version
func (db *DB) CheckSchemaVersion() error {
err := db.CheckInitialized()
if err != nil {
return err
}
version_, err := db.GetSchemaVersion()
if err != nil {
return err
}
if version_ < version.SchemaVersion {
return ErrSchemaOutdated
} else if version_ > version.SchemaVersion {
return ErrSchemaUnsupported
}
return nil
}
// Transaction executes the provided function within a SQL transaction.
// If the function returns an error, the transaction is rolled back.
// If the function panics, the transaction is rolled back and the panic is re-thrown.
// The function receives a *sql.Tx that can be used for database operations.
func (db *DB) Transaction(fn func(*sql.Tx) error) error {
tx, err := db.conn.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("error rolling back transaction: %v (original error: %w)", rbErr, err)
}
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}

View File

@@ -0,0 +1,314 @@
package database
import (
"database/sql"
"os"
"path/filepath"
"testing"
"git.omicron.one/omicron/linkshare/internal/version"
)
func TestOpenClose(t *testing.T) {
// Create temp file for database
tempFile, err := os.CreateTemp("", "linkshare-test-*.db")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
tempFile.Close()
// Test opening
db, err := Open(tempFile.Name())
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
// Test closing
err = db.Close()
if err != nil {
t.Fatalf("Failed to close database: %v", err)
}
}
func TestInitialize(t *testing.T) {
// Create temp directory for test data
tempDir, err := os.MkdirTemp("", "linkshare-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create schema directory and current.sql file
schemaDir := filepath.Join(tempDir, "schema")
err = os.Mkdir(schemaDir, 0o755)
if err != nil {
t.Fatalf("Failed to create schema directory: %v", err)
}
// Write test schema to file
schemaContent := `CREATE TABLE settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
kind TEXT CHECK(kind IN ('int', 'string', 'bool', 'json', 'glob')) NOT NULL
);
INSERT INTO settings (key, value, kind) VALUES ('schema-version', '1', 'int');`
err = os.WriteFile(filepath.Join(schemaDir, "current.sql"), []byte(schemaContent), 0o644)
if err != nil {
t.Fatalf("Failed to write schema file: %v", err)
}
// Create temp database file
dbPath := filepath.Join(tempDir, "test.db")
// Open database
db, err := Open(dbPath)
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Test initialization
err = db.Initialize(schemaDir)
if err != nil {
t.Fatalf("Failed to initialize database: %v", err)
}
// Test already initialized error
err = db.Initialize(schemaDir)
if err != ErrAlreadyInitialized {
t.Fatalf("Expected ErrAlreadyInitialized, got: %v", err)
}
}
func TestCheckInitialized(t *testing.T) {
// Create temp file for database
tempFile, err := os.CreateTemp("", "linkshare-test-*.db")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
tempFile.Close()
// Open database
db, err := Open(tempFile.Name())
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Test not initialized
err = db.CheckInitialized()
if err != ErrNotInitialized {
t.Fatalf("Expected ErrNotInitialized, got: %v", err)
}
// Initialize the database manually for testing
_, err = db.conn.Exec("CREATE TABLE settings (key TEXT PRIMARY KEY, value TEXT NOT NULL, kind TEXT NOT NULL)")
if err != nil {
t.Fatalf("Failed to create settings table: %v", err)
}
// Test initialized
err = db.CheckInitialized()
if err != nil {
t.Fatalf("Expected nil error after initialization, got: %v", err)
}
}
func TestGetSchemaVersion(t *testing.T) {
// Create temp file for database
tempFile, err := os.CreateTemp("", "linkshare-test-*.db")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
tempFile.Close()
// Open database
db, err := Open(tempFile.Name())
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Initialize the database manually for testing
_, err = db.conn.Exec("CREATE TABLE settings (key TEXT PRIMARY KEY, value TEXT NOT NULL, kind TEXT NOT NULL)")
if err != nil {
t.Fatalf("Failed to create settings table: %v", err)
}
_, err = db.conn.Exec("INSERT INTO settings (key, value, kind) VALUES ('schema-version', '1', 'int')")
if err != nil {
t.Fatalf("Failed to insert schema version: %v", err)
}
// Test schema version
version, err := db.GetSchemaVersion()
if err != nil {
t.Fatalf("Failed to get schema version: %v", err)
}
if version != 1 {
t.Fatalf("Expected schema version 1, got: %d", version)
}
// Test invalid schema version
_, err = db.conn.Exec("UPDATE settings SET value = 'invalid' WHERE key = 'schema-version'")
if err != nil {
t.Fatalf("Failed to update schema version: %v", err)
}
_, err = db.GetSchemaVersion()
if err == nil {
t.Fatal("Expected error for invalid schema version, got nil")
}
}
func TestCheckSchemaVersion(t *testing.T) {
// Create temp file for database
tempFile, err := os.CreateTemp("", "linkshare-test-*.db")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
tempFile.Close()
// Open database
db, err := Open(tempFile.Name())
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Test not initialized
err = db.CheckSchemaVersion()
if err != ErrNotInitialized {
t.Fatalf("Expected ErrNotInitialized, got: %v", err)
}
// Initialize the database manually
_, err = db.conn.Exec("CREATE TABLE settings (key TEXT PRIMARY KEY, value TEXT NOT NULL, kind TEXT NOT NULL)")
if err != nil {
t.Fatalf("Failed to create settings table: %v", err)
}
// Store current schema version
originalSchemaVersion := version.SchemaVersion
defer func() {
// Restore original schema version after test
version.SchemaVersion = originalSchemaVersion
}()
// Test version match
_, err = db.conn.Exec("INSERT INTO settings (key, value, kind) VALUES ('schema-version', '1', 'int')")
if err != nil {
t.Fatalf("Failed to insert schema version: %v", err)
}
version.SchemaVersion = 1
err = db.CheckSchemaVersion()
if err != nil {
t.Fatalf("Expected nil error for matching schema versions, got: %v", err)
}
// Test outdated version
version.SchemaVersion = 2
err = db.CheckSchemaVersion()
if err != ErrSchemaOutdated {
t.Fatalf("Expected ErrSchemaOutdated, got: %v", err)
}
// Test unsupported version
version.SchemaVersion = 1
_, err = db.conn.Exec("UPDATE settings SET value = '2' WHERE key = 'schema-version'")
if err != nil {
t.Fatalf("Failed to update schema version: %v", err)
}
err = db.CheckSchemaVersion()
if err != ErrSchemaUnsupported {
t.Fatalf("Expected ErrSchemaUnsupported, got: %v", err)
}
}
func TestTransaction(t *testing.T) {
// Create temp file for database
tempFile, err := os.CreateTemp("", "linkshare-test-*.db")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
tempFile.Close()
// Open database
db, err := Open(tempFile.Name())
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Initialize the database manually for testing
_, err = db.conn.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)")
if err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
// Test successful transaction
err = db.Transaction(func(tx *sql.Tx) error {
_, err := tx.Exec("INSERT INTO test (value) VALUES (?)", "test-value")
return err
})
if err != nil {
t.Fatalf("Transaction failed: %v", err)
}
// Verify data was inserted
var value string
err = db.conn.QueryRow("SELECT value FROM test WHERE id = 1").Scan(&value)
if err != nil {
t.Fatalf("Failed to query test value: %v", err)
}
if value != "test-value" {
t.Fatalf("Expected 'test-value', got: %s", value)
}
// Test failed transaction
err = db.Transaction(func(tx *sql.Tx) error {
_, err := tx.Exec("INSERT INTO test (value) VALUES (?)", "should-rollback")
if err != nil {
return err
}
return sql.ErrTxDone // Force rollback
})
if err == nil {
t.Fatal("Expected error from failed transaction, got nil")
}
// Verify data was not inserted (rollback worked)
var count int
err = db.conn.QueryRow("SELECT COUNT(*) FROM test WHERE value = 'should-rollback'").Scan(&count)
if err != nil {
t.Fatalf("Failed to query test count: %v", err)
}
if count != 0 {
t.Fatalf("Expected count 0 after rollback, got: %d", count)
}
// Test panic in transaction
panicked := false
func() {
defer func() {
if r := recover(); r != nil {
panicked = true
}
}()
_ = db.Transaction(func(tx *sql.Tx) error {
panic("test panic")
})
}()
if !panicked {
t.Fatal("Expected panic to be propagated")
}
}

View File

@@ -110,7 +110,7 @@ func FindDirectories(dbPath string) (*AppPaths, error) {
// CreateDirectories ensures all application managed directories are created // CreateDirectories ensures all application managed directories are created
func CreateDirectories(paths *AppPaths) error { func CreateDirectories(paths *AppPaths) error {
err := os.MkdirAll(filepath.Dir(paths.DatabaseFile), 0750) err := os.MkdirAll(filepath.Dir(paths.DatabaseFile), 0o750)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -0,0 +1,41 @@
package option
type Option[T any] struct {
hasValue bool
value T
}
func Some[T any](value T) Option[T] {
return Option[T]{
hasValue: true,
value: value,
}
}
func None[T any]() Option[T] {
return Option[T]{
hasValue: false,
}
}
func (o Option[T]) IsSome() bool {
return o.hasValue
}
func (o Option[T]) IsNone() bool {
return !o.hasValue
}
func (o Option[T]) Value() T {
if !o.hasValue {
panic("Option has no value")
}
return o.value
}
func (o Option[T]) ValueOr(defaultValue T) T {
if !o.hasValue {
return defaultValue
}
return o.value
}

View File

@@ -0,0 +1,54 @@
package option_test
import (
"testing"
. "git.omicron.one/omicron/linkshare/internal/util/option"
)
func TestSome(t *testing.T) {
opt := Some(42)
if !opt.IsSome() {
t.Error("Expected IsSome() to be true for Some(42)")
}
if opt.IsNone() {
t.Error("Expected IsNone() to be false for Some(42)")
}
if opt.Value() != 42 {
t.Errorf("Expected Value() to be 42, got %v", opt.Value())
}
if opt.ValueOr(0) != 42 {
t.Errorf("Expected ValueOr(0) to be 42, got %v", opt.ValueOr(0))
}
}
func TestNone(t *testing.T) {
opt := None[int]()
if opt.IsSome() {
t.Error("Expected IsSome() to be false for None[int]()")
}
if !opt.IsNone() {
t.Error("Expected IsNone() to be true for None[int]()")
}
if opt.ValueOr(99) != 99 {
t.Errorf("Expected ValueOr(99) to be 99, got %v", opt.ValueOr(99))
}
}
func TestPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("Expected Value() to panic on None")
}
}()
opt := None[string]()
_ = opt.Value() // This should panic
}

View File

@@ -0,0 +1,17 @@
package version
import "fmt"
var (
Version = "dev"
GitCommit = "unknown"
CommitDateTime = "unknown"
SchemaVersion = 1
)
// PrintVersionInfo prints formatted version information to stdout
func Print() {
fmt.Printf("Version: %s\n", Version)
fmt.Printf("Git commit: %s %s\n", GitCommit, CommitDateTime)
fmt.Printf("Schema: v%d\n", SchemaVersion)
}

20
schema/current.sql Normal file
View File

@@ -0,0 +1,20 @@
CREATE TABLE settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
kind TEXT CHECK(kind IN ('int', 'string', 'bool', 'json', 'glob')) NOT NULL
);
INSERT INTO settings (key, value, kind) VALUES ('schema-version', '1', 'int');
CREATE TABLE links (
id INTEGER PRIMARY KEY,
url TEXT NOT NULL,
title TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT,
is_private BOOLEAN NOT NULL DEFAULT 0
);
CREATE INDEX idx_links_created_at ON links(created_at);
CREATE INDEX idx_links_is_private_created_at ON links(is_private, created_at);
CREATE INDEX idx_links_url ON links(url);