Compare commits

..

7 Commits

Author SHA1 Message Date
7b8a6f03c2 Implement JSON marshalling for the matrix class
All checks were successful
Validate the build / validate-build (push) Successful in 1m1s
2025-05-22 12:56:45 +02:00
9729fe6dcb Add matrix.CreateFromSlice function 2025-05-22 12:56:38 +02:00
5f450e206f Add ci/cd commit validation and a makefile to do common operations
All checks were successful
Validate the build / validate-build (push) Successful in 1m1s
2025-05-20 22:33:48 +02:00
eae48cf9ae Reduce go mod version compatibility to 1.23 2025-05-20 22:01:37 +02:00
2730496b35 Rework speck tests, improve coverage, add benchmarks 2025-05-20 17:08:28 +02:00
1ca7d572f9 Add util package with a DeHex function 2025-05-20 17:07:43 +02:00
7b8df3b046 Fix speck.New not panicking early on all invalid parameters 2025-05-20 16:57:39 +02:00
12 changed files with 733 additions and 60 deletions

View File

@@ -0,0 +1,40 @@
name: Validate the build
run-name: ${{ gitea.actor }} is validating
on: [push]
jobs:
validate-build:
runs-on: ubuntu-latest
container:
image: node:current-alpine
steps:
- name: Install dependencies
run: |
echo "https://dl-cdn.alpinelinux.org/alpine/edge/main" >> /etc/apk/repositories
echo "https://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
apk update
apk add --no-cache git make bash go
GOBIN=/usr/local/bin go install mvdan.cc/gofumpt@latest
export "PATH=$PATH:/root/go/bin"
echo "---------------------"
echo "Go version:"
go version
echo "---------------------"
- name: Check out repository code
uses: actions/checkout@v4
- name: Fetch dependencies
run: |
go mod download
- name: Validate the code and formatting
run: |
make validate
- name: Run tests
run: |
make test

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
/reports
/bin

34
Makefile Normal file
View File

@@ -0,0 +1,34 @@
BINARY_DIR = bin
BINARIES = $(patsubst cmd/%/,%,$(wildcard cmd/*/))
.PHONY: all build test coverage validate clean purge $(BINARIES)
all: build
build: $(BINARIES)
$(BINARY_DIR):
mkdir -p $(BINARY_DIR)
$(BINARIES): %: $(BINARY_DIR)
go build -o $(BINARY_DIR)/$@ ./cmd/$@/
test:
go test ./... -cover
coverage:
mkdir -p reports/
go test -coverprofile=reports/coverage.out ./... && go tool cover -html=reports/coverage.out
validate:
@test -z "$(shell gofumpt -l .)" && echo "No files need formatting" || (echo "Incorrect formatting in:"; gofumpt -l .; exit 1)
go vet ./...
clean:
rm -rf $(BINARY_DIR)
go clean
purge: clean
rm -rf reports

View File

@@ -0,0 +1,172 @@
package impl_test
import (
"crypto/rand"
"io"
"testing"
"git.omicron.one/playground/cryptography/cipher/speck/impl"
"github.com/stretchr/testify/assert"
)
func BenchmarkKeyschedule128128(b *testing.B) {
key := make([]byte, impl.KeySize128128)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
_, err = impl.New128(key)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
impl.New128(key)
}
}
func BenchmarkKeyschedule128192(b *testing.B) {
key := make([]byte, impl.KeySize128192)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
_, err = impl.New128(key)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
impl.New128(key)
}
}
func BenchmarkKeyschedule128256(b *testing.B) {
key := make([]byte, impl.KeySize128256)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
_, err = impl.New128(key)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
impl.New128(key)
}
}
func BenchmarkEncrypt128128(b *testing.B) {
key := make([]byte, impl.KeySize128128)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
ctx, err := impl.New128(key)
assert.Nil(b, err)
b.SetBytes(int64(ctx.BlockSize()))
ciphertext := make([]byte, ctx.BlockSize())
plaintext := make([]byte, ctx.BlockSize())
_, err = io.ReadFull(rand.Reader, plaintext)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
ctx.Encrypt(ciphertext, plaintext)
}
}
func BenchmarkDecrypt128128(b *testing.B) {
key := make([]byte, impl.KeySize128128)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
ctx, err := impl.New128(key)
assert.Nil(b, err)
b.SetBytes(int64(ctx.BlockSize()))
plaintext := make([]byte, ctx.BlockSize())
ciphertext := make([]byte, ctx.BlockSize())
_, err = io.ReadFull(rand.Reader, ciphertext)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
ctx.Decrypt(plaintext, ciphertext)
}
}
func BenchmarkEncrypt128192(b *testing.B) {
key := make([]byte, impl.KeySize128192)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
ctx, err := impl.New128(key)
assert.Nil(b, err)
b.SetBytes(int64(ctx.BlockSize()))
ciphertext := make([]byte, ctx.BlockSize())
plaintext := make([]byte, ctx.BlockSize())
_, err = io.ReadFull(rand.Reader, plaintext)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
ctx.Encrypt(ciphertext, plaintext)
}
}
func BenchmarkDecrypt128192(b *testing.B) {
key := make([]byte, impl.KeySize128192)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
ctx, err := impl.New128(key)
assert.Nil(b, err)
b.SetBytes(int64(ctx.BlockSize()))
plaintext := make([]byte, ctx.BlockSize())
ciphertext := make([]byte, ctx.BlockSize())
_, err = io.ReadFull(rand.Reader, ciphertext)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
ctx.Decrypt(plaintext, ciphertext)
}
}
func BenchmarkEncrypt128256(b *testing.B) {
key := make([]byte, impl.KeySize128256)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
ctx, err := impl.New128(key)
assert.Nil(b, err)
b.SetBytes(int64(ctx.BlockSize()))
ciphertext := make([]byte, ctx.BlockSize())
plaintext := make([]byte, ctx.BlockSize())
_, err = io.ReadFull(rand.Reader, plaintext)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
ctx.Encrypt(ciphertext, plaintext)
}
}
func BenchmarkDecrypt128256(b *testing.B) {
key := make([]byte, impl.KeySize128256)
_, err := io.ReadFull(rand.Reader, key)
assert.Nil(b, err)
ctx, err := impl.New128(key)
assert.Nil(b, err)
b.SetBytes(int64(ctx.BlockSize()))
plaintext := make([]byte, ctx.BlockSize())
ciphertext := make([]byte, ctx.BlockSize())
_, err = io.ReadFull(rand.Reader, ciphertext)
assert.Nil(b, err)
b.ResetTimer()
for range b.N {
ctx.Decrypt(plaintext, ciphertext)
}
}

View File

@@ -0,0 +1,134 @@
package impl_test
import (
"slices"
"testing"
"git.omicron.one/playground/cryptography/cipher"
"git.omicron.one/playground/cryptography/cipher/speck/impl"
. "git.omicron.one/playground/cryptography/util"
"github.com/stretchr/testify/assert"
)
func testVector128(t *testing.T, key, plaintext, ciphertext []byte, bs int, name string) {
t.Helper()
buffer := make([]byte, len(plaintext))
ctx, err := impl.New128(key)
assert.Nil(t, err)
assert.NotNil(t, ctx)
assert.Equal(t, bs, ctx.BlockSize())
assert.Equal(t, name, ctx.Algorithm())
// Two buffers
pt := slices.Clone(plaintext)
ctx.Encrypt(buffer, pt)
assert.Equal(t, plaintext, pt)
assert.Equal(t, ciphertext, buffer)
clear(buffer)
ct := slices.Clone(ciphertext)
ctx.Decrypt(buffer, ct)
assert.Equal(t, ciphertext, ct)
assert.Equal(t, plaintext, buffer)
// In-place
copy(buffer, plaintext)
ctx.Encrypt(buffer, buffer)
assert.Equal(t, ciphertext, buffer)
ctx.Decrypt(buffer, buffer)
assert.Equal(t, plaintext, buffer)
}
func TestVector128128(t *testing.T) {
var (
key = DeHex("0f0e0d0c0b0a09080706050403020100")
plaintext = DeHex("6c617669757165207469206564616d20")
ciphertext = DeHex("a65d9851797832657860fedf5c570d18")
bs = impl.BlockSize128
name = "Speck128/128"
)
testVector128(t, key, plaintext, ciphertext, bs, name)
}
func TestVector128192(t *testing.T) {
var (
key = DeHex("17161514131211100f0e0d0c0b0a09080706050403020100")
plaintext = DeHex("726148206665696843206f7420746e65")
ciphertext = DeHex("1be4cf3a13135566f9bc185de03c1886")
bs = impl.BlockSize128
name = "Speck128/192"
)
testVector128(t, key, plaintext, ciphertext, bs, name)
}
func TestVector128256(t *testing.T) {
var (
key = DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100")
plaintext = DeHex("65736f6874206e49202e72656e6f6f70")
ciphertext = DeHex("4109010405c0f53e4eeeb48d9c188f43")
bs = impl.BlockSize128
name = "Speck128/256"
)
testVector128(t, key, plaintext, ciphertext, bs, name)
}
func TestInvalidKey128(t *testing.T) {
ctx, err := impl.New128(DeHex("deadbeef"))
assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
assert.Nil(t, ctx)
}
func TestDecryptBlockSize128(t *testing.T) {
ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
assert.Nil(t, err)
assert.NotNil(t, ctx)
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()-1)
ctx.Decrypt(nil, buffer)
})
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()-1)
ctx.Decrypt(buffer, nil)
})
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()-1)
ctx.Decrypt(buffer, buffer)
})
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()+1)
ctx.Decrypt(buffer, buffer)
})
assert.NotPanics(t, func() {
buffer := make([]byte, ctx.BlockSize())
ctx.Decrypt(buffer, buffer)
})
}
func TestEncryptBlockSize128(t *testing.T) {
ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
assert.Nil(t, err)
assert.NotNil(t, ctx)
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()-1)
ctx.Encrypt(nil, buffer)
})
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()-1)
ctx.Encrypt(buffer, nil)
})
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()-1)
ctx.Encrypt(buffer, buffer)
})
assert.Panics(t, func() {
buffer := make([]byte, ctx.BlockSize()+1)
ctx.Encrypt(buffer, buffer)
})
assert.NotPanics(t, func() {
buffer := make([]byte, ctx.BlockSize())
ctx.Encrypt(buffer, buffer)
})
}

View File

@@ -41,7 +41,7 @@ var keySizes = []int{
// New creates a new speck block cipher context. // New creates a new speck block cipher context.
// Returns the created block cipher or an error. // Returns the created block cipher or an error.
func New(key []byte, param SpeckParameters) (cipher.Block, error) { func New(key []byte, param SpeckParameters) (cipher.Block, error) {
if param == 0 || int(param) > len(keySizes) { if param <= 0 || int(param) >= len(keySizes) {
panic("Invalid parameters") panic("Invalid parameters")
} }
keySize := keySizes[param] keySize := keySizes[param]

View File

@@ -1,75 +1,93 @@
package speck_test package speck_test
import ( import (
"encoding/hex"
"slices"
"testing" "testing"
"git.omicron.one/playground/cryptography/cipher"
"git.omicron.one/playground/cryptography/cipher/speck" "git.omicron.one/playground/cryptography/cipher/speck"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func DeHex(s string) []byte { func testKey(param speck.SpeckParameters) []byte {
decoded, err := hex.DecodeString(s) switch param {
if err != nil { case speck.Speck3264:
panic("invalid hex string") return make([]byte, 64/8)
case speck.Speck4872:
return make([]byte, 72/8)
case speck.Speck4896, speck.Speck6496, speck.Speck9696:
return make([]byte, 96/8)
case speck.Speck64128, speck.Speck128128:
return make([]byte, 128/8)
case speck.Speck96144:
return make([]byte, 144/8)
case speck.Speck128192:
return make([]byte, 192/8)
case speck.Speck128256:
return make([]byte, 256/8)
} }
return decoded panic("unreachable")
} }
type TestVector struct { func TestNew(t *testing.T) {
Key []byte notImplemented := []speck.SpeckParameters{
Plaintext []byte speck.Speck3264,
Ciphertext []byte speck.Speck4872,
Param speck.SpeckParameters speck.Speck4896,
} speck.Speck6496,
speck.Speck64128,
speck.Speck9696,
speck.Speck96144,
}
implemented := []speck.SpeckParameters{
speck.Speck128128,
speck.Speck128192,
speck.Speck128256,
}
var vectors []TestVector = []TestVector{ for _, param := range notImplemented {
// Speck128/128 test vector key := testKey(param)
{ ctx, err := speck.New(key, param)
Key: DeHex("0f0e0d0c0b0a09080706050403020100"), assert.Nil(t, ctx)
Plaintext: DeHex("6c617669757165207469206564616d20"), assert.ErrorContains(t, err, "Not implemented")
Ciphertext: DeHex("a65d9851797832657860fedf5c570d18"), }
Param: speck.Speck128128,
},
{
Key: DeHex("17161514131211100f0e0d0c0b0a09080706050403020100"),
Plaintext: DeHex("726148206665696843206f7420746e65"),
Ciphertext: DeHex("1be4cf3a13135566f9bc185de03c1886"),
Param: speck.Speck128192,
},
{
Key: DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100"),
Plaintext: DeHex("65736f6874206e49202e72656e6f6f70"),
Ciphertext: DeHex("4109010405c0f53e4eeeb48d9c188f43"),
Param: speck.Speck128256,
},
}
func TestVectors(t *testing.T) { for _, param := range implemented {
for _, vector := range vectors { key := testKey(param)
ctx, err := speck.New(vector.Key, vector.Param) ctx, err := speck.New(key, param)
assert.NotNil(t, ctx)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, ctx)
// Test in place
buffer := slices.Clone(vector.Plaintext)
ctx.Encrypt(buffer, buffer)
assert.Equal(t, vector.Ciphertext, buffer, ctx.Algorithm())
ctx.Decrypt(buffer, buffer)
assert.Equal(t, vector.Plaintext, buffer, ctx.Algorithm())
// Test two buffers
dst := make([]byte, len(vector.Ciphertext))
src := slices.Clone(vector.Plaintext)
ctx.Encrypt(dst, src)
assert.Equal(t, vector.Plaintext, src, ctx.Algorithm())
assert.Equal(t, vector.Ciphertext, dst, ctx.Algorithm())
dst = make([]byte, len(vector.Plaintext))
src = slices.Clone(vector.Ciphertext)
ctx.Decrypt(dst, src)
assert.Equal(t, vector.Ciphertext, src, ctx.Algorithm())
assert.Equal(t, vector.Plaintext, dst, ctx.Algorithm())
} }
} }
func TestInvalidKeyLength(t *testing.T) {
params := []speck.SpeckParameters{
speck.Speck3264,
speck.Speck4872,
speck.Speck4896,
speck.Speck6496,
speck.Speck64128,
speck.Speck9696,
speck.Speck96144,
speck.Speck128128,
speck.Speck128192,
speck.Speck128256,
}
for _, param := range params {
key := testKey(param)
ctx, err := speck.New(key[1:], param)
assert.Nil(t, ctx)
assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
}
}
func TestInvalidParam(t *testing.T) {
assert.PanicsWithValue(t, "Invalid parameters", func() {
speck.New(nil, -1)
})
assert.PanicsWithValue(t, "Invalid parameters", func() {
speck.New(nil, 0)
})
assert.PanicsWithValue(t, "Invalid parameters", func() {
speck.New(nil, speck.Speck128256+1)
})
}

2
go.mod
View File

@@ -1,6 +1,6 @@
module git.omicron.one/playground/cryptography module git.omicron.one/playground/cryptography
go 1.24.2 go 1.23.0
require ( require (
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0

View File

@@ -4,6 +4,7 @@
package matrix package matrix
import ( import (
"encoding/json"
"errors" "errors"
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
@@ -54,6 +55,33 @@ func Create[T Number](rows, cols int) *Matrix[T] {
} }
} }
// Creates a new matrix from a given 2D slice of values. The first index of the
// slice denotes the rows and the second index denotes the columns. Returns the
// newly created matrix.
//
// Panics with ErrIncompatibleDataDimensions if any of the lengths are 0 or if
// not all rows have the same length.
func CreateFromSlice[T Number](values [][]T) *Matrix[T] {
rows := len(values)
if rows == 0 {
panic(ErrInvalidDimensions)
}
cols := len(values[0])
if cols == 0 {
panic(ErrInvalidDimensions)
}
m := Create[T](rows, cols)
for i := range rows {
if len(values[i]) != cols {
panic(ErrIncompatibleDataDimensions)
}
copy(m.values[i], values[i])
}
return m
}
// Creates a new matrix of a given size and sets the values from the given // Creates a new matrix of a given size and sets the values from the given
// slice. The values are used to fill the matrix left to right and top to // slice. The values are used to fill the matrix left to right and top to
// bottom. Returns the newly created matrix. // bottom. Returns the newly created matrix.
@@ -80,6 +108,17 @@ func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] {
return m return m
} }
// CreateFromJSON creates a new matrix from JSON data representing a
// two-dimensional array. Returns an error if the JSON is invalid, the array is
// empty, or the inner arrays have inconsistent lengths.
func CreateFromJSON[T Number](data []byte) (*Matrix[T], error) {
m := &Matrix[T]{}
if err := m.UnmarshalJSON(data); err != nil {
return nil, err
}
return m, nil
}
// Convert will take a matrix of type T and convert it into a matrix of type U // Convert will take a matrix of type T and convert it into a matrix of type U
// Only works on SimpleNumber matrices // Only works on SimpleNumber matrices
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] { func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
@@ -293,3 +332,39 @@ func (m *Matrix[T]) Fill(value T) *Matrix[T] {
} }
return m return m
} }
// UnmarshalJSON implements json.Unmarshaler for Matrix, creating a matrix from
// a JSON two-dimensional array.
func (m *Matrix[T]) UnmarshalJSON(data []byte) error {
var values [][]T
if err := json.Unmarshal(data, &values); err != nil {
return err
}
rows := len(values)
if rows == 0 {
return ErrInvalidDimensions
}
cols := len(values[0])
if cols == 0 {
return ErrInvalidDimensions
}
for i := range rows {
if len(values[i]) != cols {
return ErrIncompatibleDataDimensions
}
}
m.rows = rows
m.cols = cols
m.values = values
return nil
}
// MarshalJSON implements json.Marshaler for Matrix, serializing the matrix as
// a JSON two-dimensional array.
func (m *Matrix[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(m.values)
}

View File

@@ -1,6 +1,7 @@
package matrix_test package matrix_test
import ( import (
"encoding/json"
"testing" "testing"
"git.omicron.one/playground/cryptography/matrix" "git.omicron.one/playground/cryptography/matrix"
@@ -55,6 +56,41 @@ func TestCreate(t *testing.T) {
}) })
} }
func TestCreateFromSlice(t *testing.T) {
m := matrix.CreateFromSlice([][]int{
{1, 2, 3},
{4, 5, 6},
})
assert.NotNil(t, m)
assert.Equal(t, 2, m.Rows())
assert.Equal(t, 3, m.Cols())
assert.Equal(t, 1, m.Get(0, 0))
assert.Equal(t, 2, m.Get(0, 1))
assert.Equal(t, 3, m.Get(0, 2))
assert.Equal(t, 4, m.Get(1, 0))
assert.Equal(t, 5, m.Get(1, 1))
assert.Equal(t, 6, m.Get(1, 2))
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
matrix.CreateFromSlice([][]int{})
})
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
matrix.CreateFromSlice([][]int{
{},
})
})
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
matrix.CreateFromSlice([][]int{
{1, 2, 3},
{4, 5},
})
})
}
func TestCreateFromFlatSlice(t *testing.T) { func TestCreateFromFlatSlice(t *testing.T) {
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
assert.NotNil(t, m) assert.NotNil(t, m)
@@ -94,6 +130,46 @@ func TestCreateFromFlatSlice(t *testing.T) {
}) })
} }
func TestCreateFromJSON(t *testing.T) {
// data json
data := []byte(`[[1, 2, 3], [4, 5, 6]]`)
m, err := matrix.CreateFromJSON[int](data)
assert.Nil(t, err)
assert.NotNil(t, m)
assert.Equal(t, 2, m.Rows())
assert.Equal(t, 3, m.Cols())
assert.Equal(t, 1, m.Get(0, 0))
assert.Equal(t, 2, m.Get(0, 1))
assert.Equal(t, 3, m.Get(0, 2))
assert.Equal(t, 4, m.Get(1, 0))
assert.Equal(t, 5, m.Get(1, 1))
assert.Equal(t, 6, m.Get(1, 2))
// invalid json
data = []byte(`[[1, 2, 3], [4, 5,`)
m, err = matrix.CreateFromJSON[int](data)
assert.NotNil(t, err)
assert.Nil(t, m)
// empty matrix
data = []byte(`[]`)
m, err = matrix.CreateFromJSON[int](data)
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
assert.Nil(t, m)
// empty rows
data = []byte(`[[]]`)
m, err = matrix.CreateFromJSON[int](data)
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
assert.Nil(t, m)
// mixed row length
data = []byte(`[[1, 2, 3], [4, 5]]`)
m, err = matrix.CreateFromJSON[int](data)
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
assert.Nil(t, m)
}
func TestSum(t *testing.T) { func TestSum(t *testing.T) {
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1}) b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
@@ -493,3 +569,87 @@ func TestMatrix_Fill(t *testing.T) {
assert.Equal(t, 3, a.Get(1, 1)) assert.Equal(t, 3, a.Get(1, 1))
assert.Equal(t, 3, a.Get(1, 2)) assert.Equal(t, 3, a.Get(1, 2))
} }
func TestMatrix_UnmarshalJSON(t *testing.T) {
// int matrix
data := []byte(`[[1,2,3],[4,5,6]]`)
var m *matrix.Matrix[int]
err := json.Unmarshal(data, &m)
assert.Nil(t, err)
assert.Equal(t, 2, m.Rows())
assert.Equal(t, 3, m.Cols())
assert.Equal(t, 1, m.Get(0, 0))
assert.Equal(t, 2, m.Get(0, 1))
assert.Equal(t, 3, m.Get(0, 2))
assert.Equal(t, 4, m.Get(1, 0))
assert.Equal(t, 5, m.Get(1, 1))
assert.Equal(t, 6, m.Get(1, 2))
// float matrix
data = []byte(`[[1.5,2.5],[3.5,4.5]]`)
var mf *matrix.Matrix[float64]
err = json.Unmarshal(data, &mf)
assert.Nil(t, err)
assert.Equal(t, 2, mf.Rows())
assert.Equal(t, 2, mf.Cols())
assert.Equal(t, 1.5, mf.Get(0, 0))
assert.Equal(t, 2.5, mf.Get(0, 1))
assert.Equal(t, 3.5, mf.Get(1, 0))
assert.Equal(t, 4.5, mf.Get(1, 1))
// via json.Unmarshal
matrices := []byte(`[[[1,2],[3,4]],[[5,6,7]]]`)
var ms []*matrix.Matrix[int]
err = json.Unmarshal(matrices, &ms)
assert.Nil(t, err)
assert.Len(t, ms, 2)
assert.Equal(t, 2, ms[0].Get(0, 1))
assert.Equal(t, 7, ms[1].Get(0, 2))
// invalid JSON
err = m.UnmarshalJSON([]byte(`invalid`))
assert.NotNil(t, err)
// empty array
err = m.UnmarshalJSON([]byte(`[]`))
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
// empty inner array
err = m.UnmarshalJSON([]byte(`[[]]`))
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
// inconsistent lengths
err = m.UnmarshalJSON([]byte(`[[1,2],[3]]`))
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
}
func TestMatrix_MarshallJSON(t *testing.T) {
// int matrix
m := matrix.CreateFromSlice([][]int{{1, 2, 3}, {4, 5, 6}})
data, err := m.MarshalJSON()
assert.Nil(t, err)
assert.NotNil(t, data)
expected := `[[1,2,3],[4,5,6]]`
assert.Equal(t, expected, string(data))
// float matrix
mf := matrix.CreateFromSlice([][]float64{{1.5, 2.5}, {3.5, 4.5}})
data, err = mf.MarshalJSON()
assert.Nil(t, err)
assert.NotNil(t, data)
expectedFloat := `[[1.5,2.5],[3.5,4.5]]`
assert.Equal(t, expectedFloat, string(data))
// slice of matrices via json.Marshal
m1 := matrix.CreateFromSlice([][]int{{1, 2}, {3, 4}})
m2 := matrix.CreateFromSlice([][]int{{5, 6, 7}})
matrices := []*matrix.Matrix[int]{m1, m2}
data, err = json.Marshal(matrices)
assert.Nil(t, err)
expected = `[[[1,2],[3,4]],[[5,6,7]]]`
assert.Equal(t, expected, string(data))
}

13
util/util.go Normal file
View File

@@ -0,0 +1,13 @@
// package util provides small utility functions that make it easier to express common things
package util
import "encoding/hex"
// DeHex decodes a hexadecimal string into a byte slice. Panics if the string is invalid.
func DeHex(s string) []byte {
decoded, err := hex.DecodeString(s)
if err != nil {
panic("invalid hex string")
}
return decoded
}

25
util/util_test.go Normal file
View File

@@ -0,0 +1,25 @@
package util_test
import (
"testing"
"git.omicron.one/playground/cryptography/util"
"github.com/stretchr/testify/assert"
)
func TestDeHex(t *testing.T) {
b := util.DeHex("")
assert.NotNil(t, b)
assert.Len(t, b, 0)
b = util.DeHex("deadbeef")
assert.NotNil(t, b)
assert.Equal(t, []byte("\xde\xad\xbe\xef"), b)
assert.PanicsWithValue(t, "invalid hex string", func() {
util.DeHex("dead serious this is not a hex string")
})
assert.PanicsWithValue(t, "invalid hex string", func() {
util.DeHex("deada55")
})
}