Compare commits
7 Commits
35e848ec43
...
main
Author | SHA1 | Date | |
---|---|---|---|
7b8a6f03c2 | |||
9729fe6dcb | |||
5f450e206f | |||
eae48cf9ae | |||
2730496b35 | |||
1ca7d572f9 | |||
7b8df3b046 |
40
.gitea/workflows/validate.yaml
Normal file
40
.gitea/workflows/validate.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Validate the build
|
||||
run-name: ${{ gitea.actor }} is validating
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
validate-build:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: node:current-alpine
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "https://dl-cdn.alpinelinux.org/alpine/edge/main" >> /etc/apk/repositories
|
||||
echo "https://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
|
||||
apk update
|
||||
apk add --no-cache git make bash go
|
||||
|
||||
GOBIN=/usr/local/bin go install mvdan.cc/gofumpt@latest
|
||||
|
||||
export "PATH=$PATH:/root/go/bin"
|
||||
|
||||
echo "---------------------"
|
||||
echo "Go version:"
|
||||
go version
|
||||
echo "---------------------"
|
||||
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Fetch dependencies
|
||||
run: |
|
||||
go mod download
|
||||
|
||||
- name: Validate the code and formatting
|
||||
run: |
|
||||
make validate
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
make test
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
/reports
|
||||
/bin
|
34
Makefile
Normal file
34
Makefile
Normal file
@@ -0,0 +1,34 @@
|
||||
BINARY_DIR = bin
|
||||
BINARIES = $(patsubst cmd/%/,%,$(wildcard cmd/*/))
|
||||
|
||||
.PHONY: all build test coverage validate clean purge $(BINARIES)
|
||||
|
||||
all: build
|
||||
|
||||
|
||||
build: $(BINARIES)
|
||||
|
||||
|
||||
$(BINARY_DIR):
|
||||
mkdir -p $(BINARY_DIR)
|
||||
|
||||
$(BINARIES): %: $(BINARY_DIR)
|
||||
go build -o $(BINARY_DIR)/$@ ./cmd/$@/
|
||||
|
||||
test:
|
||||
go test ./... -cover
|
||||
|
||||
coverage:
|
||||
mkdir -p reports/
|
||||
go test -coverprofile=reports/coverage.out ./... && go tool cover -html=reports/coverage.out
|
||||
|
||||
validate:
|
||||
@test -z "$(shell gofumpt -l .)" && echo "No files need formatting" || (echo "Incorrect formatting in:"; gofumpt -l .; exit 1)
|
||||
go vet ./...
|
||||
|
||||
clean:
|
||||
rm -rf $(BINARY_DIR)
|
||||
go clean
|
||||
|
||||
purge: clean
|
||||
rm -rf reports
|
172
cipher/speck/impl/benchmark128_test.go
Normal file
172
cipher/speck/impl/benchmark128_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package impl_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"git.omicron.one/playground/cryptography/cipher/speck/impl"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkKeyschedule128128(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128128)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
_, err = impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
impl.New128(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkKeyschedule128192(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128192)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
_, err = impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
impl.New128(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkKeyschedule128256(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128256)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
_, err = impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
impl.New128(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncrypt128128(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128128)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
ctx, err := impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
b.SetBytes(int64(ctx.BlockSize()))
|
||||
|
||||
ciphertext := make([]byte, ctx.BlockSize())
|
||||
plaintext := make([]byte, ctx.BlockSize())
|
||||
_, err = io.ReadFull(rand.Reader, plaintext)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx.Encrypt(ciphertext, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecrypt128128(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128128)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
ctx, err := impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
b.SetBytes(int64(ctx.BlockSize()))
|
||||
|
||||
plaintext := make([]byte, ctx.BlockSize())
|
||||
ciphertext := make([]byte, ctx.BlockSize())
|
||||
_, err = io.ReadFull(rand.Reader, ciphertext)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx.Decrypt(plaintext, ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncrypt128192(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128192)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
ctx, err := impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
b.SetBytes(int64(ctx.BlockSize()))
|
||||
|
||||
ciphertext := make([]byte, ctx.BlockSize())
|
||||
plaintext := make([]byte, ctx.BlockSize())
|
||||
_, err = io.ReadFull(rand.Reader, plaintext)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx.Encrypt(ciphertext, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecrypt128192(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128192)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
ctx, err := impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
b.SetBytes(int64(ctx.BlockSize()))
|
||||
|
||||
plaintext := make([]byte, ctx.BlockSize())
|
||||
ciphertext := make([]byte, ctx.BlockSize())
|
||||
_, err = io.ReadFull(rand.Reader, ciphertext)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx.Decrypt(plaintext, ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncrypt128256(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128256)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
ctx, err := impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
b.SetBytes(int64(ctx.BlockSize()))
|
||||
|
||||
ciphertext := make([]byte, ctx.BlockSize())
|
||||
plaintext := make([]byte, ctx.BlockSize())
|
||||
_, err = io.ReadFull(rand.Reader, plaintext)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx.Encrypt(ciphertext, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecrypt128256(b *testing.B) {
|
||||
key := make([]byte, impl.KeySize128256)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
assert.Nil(b, err)
|
||||
|
||||
ctx, err := impl.New128(key)
|
||||
assert.Nil(b, err)
|
||||
b.SetBytes(int64(ctx.BlockSize()))
|
||||
|
||||
plaintext := make([]byte, ctx.BlockSize())
|
||||
ciphertext := make([]byte, ctx.BlockSize())
|
||||
_, err = io.ReadFull(rand.Reader, ciphertext)
|
||||
assert.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
ctx.Decrypt(plaintext, ciphertext)
|
||||
}
|
||||
}
|
134
cipher/speck/impl/speck128_test.go
Normal file
134
cipher/speck/impl/speck128_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package impl_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"git.omicron.one/playground/cryptography/cipher"
|
||||
"git.omicron.one/playground/cryptography/cipher/speck/impl"
|
||||
. "git.omicron.one/playground/cryptography/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testVector128(t *testing.T, key, plaintext, ciphertext []byte, bs int, name string) {
|
||||
t.Helper()
|
||||
|
||||
buffer := make([]byte, len(plaintext))
|
||||
ctx, err := impl.New128(key)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, ctx)
|
||||
assert.Equal(t, bs, ctx.BlockSize())
|
||||
assert.Equal(t, name, ctx.Algorithm())
|
||||
|
||||
// Two buffers
|
||||
pt := slices.Clone(plaintext)
|
||||
ctx.Encrypt(buffer, pt)
|
||||
assert.Equal(t, plaintext, pt)
|
||||
assert.Equal(t, ciphertext, buffer)
|
||||
|
||||
clear(buffer)
|
||||
ct := slices.Clone(ciphertext)
|
||||
ctx.Decrypt(buffer, ct)
|
||||
assert.Equal(t, ciphertext, ct)
|
||||
assert.Equal(t, plaintext, buffer)
|
||||
|
||||
// In-place
|
||||
copy(buffer, plaintext)
|
||||
ctx.Encrypt(buffer, buffer)
|
||||
assert.Equal(t, ciphertext, buffer)
|
||||
ctx.Decrypt(buffer, buffer)
|
||||
assert.Equal(t, plaintext, buffer)
|
||||
}
|
||||
|
||||
func TestVector128128(t *testing.T) {
|
||||
var (
|
||||
key = DeHex("0f0e0d0c0b0a09080706050403020100")
|
||||
plaintext = DeHex("6c617669757165207469206564616d20")
|
||||
ciphertext = DeHex("a65d9851797832657860fedf5c570d18")
|
||||
bs = impl.BlockSize128
|
||||
name = "Speck128/128"
|
||||
)
|
||||
testVector128(t, key, plaintext, ciphertext, bs, name)
|
||||
}
|
||||
|
||||
func TestVector128192(t *testing.T) {
|
||||
var (
|
||||
key = DeHex("17161514131211100f0e0d0c0b0a09080706050403020100")
|
||||
plaintext = DeHex("726148206665696843206f7420746e65")
|
||||
ciphertext = DeHex("1be4cf3a13135566f9bc185de03c1886")
|
||||
bs = impl.BlockSize128
|
||||
name = "Speck128/192"
|
||||
)
|
||||
testVector128(t, key, plaintext, ciphertext, bs, name)
|
||||
}
|
||||
|
||||
func TestVector128256(t *testing.T) {
|
||||
var (
|
||||
key = DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100")
|
||||
plaintext = DeHex("65736f6874206e49202e72656e6f6f70")
|
||||
ciphertext = DeHex("4109010405c0f53e4eeeb48d9c188f43")
|
||||
bs = impl.BlockSize128
|
||||
name = "Speck128/256"
|
||||
)
|
||||
testVector128(t, key, plaintext, ciphertext, bs, name)
|
||||
}
|
||||
|
||||
func TestInvalidKey128(t *testing.T) {
|
||||
ctx, err := impl.New128(DeHex("deadbeef"))
|
||||
assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
|
||||
assert.Nil(t, ctx)
|
||||
}
|
||||
|
||||
func TestDecryptBlockSize128(t *testing.T) {
|
||||
ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, ctx)
|
||||
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()-1)
|
||||
ctx.Decrypt(nil, buffer)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()-1)
|
||||
ctx.Decrypt(buffer, nil)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()-1)
|
||||
ctx.Decrypt(buffer, buffer)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()+1)
|
||||
ctx.Decrypt(buffer, buffer)
|
||||
})
|
||||
assert.NotPanics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize())
|
||||
ctx.Decrypt(buffer, buffer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncryptBlockSize128(t *testing.T) {
|
||||
ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, ctx)
|
||||
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()-1)
|
||||
ctx.Encrypt(nil, buffer)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()-1)
|
||||
ctx.Encrypt(buffer, nil)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()-1)
|
||||
ctx.Encrypt(buffer, buffer)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize()+1)
|
||||
ctx.Encrypt(buffer, buffer)
|
||||
})
|
||||
assert.NotPanics(t, func() {
|
||||
buffer := make([]byte, ctx.BlockSize())
|
||||
ctx.Encrypt(buffer, buffer)
|
||||
})
|
||||
}
|
@@ -41,7 +41,7 @@ var keySizes = []int{
|
||||
// New creates a new speck block cipher context.
|
||||
// Returns the created block cipher or an error.
|
||||
func New(key []byte, param SpeckParameters) (cipher.Block, error) {
|
||||
if param == 0 || int(param) > len(keySizes) {
|
||||
if param <= 0 || int(param) >= len(keySizes) {
|
||||
panic("Invalid parameters")
|
||||
}
|
||||
keySize := keySizes[param]
|
||||
|
@@ -1,75 +1,93 @@
|
||||
package speck_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"git.omicron.one/playground/cryptography/cipher"
|
||||
"git.omicron.one/playground/cryptography/cipher/speck"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func DeHex(s string) []byte {
|
||||
decoded, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
panic("invalid hex string")
|
||||
func testKey(param speck.SpeckParameters) []byte {
|
||||
switch param {
|
||||
case speck.Speck3264:
|
||||
return make([]byte, 64/8)
|
||||
case speck.Speck4872:
|
||||
return make([]byte, 72/8)
|
||||
case speck.Speck4896, speck.Speck6496, speck.Speck9696:
|
||||
return make([]byte, 96/8)
|
||||
case speck.Speck64128, speck.Speck128128:
|
||||
return make([]byte, 128/8)
|
||||
case speck.Speck96144:
|
||||
return make([]byte, 144/8)
|
||||
case speck.Speck128192:
|
||||
return make([]byte, 192/8)
|
||||
case speck.Speck128256:
|
||||
return make([]byte, 256/8)
|
||||
}
|
||||
return decoded
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
type TestVector struct {
|
||||
Key []byte
|
||||
Plaintext []byte
|
||||
Ciphertext []byte
|
||||
Param speck.SpeckParameters
|
||||
func TestNew(t *testing.T) {
|
||||
notImplemented := []speck.SpeckParameters{
|
||||
speck.Speck3264,
|
||||
speck.Speck4872,
|
||||
speck.Speck4896,
|
||||
speck.Speck6496,
|
||||
speck.Speck64128,
|
||||
speck.Speck9696,
|
||||
speck.Speck96144,
|
||||
}
|
||||
implemented := []speck.SpeckParameters{
|
||||
speck.Speck128128,
|
||||
speck.Speck128192,
|
||||
speck.Speck128256,
|
||||
}
|
||||
|
||||
var vectors []TestVector = []TestVector{
|
||||
// Speck128/128 test vector
|
||||
{
|
||||
Key: DeHex("0f0e0d0c0b0a09080706050403020100"),
|
||||
Plaintext: DeHex("6c617669757165207469206564616d20"),
|
||||
Ciphertext: DeHex("a65d9851797832657860fedf5c570d18"),
|
||||
Param: speck.Speck128128,
|
||||
},
|
||||
{
|
||||
Key: DeHex("17161514131211100f0e0d0c0b0a09080706050403020100"),
|
||||
Plaintext: DeHex("726148206665696843206f7420746e65"),
|
||||
Ciphertext: DeHex("1be4cf3a13135566f9bc185de03c1886"),
|
||||
Param: speck.Speck128192,
|
||||
},
|
||||
{
|
||||
Key: DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100"),
|
||||
Plaintext: DeHex("65736f6874206e49202e72656e6f6f70"),
|
||||
Ciphertext: DeHex("4109010405c0f53e4eeeb48d9c188f43"),
|
||||
Param: speck.Speck128256,
|
||||
},
|
||||
for _, param := range notImplemented {
|
||||
key := testKey(param)
|
||||
ctx, err := speck.New(key, param)
|
||||
assert.Nil(t, ctx)
|
||||
assert.ErrorContains(t, err, "Not implemented")
|
||||
}
|
||||
|
||||
func TestVectors(t *testing.T) {
|
||||
for _, vector := range vectors {
|
||||
ctx, err := speck.New(vector.Key, vector.Param)
|
||||
assert.NotNil(t, ctx)
|
||||
for _, param := range implemented {
|
||||
key := testKey(param)
|
||||
ctx, err := speck.New(key, param)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Test in place
|
||||
buffer := slices.Clone(vector.Plaintext)
|
||||
ctx.Encrypt(buffer, buffer)
|
||||
assert.Equal(t, vector.Ciphertext, buffer, ctx.Algorithm())
|
||||
ctx.Decrypt(buffer, buffer)
|
||||
assert.Equal(t, vector.Plaintext, buffer, ctx.Algorithm())
|
||||
|
||||
// Test two buffers
|
||||
dst := make([]byte, len(vector.Ciphertext))
|
||||
src := slices.Clone(vector.Plaintext)
|
||||
ctx.Encrypt(dst, src)
|
||||
assert.Equal(t, vector.Plaintext, src, ctx.Algorithm())
|
||||
assert.Equal(t, vector.Ciphertext, dst, ctx.Algorithm())
|
||||
|
||||
dst = make([]byte, len(vector.Plaintext))
|
||||
src = slices.Clone(vector.Ciphertext)
|
||||
ctx.Decrypt(dst, src)
|
||||
assert.Equal(t, vector.Ciphertext, src, ctx.Algorithm())
|
||||
assert.Equal(t, vector.Plaintext, dst, ctx.Algorithm())
|
||||
assert.NotNil(t, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidKeyLength(t *testing.T) {
|
||||
params := []speck.SpeckParameters{
|
||||
speck.Speck3264,
|
||||
speck.Speck4872,
|
||||
speck.Speck4896,
|
||||
speck.Speck6496,
|
||||
speck.Speck64128,
|
||||
speck.Speck9696,
|
||||
speck.Speck96144,
|
||||
speck.Speck128128,
|
||||
speck.Speck128192,
|
||||
speck.Speck128256,
|
||||
}
|
||||
for _, param := range params {
|
||||
key := testKey(param)
|
||||
ctx, err := speck.New(key[1:], param)
|
||||
assert.Nil(t, ctx)
|
||||
assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidParam(t *testing.T) {
|
||||
assert.PanicsWithValue(t, "Invalid parameters", func() {
|
||||
speck.New(nil, -1)
|
||||
})
|
||||
assert.PanicsWithValue(t, "Invalid parameters", func() {
|
||||
speck.New(nil, 0)
|
||||
})
|
||||
assert.PanicsWithValue(t, "Invalid parameters", func() {
|
||||
speck.New(nil, speck.Speck128256+1)
|
||||
})
|
||||
}
|
||||
|
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module git.omicron.one/playground/cryptography
|
||||
|
||||
go 1.24.2
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.10.0
|
||||
|
@@ -4,6 +4,7 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
@@ -54,6 +55,33 @@ func Create[T Number](rows, cols int) *Matrix[T] {
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a new matrix from a given 2D slice of values. The first index of the
|
||||
// slice denotes the rows and the second index denotes the columns. Returns the
|
||||
// newly created matrix.
|
||||
//
|
||||
// Panics with ErrIncompatibleDataDimensions if any of the lengths are 0 or if
|
||||
// not all rows have the same length.
|
||||
func CreateFromSlice[T Number](values [][]T) *Matrix[T] {
|
||||
rows := len(values)
|
||||
if rows == 0 {
|
||||
panic(ErrInvalidDimensions)
|
||||
}
|
||||
cols := len(values[0])
|
||||
if cols == 0 {
|
||||
panic(ErrInvalidDimensions)
|
||||
}
|
||||
|
||||
m := Create[T](rows, cols)
|
||||
for i := range rows {
|
||||
if len(values[i]) != cols {
|
||||
panic(ErrIncompatibleDataDimensions)
|
||||
}
|
||||
copy(m.values[i], values[i])
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Creates a new matrix of a given size and sets the values from the given
|
||||
// slice. The values are used to fill the matrix left to right and top to
|
||||
// bottom. Returns the newly created matrix.
|
||||
@@ -80,6 +108,17 @@ func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] {
|
||||
return m
|
||||
}
|
||||
|
||||
// CreateFromJSON creates a new matrix from JSON data representing a
|
||||
// two-dimensional array. Returns an error if the JSON is invalid, the array is
|
||||
// empty, or the inner arrays have inconsistent lengths.
|
||||
func CreateFromJSON[T Number](data []byte) (*Matrix[T], error) {
|
||||
m := &Matrix[T]{}
|
||||
if err := m.UnmarshalJSON(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Convert will take a matrix of type T and convert it into a matrix of type U
|
||||
// Only works on SimpleNumber matrices
|
||||
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
|
||||
@@ -293,3 +332,39 @@ func (m *Matrix[T]) Fill(value T) *Matrix[T] {
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Matrix, creating a matrix from
|
||||
// a JSON two-dimensional array.
|
||||
func (m *Matrix[T]) UnmarshalJSON(data []byte) error {
|
||||
var values [][]T
|
||||
if err := json.Unmarshal(data, &values); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows := len(values)
|
||||
if rows == 0 {
|
||||
return ErrInvalidDimensions
|
||||
}
|
||||
cols := len(values[0])
|
||||
if cols == 0 {
|
||||
return ErrInvalidDimensions
|
||||
}
|
||||
|
||||
for i := range rows {
|
||||
if len(values[i]) != cols {
|
||||
return ErrIncompatibleDataDimensions
|
||||
}
|
||||
}
|
||||
|
||||
m.rows = rows
|
||||
m.cols = cols
|
||||
m.values = values
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Matrix, serializing the matrix as
|
||||
// a JSON two-dimensional array.
|
||||
func (m *Matrix[T]) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(m.values)
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package matrix_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"git.omicron.one/playground/cryptography/matrix"
|
||||
@@ -55,6 +56,41 @@ func TestCreate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromSlice(t *testing.T) {
|
||||
m := matrix.CreateFromSlice([][]int{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
})
|
||||
assert.NotNil(t, m)
|
||||
|
||||
assert.Equal(t, 2, m.Rows())
|
||||
assert.Equal(t, 3, m.Cols())
|
||||
|
||||
assert.Equal(t, 1, m.Get(0, 0))
|
||||
assert.Equal(t, 2, m.Get(0, 1))
|
||||
assert.Equal(t, 3, m.Get(0, 2))
|
||||
assert.Equal(t, 4, m.Get(1, 0))
|
||||
assert.Equal(t, 5, m.Get(1, 1))
|
||||
assert.Equal(t, 6, m.Get(1, 2))
|
||||
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.CreateFromSlice([][]int{})
|
||||
})
|
||||
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.CreateFromSlice([][]int{
|
||||
{},
|
||||
})
|
||||
})
|
||||
|
||||
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
|
||||
matrix.CreateFromSlice([][]int{
|
||||
{1, 2, 3},
|
||||
{4, 5},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromFlatSlice(t *testing.T) {
|
||||
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
assert.NotNil(t, m)
|
||||
@@ -94,6 +130,46 @@ func TestCreateFromFlatSlice(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromJSON(t *testing.T) {
|
||||
// data json
|
||||
data := []byte(`[[1, 2, 3], [4, 5, 6]]`)
|
||||
m, err := matrix.CreateFromJSON[int](data)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, m)
|
||||
assert.Equal(t, 2, m.Rows())
|
||||
assert.Equal(t, 3, m.Cols())
|
||||
assert.Equal(t, 1, m.Get(0, 0))
|
||||
assert.Equal(t, 2, m.Get(0, 1))
|
||||
assert.Equal(t, 3, m.Get(0, 2))
|
||||
assert.Equal(t, 4, m.Get(1, 0))
|
||||
assert.Equal(t, 5, m.Get(1, 1))
|
||||
assert.Equal(t, 6, m.Get(1, 2))
|
||||
|
||||
// invalid json
|
||||
data = []byte(`[[1, 2, 3], [4, 5,`)
|
||||
m, err = matrix.CreateFromJSON[int](data)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, m)
|
||||
|
||||
// empty matrix
|
||||
data = []byte(`[]`)
|
||||
m, err = matrix.CreateFromJSON[int](data)
|
||||
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||
assert.Nil(t, m)
|
||||
|
||||
// empty rows
|
||||
data = []byte(`[[]]`)
|
||||
m, err = matrix.CreateFromJSON[int](data)
|
||||
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||
assert.Nil(t, m)
|
||||
|
||||
// mixed row length
|
||||
data = []byte(`[[1, 2, 3], [4, 5]]`)
|
||||
m, err = matrix.CreateFromJSON[int](data)
|
||||
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
|
||||
assert.Nil(t, m)
|
||||
}
|
||||
|
||||
func TestSum(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||
@@ -493,3 +569,87 @@ func TestMatrix_Fill(t *testing.T) {
|
||||
assert.Equal(t, 3, a.Get(1, 1))
|
||||
assert.Equal(t, 3, a.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestMatrix_UnmarshalJSON(t *testing.T) {
|
||||
// int matrix
|
||||
data := []byte(`[[1,2,3],[4,5,6]]`)
|
||||
var m *matrix.Matrix[int]
|
||||
err := json.Unmarshal(data, &m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 2, m.Rows())
|
||||
assert.Equal(t, 3, m.Cols())
|
||||
assert.Equal(t, 1, m.Get(0, 0))
|
||||
assert.Equal(t, 2, m.Get(0, 1))
|
||||
assert.Equal(t, 3, m.Get(0, 2))
|
||||
assert.Equal(t, 4, m.Get(1, 0))
|
||||
assert.Equal(t, 5, m.Get(1, 1))
|
||||
assert.Equal(t, 6, m.Get(1, 2))
|
||||
|
||||
// float matrix
|
||||
data = []byte(`[[1.5,2.5],[3.5,4.5]]`)
|
||||
var mf *matrix.Matrix[float64]
|
||||
err = json.Unmarshal(data, &mf)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 2, mf.Rows())
|
||||
assert.Equal(t, 2, mf.Cols())
|
||||
assert.Equal(t, 1.5, mf.Get(0, 0))
|
||||
assert.Equal(t, 2.5, mf.Get(0, 1))
|
||||
assert.Equal(t, 3.5, mf.Get(1, 0))
|
||||
assert.Equal(t, 4.5, mf.Get(1, 1))
|
||||
|
||||
// via json.Unmarshal
|
||||
matrices := []byte(`[[[1,2],[3,4]],[[5,6,7]]]`)
|
||||
var ms []*matrix.Matrix[int]
|
||||
err = json.Unmarshal(matrices, &ms)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, ms, 2)
|
||||
assert.Equal(t, 2, ms[0].Get(0, 1))
|
||||
assert.Equal(t, 7, ms[1].Get(0, 2))
|
||||
|
||||
// invalid JSON
|
||||
err = m.UnmarshalJSON([]byte(`invalid`))
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// empty array
|
||||
err = m.UnmarshalJSON([]byte(`[]`))
|
||||
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||
|
||||
// empty inner array
|
||||
err = m.UnmarshalJSON([]byte(`[[]]`))
|
||||
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||
|
||||
// inconsistent lengths
|
||||
err = m.UnmarshalJSON([]byte(`[[1,2],[3]]`))
|
||||
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
|
||||
}
|
||||
|
||||
func TestMatrix_MarshallJSON(t *testing.T) {
|
||||
// int matrix
|
||||
m := matrix.CreateFromSlice([][]int{{1, 2, 3}, {4, 5, 6}})
|
||||
data, err := m.MarshalJSON()
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, data)
|
||||
|
||||
expected := `[[1,2,3],[4,5,6]]`
|
||||
assert.Equal(t, expected, string(data))
|
||||
|
||||
// float matrix
|
||||
mf := matrix.CreateFromSlice([][]float64{{1.5, 2.5}, {3.5, 4.5}})
|
||||
data, err = mf.MarshalJSON()
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, data)
|
||||
expectedFloat := `[[1.5,2.5],[3.5,4.5]]`
|
||||
assert.Equal(t, expectedFloat, string(data))
|
||||
|
||||
// slice of matrices via json.Marshal
|
||||
m1 := matrix.CreateFromSlice([][]int{{1, 2}, {3, 4}})
|
||||
m2 := matrix.CreateFromSlice([][]int{{5, 6, 7}})
|
||||
matrices := []*matrix.Matrix[int]{m1, m2}
|
||||
|
||||
data, err = json.Marshal(matrices)
|
||||
assert.Nil(t, err)
|
||||
expected = `[[[1,2],[3,4]],[[5,6,7]]]`
|
||||
assert.Equal(t, expected, string(data))
|
||||
}
|
||||
|
13
util/util.go
Normal file
13
util/util.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// package util provides small utility functions that make it easier to express common things
|
||||
package util
|
||||
|
||||
import "encoding/hex"
|
||||
|
||||
// DeHex decodes a hexadecimal string into a byte slice. Panics if the string is invalid.
|
||||
func DeHex(s string) []byte {
|
||||
decoded, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
panic("invalid hex string")
|
||||
}
|
||||
return decoded
|
||||
}
|
25
util/util_test.go
Normal file
25
util/util_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package util_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.omicron.one/playground/cryptography/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeHex(t *testing.T) {
|
||||
b := util.DeHex("")
|
||||
assert.NotNil(t, b)
|
||||
assert.Len(t, b, 0)
|
||||
|
||||
b = util.DeHex("deadbeef")
|
||||
assert.NotNil(t, b)
|
||||
assert.Equal(t, []byte("\xde\xad\xbe\xef"), b)
|
||||
|
||||
assert.PanicsWithValue(t, "invalid hex string", func() {
|
||||
util.DeHex("dead serious this is not a hex string")
|
||||
})
|
||||
assert.PanicsWithValue(t, "invalid hex string", func() {
|
||||
util.DeHex("deada55")
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user