From 2730496b355d432649d8f6435252d3d0498c3a4c Mon Sep 17 00:00:00 2001
From: omicron <omicron.me@protonmail.com>
Date: Tue, 20 May 2025 17:08:28 +0200
Subject: [PATCH] Rework speck tests, improve coverage, add benchmarks

---
 cipher/speck/impl/benchmark128_test.go | 172 +++++++++++++++++++++++++
 cipher/speck/impl/speck128_test.go     | 134 +++++++++++++++++++
 cipher/speck/speck_test.go             | 134 ++++++++++---------
 3 files changed, 382 insertions(+), 58 deletions(-)
 create mode 100644 cipher/speck/impl/benchmark128_test.go
 create mode 100644 cipher/speck/impl/speck128_test.go

diff --git a/cipher/speck/impl/benchmark128_test.go b/cipher/speck/impl/benchmark128_test.go
new file mode 100644
index 0000000..d65fc28
--- /dev/null
+++ b/cipher/speck/impl/benchmark128_test.go
@@ -0,0 +1,172 @@
+package impl_test
+
+import (
+	"crypto/rand"
+	"io"
+	"testing"
+
+	"git.omicron.one/playground/cryptography/cipher/speck/impl"
+	"github.com/stretchr/testify/assert"
+)
+
+func BenchmarkKeyschedule128128(b *testing.B) {
+	key := make([]byte, impl.KeySize128128)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	_, err = impl.New128(key)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		impl.New128(key)
+	}
+}
+
+func BenchmarkKeyschedule128192(b *testing.B) {
+	key := make([]byte, impl.KeySize128192)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	_, err = impl.New128(key)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		impl.New128(key)
+	}
+}
+
+func BenchmarkKeyschedule128256(b *testing.B) {
+	key := make([]byte, impl.KeySize128256)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	_, err = impl.New128(key)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		impl.New128(key)
+	}
+}
+
+func BenchmarkEncrypt128128(b *testing.B) {
+	key := make([]byte, impl.KeySize128128)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	ctx, err := impl.New128(key)
+	assert.Nil(b, err)
+	b.SetBytes(int64(ctx.BlockSize()))
+
+	ciphertext := make([]byte, ctx.BlockSize())
+	plaintext := make([]byte, ctx.BlockSize())
+	_, err = io.ReadFull(rand.Reader, plaintext)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		ctx.Encrypt(ciphertext, plaintext)
+	}
+}
+
+func BenchmarkDecrypt128128(b *testing.B) {
+	key := make([]byte, impl.KeySize128128)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	ctx, err := impl.New128(key)
+	assert.Nil(b, err)
+	b.SetBytes(int64(ctx.BlockSize()))
+
+	plaintext := make([]byte, ctx.BlockSize())
+	ciphertext := make([]byte, ctx.BlockSize())
+	_, err = io.ReadFull(rand.Reader, ciphertext)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		ctx.Decrypt(plaintext, ciphertext)
+	}
+}
+
+func BenchmarkEncrypt128192(b *testing.B) {
+	key := make([]byte, impl.KeySize128192)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	ctx, err := impl.New128(key)
+	assert.Nil(b, err)
+	b.SetBytes(int64(ctx.BlockSize()))
+
+	ciphertext := make([]byte, ctx.BlockSize())
+	plaintext := make([]byte, ctx.BlockSize())
+	_, err = io.ReadFull(rand.Reader, plaintext)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		ctx.Encrypt(ciphertext, plaintext)
+	}
+}
+
+func BenchmarkDecrypt128192(b *testing.B) {
+	key := make([]byte, impl.KeySize128192)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	ctx, err := impl.New128(key)
+	assert.Nil(b, err)
+	b.SetBytes(int64(ctx.BlockSize()))
+
+	plaintext := make([]byte, ctx.BlockSize())
+	ciphertext := make([]byte, ctx.BlockSize())
+	_, err = io.ReadFull(rand.Reader, ciphertext)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		ctx.Decrypt(plaintext, ciphertext)
+	}
+}
+
+func BenchmarkEncrypt128256(b *testing.B) {
+	key := make([]byte, impl.KeySize128256)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	ctx, err := impl.New128(key)
+	assert.Nil(b, err)
+	b.SetBytes(int64(ctx.BlockSize()))
+
+	ciphertext := make([]byte, ctx.BlockSize())
+	plaintext := make([]byte, ctx.BlockSize())
+	_, err = io.ReadFull(rand.Reader, plaintext)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		ctx.Encrypt(ciphertext, plaintext)
+	}
+}
+
+func BenchmarkDecrypt128256(b *testing.B) {
+	key := make([]byte, impl.KeySize128256)
+	_, err := io.ReadFull(rand.Reader, key)
+	assert.Nil(b, err)
+
+	ctx, err := impl.New128(key)
+	assert.Nil(b, err)
+	b.SetBytes(int64(ctx.BlockSize()))
+
+	plaintext := make([]byte, ctx.BlockSize())
+	ciphertext := make([]byte, ctx.BlockSize())
+	_, err = io.ReadFull(rand.Reader, ciphertext)
+	assert.Nil(b, err)
+
+	b.ResetTimer()
+	for range b.N {
+		ctx.Decrypt(plaintext, ciphertext)
+	}
+}
diff --git a/cipher/speck/impl/speck128_test.go b/cipher/speck/impl/speck128_test.go
new file mode 100644
index 0000000..4f53870
--- /dev/null
+++ b/cipher/speck/impl/speck128_test.go
@@ -0,0 +1,134 @@
+package impl_test
+
+import (
+	"slices"
+	"testing"
+
+	"git.omicron.one/playground/cryptography/cipher"
+	"git.omicron.one/playground/cryptography/cipher/speck/impl"
+	. "git.omicron.one/playground/cryptography/util"
+	"github.com/stretchr/testify/assert"
+)
+
+func testVector128(t *testing.T, key, plaintext, ciphertext []byte, bs int, name string) {
+	t.Helper()
+
+	buffer := make([]byte, len(plaintext))
+	ctx, err := impl.New128(key)
+	assert.Nil(t, err)
+	assert.NotNil(t, ctx)
+	assert.Equal(t, bs, ctx.BlockSize())
+	assert.Equal(t, name, ctx.Algorithm())
+
+	// Two buffers
+	pt := slices.Clone(plaintext)
+	ctx.Encrypt(buffer, pt)
+	assert.Equal(t, plaintext, pt)
+	assert.Equal(t, ciphertext, buffer)
+
+	clear(buffer)
+	ct := slices.Clone(ciphertext)
+	ctx.Decrypt(buffer, ct)
+	assert.Equal(t, ciphertext, ct)
+	assert.Equal(t, plaintext, buffer)
+
+	// In-place
+	copy(buffer, plaintext)
+	ctx.Encrypt(buffer, buffer)
+	assert.Equal(t, ciphertext, buffer)
+	ctx.Decrypt(buffer, buffer)
+	assert.Equal(t, plaintext, buffer)
+}
+
+func TestVector128128(t *testing.T) {
+	var (
+		key        = DeHex("0f0e0d0c0b0a09080706050403020100")
+		plaintext  = DeHex("6c617669757165207469206564616d20")
+		ciphertext = DeHex("a65d9851797832657860fedf5c570d18")
+		bs         = impl.BlockSize128
+		name       = "Speck128/128"
+	)
+	testVector128(t, key, plaintext, ciphertext, bs, name)
+}
+
+func TestVector128192(t *testing.T) {
+	var (
+		key        = DeHex("17161514131211100f0e0d0c0b0a09080706050403020100")
+		plaintext  = DeHex("726148206665696843206f7420746e65")
+		ciphertext = DeHex("1be4cf3a13135566f9bc185de03c1886")
+		bs         = impl.BlockSize128
+		name       = "Speck128/192"
+	)
+	testVector128(t, key, plaintext, ciphertext, bs, name)
+}
+
+func TestVector128256(t *testing.T) {
+	var (
+		key        = DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100")
+		plaintext  = DeHex("65736f6874206e49202e72656e6f6f70")
+		ciphertext = DeHex("4109010405c0f53e4eeeb48d9c188f43")
+		bs         = impl.BlockSize128
+		name       = "Speck128/256"
+	)
+	testVector128(t, key, plaintext, ciphertext, bs, name)
+}
+
+func TestInvalidKey128(t *testing.T) {
+	ctx, err := impl.New128(DeHex("deadbeef"))
+	assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
+	assert.Nil(t, ctx)
+}
+
+func TestDecryptBlockSize128(t *testing.T) {
+	ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
+	assert.Nil(t, err)
+	assert.NotNil(t, ctx)
+
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()-1)
+		ctx.Decrypt(nil, buffer)
+	})
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()-1)
+		ctx.Decrypt(buffer, nil)
+	})
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()-1)
+		ctx.Decrypt(buffer, buffer)
+	})
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()+1)
+		ctx.Decrypt(buffer, buffer)
+	})
+	assert.NotPanics(t, func() {
+		buffer := make([]byte, ctx.BlockSize())
+		ctx.Decrypt(buffer, buffer)
+	})
+}
+
+func TestEncryptBlockSize128(t *testing.T) {
+	ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
+	assert.Nil(t, err)
+	assert.NotNil(t, ctx)
+
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()-1)
+		ctx.Encrypt(nil, buffer)
+	})
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()-1)
+		ctx.Encrypt(buffer, nil)
+	})
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()-1)
+		ctx.Encrypt(buffer, buffer)
+	})
+	assert.Panics(t, func() {
+		buffer := make([]byte, ctx.BlockSize()+1)
+		ctx.Encrypt(buffer, buffer)
+	})
+	assert.NotPanics(t, func() {
+		buffer := make([]byte, ctx.BlockSize())
+		ctx.Encrypt(buffer, buffer)
+	})
+}
diff --git a/cipher/speck/speck_test.go b/cipher/speck/speck_test.go
index fdc105c..e9fb8c7 100644
--- a/cipher/speck/speck_test.go
+++ b/cipher/speck/speck_test.go
@@ -1,75 +1,93 @@
 package speck_test
 
 import (
-	"encoding/hex"
-	"slices"
 	"testing"
 
+	"git.omicron.one/playground/cryptography/cipher"
 	"git.omicron.one/playground/cryptography/cipher/speck"
 	"github.com/stretchr/testify/assert"
 )
 
-func DeHex(s string) []byte {
-	decoded, err := hex.DecodeString(s)
-	if err != nil {
-		panic("invalid hex string")
+func testKey(param speck.SpeckParameters) []byte {
+	switch param {
+	case speck.Speck3264:
+		return make([]byte, 64/8)
+	case speck.Speck4872:
+		return make([]byte, 72/8)
+	case speck.Speck4896, speck.Speck6496, speck.Speck9696:
+		return make([]byte, 96/8)
+	case speck.Speck64128, speck.Speck128128:
+		return make([]byte, 128/8)
+	case speck.Speck96144:
+		return make([]byte, 144/8)
+	case speck.Speck128192:
+		return make([]byte, 192/8)
+	case speck.Speck128256:
+		return make([]byte, 256/8)
 	}
-	return decoded
+	panic("unreachable")
 }
 
-type TestVector struct {
-	Key        []byte
-	Plaintext  []byte
-	Ciphertext []byte
-	Param      speck.SpeckParameters
-}
+func TestNew(t *testing.T) {
+	notImplemented := []speck.SpeckParameters{
+		speck.Speck3264,
+		speck.Speck4872,
+		speck.Speck4896,
+		speck.Speck6496,
+		speck.Speck64128,
+		speck.Speck9696,
+		speck.Speck96144,
+	}
+	implemented := []speck.SpeckParameters{
+		speck.Speck128128,
+		speck.Speck128192,
+		speck.Speck128256,
+	}
 
-var vectors []TestVector = []TestVector{
-	// Speck128/128 test vector
-	{
-		Key:        DeHex("0f0e0d0c0b0a09080706050403020100"),
-		Plaintext:  DeHex("6c617669757165207469206564616d20"),
-		Ciphertext: DeHex("a65d9851797832657860fedf5c570d18"),
-		Param:      speck.Speck128128,
-	},
-	{
-		Key:        DeHex("17161514131211100f0e0d0c0b0a09080706050403020100"),
-		Plaintext:  DeHex("726148206665696843206f7420746e65"),
-		Ciphertext: DeHex("1be4cf3a13135566f9bc185de03c1886"),
-		Param:      speck.Speck128192,
-	},
-	{
-		Key:        DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100"),
-		Plaintext:  DeHex("65736f6874206e49202e72656e6f6f70"),
-		Ciphertext: DeHex("4109010405c0f53e4eeeb48d9c188f43"),
-		Param:      speck.Speck128256,
-	},
-}
+	for _, param := range notImplemented {
+		key := testKey(param)
+		ctx, err := speck.New(key, param)
+		assert.Nil(t, ctx)
+		assert.ErrorContains(t, err, "Not implemented")
+	}
 
-func TestVectors(t *testing.T) {
-	for _, vector := range vectors {
-		ctx, err := speck.New(vector.Key, vector.Param)
-		assert.NotNil(t, ctx)
+	for _, param := range implemented {
+		key := testKey(param)
+		ctx, err := speck.New(key, param)
 		assert.Nil(t, err)
-
-		// Test in place
-		buffer := slices.Clone(vector.Plaintext)
-		ctx.Encrypt(buffer, buffer)
-		assert.Equal(t, vector.Ciphertext, buffer, ctx.Algorithm())
-		ctx.Decrypt(buffer, buffer)
-		assert.Equal(t, vector.Plaintext, buffer, ctx.Algorithm())
-
-		// Test two buffers
-		dst := make([]byte, len(vector.Ciphertext))
-		src := slices.Clone(vector.Plaintext)
-		ctx.Encrypt(dst, src)
-		assert.Equal(t, vector.Plaintext, src, ctx.Algorithm())
-		assert.Equal(t, vector.Ciphertext, dst, ctx.Algorithm())
-
-		dst = make([]byte, len(vector.Plaintext))
-		src = slices.Clone(vector.Ciphertext)
-		ctx.Decrypt(dst, src)
-		assert.Equal(t, vector.Ciphertext, src, ctx.Algorithm())
-		assert.Equal(t, vector.Plaintext, dst, ctx.Algorithm())
+		assert.NotNil(t, ctx)
 	}
 }
+
+func TestInvalidKeyLength(t *testing.T) {
+	params := []speck.SpeckParameters{
+		speck.Speck3264,
+		speck.Speck4872,
+		speck.Speck4896,
+		speck.Speck6496,
+		speck.Speck64128,
+		speck.Speck9696,
+		speck.Speck96144,
+		speck.Speck128128,
+		speck.Speck128192,
+		speck.Speck128256,
+	}
+	for _, param := range params {
+		key := testKey(param)
+		ctx, err := speck.New(key[1:], param)
+		assert.Nil(t, ctx)
+		assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
+	}
+}
+
+func TestInvalidParam(t *testing.T) {
+	assert.PanicsWithValue(t, "Invalid parameters", func() {
+		speck.New(nil, -1)
+	})
+	assert.PanicsWithValue(t, "Invalid parameters", func() {
+		speck.New(nil, 0)
+	})
+	assert.PanicsWithValue(t, "Invalid parameters", func() {
+		speck.New(nil, speck.Speck128256+1)
+	})
+}