Compare commits
9 Commits
2756894219
...
main
Author | SHA1 | Date | |
---|---|---|---|
7b8a6f03c2 | |||
9729fe6dcb | |||
5f450e206f | |||
eae48cf9ae | |||
2730496b35 | |||
1ca7d572f9 | |||
7b8df3b046 | |||
35e848ec43 | |||
cf17a6fb72 |
40
.gitea/workflows/validate.yaml
Normal file
40
.gitea/workflows/validate.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: Validate the build
|
||||||
|
run-name: ${{ gitea.actor }} is validating
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
validate-build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: node:current-alpine
|
||||||
|
steps:
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
echo "https://dl-cdn.alpinelinux.org/alpine/edge/main" >> /etc/apk/repositories
|
||||||
|
echo "https://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
|
||||||
|
apk update
|
||||||
|
apk add --no-cache git make bash go
|
||||||
|
|
||||||
|
GOBIN=/usr/local/bin go install mvdan.cc/gofumpt@latest
|
||||||
|
|
||||||
|
export "PATH=$PATH:/root/go/bin"
|
||||||
|
|
||||||
|
echo "---------------------"
|
||||||
|
echo "Go version:"
|
||||||
|
go version
|
||||||
|
echo "---------------------"
|
||||||
|
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Fetch dependencies
|
||||||
|
run: |
|
||||||
|
go mod download
|
||||||
|
|
||||||
|
- name: Validate the code and formatting
|
||||||
|
run: |
|
||||||
|
make validate
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
make test
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
/reports
|
||||||
|
/bin
|
34
Makefile
Normal file
34
Makefile
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
BINARY_DIR = bin
|
||||||
|
BINARIES = $(patsubst cmd/%/,%,$(wildcard cmd/*/))
|
||||||
|
|
||||||
|
.PHONY: all build test coverage validate clean purge $(BINARIES)
|
||||||
|
|
||||||
|
all: build
|
||||||
|
|
||||||
|
|
||||||
|
build: $(BINARIES)
|
||||||
|
|
||||||
|
|
||||||
|
$(BINARY_DIR):
|
||||||
|
mkdir -p $(BINARY_DIR)
|
||||||
|
|
||||||
|
$(BINARIES): %: $(BINARY_DIR)
|
||||||
|
go build -o $(BINARY_DIR)/$@ ./cmd/$@/
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./... -cover
|
||||||
|
|
||||||
|
coverage:
|
||||||
|
mkdir -p reports/
|
||||||
|
go test -coverprofile=reports/coverage.out ./... && go tool cover -html=reports/coverage.out
|
||||||
|
|
||||||
|
validate:
|
||||||
|
@test -z "$(shell gofumpt -l .)" && echo "No files need formatting" || (echo "Incorrect formatting in:"; gofumpt -l .; exit 1)
|
||||||
|
go vet ./...
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf $(BINARY_DIR)
|
||||||
|
go clean
|
||||||
|
|
||||||
|
purge: clean
|
||||||
|
rm -rf reports
|
172
cipher/speck/impl/benchmark128_test.go
Normal file
172
cipher/speck/impl/benchmark128_test.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package impl_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.omicron.one/playground/cryptography/cipher/speck/impl"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkKeyschedule128128(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128128)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
_, err = impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
impl.New128(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkKeyschedule128192(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128192)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
_, err = impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
impl.New128(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkKeyschedule128256(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128256)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
_, err = impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
impl.New128(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt128128(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128128)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
ctx, err := impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
b.SetBytes(int64(ctx.BlockSize()))
|
||||||
|
|
||||||
|
ciphertext := make([]byte, ctx.BlockSize())
|
||||||
|
plaintext := make([]byte, ctx.BlockSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, plaintext)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
ctx.Encrypt(ciphertext, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecrypt128128(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128128)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
ctx, err := impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
b.SetBytes(int64(ctx.BlockSize()))
|
||||||
|
|
||||||
|
plaintext := make([]byte, ctx.BlockSize())
|
||||||
|
ciphertext := make([]byte, ctx.BlockSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, ciphertext)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
ctx.Decrypt(plaintext, ciphertext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt128192(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128192)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
ctx, err := impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
b.SetBytes(int64(ctx.BlockSize()))
|
||||||
|
|
||||||
|
ciphertext := make([]byte, ctx.BlockSize())
|
||||||
|
plaintext := make([]byte, ctx.BlockSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, plaintext)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
ctx.Encrypt(ciphertext, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecrypt128192(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128192)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
ctx, err := impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
b.SetBytes(int64(ctx.BlockSize()))
|
||||||
|
|
||||||
|
plaintext := make([]byte, ctx.BlockSize())
|
||||||
|
ciphertext := make([]byte, ctx.BlockSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, ciphertext)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
ctx.Decrypt(plaintext, ciphertext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt128256(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128256)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
ctx, err := impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
b.SetBytes(int64(ctx.BlockSize()))
|
||||||
|
|
||||||
|
ciphertext := make([]byte, ctx.BlockSize())
|
||||||
|
plaintext := make([]byte, ctx.BlockSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, plaintext)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
ctx.Encrypt(ciphertext, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecrypt128256(b *testing.B) {
|
||||||
|
key := make([]byte, impl.KeySize128256)
|
||||||
|
_, err := io.ReadFull(rand.Reader, key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
ctx, err := impl.New128(key)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
b.SetBytes(int64(ctx.BlockSize()))
|
||||||
|
|
||||||
|
plaintext := make([]byte, ctx.BlockSize())
|
||||||
|
ciphertext := make([]byte, ctx.BlockSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, ciphertext)
|
||||||
|
assert.Nil(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
ctx.Decrypt(plaintext, ciphertext)
|
||||||
|
}
|
||||||
|
}
|
134
cipher/speck/impl/speck128_test.go
Normal file
134
cipher/speck/impl/speck128_test.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package impl_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.omicron.one/playground/cryptography/cipher"
|
||||||
|
"git.omicron.one/playground/cryptography/cipher/speck/impl"
|
||||||
|
. "git.omicron.one/playground/cryptography/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testVector128(t *testing.T, key, plaintext, ciphertext []byte, bs int, name string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
buffer := make([]byte, len(plaintext))
|
||||||
|
ctx, err := impl.New128(key)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, ctx)
|
||||||
|
assert.Equal(t, bs, ctx.BlockSize())
|
||||||
|
assert.Equal(t, name, ctx.Algorithm())
|
||||||
|
|
||||||
|
// Two buffers
|
||||||
|
pt := slices.Clone(plaintext)
|
||||||
|
ctx.Encrypt(buffer, pt)
|
||||||
|
assert.Equal(t, plaintext, pt)
|
||||||
|
assert.Equal(t, ciphertext, buffer)
|
||||||
|
|
||||||
|
clear(buffer)
|
||||||
|
ct := slices.Clone(ciphertext)
|
||||||
|
ctx.Decrypt(buffer, ct)
|
||||||
|
assert.Equal(t, ciphertext, ct)
|
||||||
|
assert.Equal(t, plaintext, buffer)
|
||||||
|
|
||||||
|
// In-place
|
||||||
|
copy(buffer, plaintext)
|
||||||
|
ctx.Encrypt(buffer, buffer)
|
||||||
|
assert.Equal(t, ciphertext, buffer)
|
||||||
|
ctx.Decrypt(buffer, buffer)
|
||||||
|
assert.Equal(t, plaintext, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVector128128(t *testing.T) {
|
||||||
|
var (
|
||||||
|
key = DeHex("0f0e0d0c0b0a09080706050403020100")
|
||||||
|
plaintext = DeHex("6c617669757165207469206564616d20")
|
||||||
|
ciphertext = DeHex("a65d9851797832657860fedf5c570d18")
|
||||||
|
bs = impl.BlockSize128
|
||||||
|
name = "Speck128/128"
|
||||||
|
)
|
||||||
|
testVector128(t, key, plaintext, ciphertext, bs, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVector128192(t *testing.T) {
|
||||||
|
var (
|
||||||
|
key = DeHex("17161514131211100f0e0d0c0b0a09080706050403020100")
|
||||||
|
plaintext = DeHex("726148206665696843206f7420746e65")
|
||||||
|
ciphertext = DeHex("1be4cf3a13135566f9bc185de03c1886")
|
||||||
|
bs = impl.BlockSize128
|
||||||
|
name = "Speck128/192"
|
||||||
|
)
|
||||||
|
testVector128(t, key, plaintext, ciphertext, bs, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVector128256(t *testing.T) {
|
||||||
|
var (
|
||||||
|
key = DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100")
|
||||||
|
plaintext = DeHex("65736f6874206e49202e72656e6f6f70")
|
||||||
|
ciphertext = DeHex("4109010405c0f53e4eeeb48d9c188f43")
|
||||||
|
bs = impl.BlockSize128
|
||||||
|
name = "Speck128/256"
|
||||||
|
)
|
||||||
|
testVector128(t, key, plaintext, ciphertext, bs, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidKey128(t *testing.T) {
|
||||||
|
ctx, err := impl.New128(DeHex("deadbeef"))
|
||||||
|
assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
|
||||||
|
assert.Nil(t, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptBlockSize128(t *testing.T) {
|
||||||
|
ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, ctx)
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()-1)
|
||||||
|
ctx.Decrypt(nil, buffer)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()-1)
|
||||||
|
ctx.Decrypt(buffer, nil)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()-1)
|
||||||
|
ctx.Decrypt(buffer, buffer)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()+1)
|
||||||
|
ctx.Decrypt(buffer, buffer)
|
||||||
|
})
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize())
|
||||||
|
ctx.Decrypt(buffer, buffer)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptBlockSize128(t *testing.T) {
|
||||||
|
ctx, err := impl.New128(DeHex("0f0e0d0c0b0a09080706050403020100"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, ctx)
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()-1)
|
||||||
|
ctx.Encrypt(nil, buffer)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()-1)
|
||||||
|
ctx.Encrypt(buffer, nil)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()-1)
|
||||||
|
ctx.Encrypt(buffer, buffer)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize()+1)
|
||||||
|
ctx.Encrypt(buffer, buffer)
|
||||||
|
})
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
buffer := make([]byte, ctx.BlockSize())
|
||||||
|
ctx.Encrypt(buffer, buffer)
|
||||||
|
})
|
||||||
|
}
|
@@ -41,7 +41,7 @@ var keySizes = []int{
|
|||||||
// New creates a new speck block cipher context.
|
// New creates a new speck block cipher context.
|
||||||
// Returns the created block cipher or an error.
|
// Returns the created block cipher or an error.
|
||||||
func New(key []byte, param SpeckParameters) (cipher.Block, error) {
|
func New(key []byte, param SpeckParameters) (cipher.Block, error) {
|
||||||
if param == 0 || int(param) > len(keySizes) {
|
if param <= 0 || int(param) >= len(keySizes) {
|
||||||
panic("Invalid parameters")
|
panic("Invalid parameters")
|
||||||
}
|
}
|
||||||
keySize := keySizes[param]
|
keySize := keySizes[param]
|
||||||
|
@@ -1,75 +1,93 @@
|
|||||||
package speck_test
|
package speck_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
|
||||||
"slices"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.omicron.one/playground/cryptography/cipher"
|
||||||
"git.omicron.one/playground/cryptography/cipher/speck"
|
"git.omicron.one/playground/cryptography/cipher/speck"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DeHex(s string) []byte {
|
func testKey(param speck.SpeckParameters) []byte {
|
||||||
decoded, err := hex.DecodeString(s)
|
switch param {
|
||||||
if err != nil {
|
case speck.Speck3264:
|
||||||
panic("invalid hex string")
|
return make([]byte, 64/8)
|
||||||
|
case speck.Speck4872:
|
||||||
|
return make([]byte, 72/8)
|
||||||
|
case speck.Speck4896, speck.Speck6496, speck.Speck9696:
|
||||||
|
return make([]byte, 96/8)
|
||||||
|
case speck.Speck64128, speck.Speck128128:
|
||||||
|
return make([]byte, 128/8)
|
||||||
|
case speck.Speck96144:
|
||||||
|
return make([]byte, 144/8)
|
||||||
|
case speck.Speck128192:
|
||||||
|
return make([]byte, 192/8)
|
||||||
|
case speck.Speck128256:
|
||||||
|
return make([]byte, 256/8)
|
||||||
}
|
}
|
||||||
return decoded
|
panic("unreachable")
|
||||||
}
|
}
|
||||||
|
|
||||||
type TestVector struct {
|
func TestNew(t *testing.T) {
|
||||||
Key []byte
|
notImplemented := []speck.SpeckParameters{
|
||||||
Plaintext []byte
|
speck.Speck3264,
|
||||||
Ciphertext []byte
|
speck.Speck4872,
|
||||||
Param speck.SpeckParameters
|
speck.Speck4896,
|
||||||
|
speck.Speck6496,
|
||||||
|
speck.Speck64128,
|
||||||
|
speck.Speck9696,
|
||||||
|
speck.Speck96144,
|
||||||
|
}
|
||||||
|
implemented := []speck.SpeckParameters{
|
||||||
|
speck.Speck128128,
|
||||||
|
speck.Speck128192,
|
||||||
|
speck.Speck128256,
|
||||||
}
|
}
|
||||||
|
|
||||||
var vectors []TestVector = []TestVector{
|
for _, param := range notImplemented {
|
||||||
// Speck128/128 test vector
|
key := testKey(param)
|
||||||
{
|
ctx, err := speck.New(key, param)
|
||||||
Key: DeHex("0f0e0d0c0b0a09080706050403020100"),
|
assert.Nil(t, ctx)
|
||||||
Plaintext: DeHex("6c617669757165207469206564616d20"),
|
assert.ErrorContains(t, err, "Not implemented")
|
||||||
Ciphertext: DeHex("a65d9851797832657860fedf5c570d18"),
|
|
||||||
Param: speck.Speck128128,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: DeHex("17161514131211100f0e0d0c0b0a09080706050403020100"),
|
|
||||||
Plaintext: DeHex("726148206665696843206f7420746e65"),
|
|
||||||
Ciphertext: DeHex("1be4cf3a13135566f9bc185de03c1886"),
|
|
||||||
Param: speck.Speck128192,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: DeHex("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100"),
|
|
||||||
Plaintext: DeHex("65736f6874206e49202e72656e6f6f70"),
|
|
||||||
Ciphertext: DeHex("4109010405c0f53e4eeeb48d9c188f43"),
|
|
||||||
Param: speck.Speck128256,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVectors(t *testing.T) {
|
for _, param := range implemented {
|
||||||
for _, vector := range vectors {
|
key := testKey(param)
|
||||||
ctx, err := speck.New(vector.Key, vector.Param)
|
ctx, err := speck.New(key, param)
|
||||||
assert.NotNil(t, ctx)
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, ctx)
|
||||||
// Test in place
|
|
||||||
buffer := slices.Clone(vector.Plaintext)
|
|
||||||
ctx.Encrypt(buffer, buffer)
|
|
||||||
assert.Equal(t, vector.Ciphertext, buffer, ctx.Algorithm())
|
|
||||||
ctx.Decrypt(buffer, buffer)
|
|
||||||
assert.Equal(t, vector.Plaintext, buffer, ctx.Algorithm())
|
|
||||||
|
|
||||||
// Test two buffers
|
|
||||||
dst := make([]byte, len(vector.Ciphertext))
|
|
||||||
src := slices.Clone(vector.Plaintext)
|
|
||||||
ctx.Encrypt(dst, src)
|
|
||||||
assert.Equal(t, vector.Plaintext, src, ctx.Algorithm())
|
|
||||||
assert.Equal(t, vector.Ciphertext, dst, ctx.Algorithm())
|
|
||||||
|
|
||||||
dst = make([]byte, len(vector.Plaintext))
|
|
||||||
src = slices.Clone(vector.Ciphertext)
|
|
||||||
ctx.Decrypt(dst, src)
|
|
||||||
assert.Equal(t, vector.Ciphertext, src, ctx.Algorithm())
|
|
||||||
assert.Equal(t, vector.Plaintext, dst, ctx.Algorithm())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInvalidKeyLength(t *testing.T) {
|
||||||
|
params := []speck.SpeckParameters{
|
||||||
|
speck.Speck3264,
|
||||||
|
speck.Speck4872,
|
||||||
|
speck.Speck4896,
|
||||||
|
speck.Speck6496,
|
||||||
|
speck.Speck64128,
|
||||||
|
speck.Speck9696,
|
||||||
|
speck.Speck96144,
|
||||||
|
speck.Speck128128,
|
||||||
|
speck.Speck128192,
|
||||||
|
speck.Speck128256,
|
||||||
|
}
|
||||||
|
for _, param := range params {
|
||||||
|
key := testKey(param)
|
||||||
|
ctx, err := speck.New(key[1:], param)
|
||||||
|
assert.Nil(t, ctx)
|
||||||
|
assert.ErrorIs(t, cipher.ErrInvalidKeyLength, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidParam(t *testing.T) {
|
||||||
|
assert.PanicsWithValue(t, "Invalid parameters", func() {
|
||||||
|
speck.New(nil, -1)
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, "Invalid parameters", func() {
|
||||||
|
speck.New(nil, 0)
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, "Invalid parameters", func() {
|
||||||
|
speck.New(nil, speck.Speck128256+1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module git.omicron.one/playground/cryptography
|
module git.omicron.one/playground/cryptography
|
||||||
|
|
||||||
go 1.24.2
|
go 1.23.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
|
128
matrix/matrix.go
128
matrix/matrix.go
@@ -4,6 +4,7 @@
|
|||||||
package matrix
|
package matrix
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"golang.org/x/exp/constraints"
|
"golang.org/x/exp/constraints"
|
||||||
@@ -54,6 +55,33 @@ func Create[T Number](rows, cols int) *Matrix[T] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creates a new matrix from a given 2D slice of values. The first index of the
|
||||||
|
// slice denotes the rows and the second index denotes the columns. Returns the
|
||||||
|
// newly created matrix.
|
||||||
|
//
|
||||||
|
// Panics with ErrIncompatibleDataDimensions if any of the lengths are 0 or if
|
||||||
|
// not all rows have the same length.
|
||||||
|
func CreateFromSlice[T Number](values [][]T) *Matrix[T] {
|
||||||
|
rows := len(values)
|
||||||
|
if rows == 0 {
|
||||||
|
panic(ErrInvalidDimensions)
|
||||||
|
}
|
||||||
|
cols := len(values[0])
|
||||||
|
if cols == 0 {
|
||||||
|
panic(ErrInvalidDimensions)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Create[T](rows, cols)
|
||||||
|
for i := range rows {
|
||||||
|
if len(values[i]) != cols {
|
||||||
|
panic(ErrIncompatibleDataDimensions)
|
||||||
|
}
|
||||||
|
copy(m.values[i], values[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// Creates a new matrix of a given size and sets the values from the given
|
// Creates a new matrix of a given size and sets the values from the given
|
||||||
// slice. The values are used to fill the matrix left to right and top to
|
// slice. The values are used to fill the matrix left to right and top to
|
||||||
// bottom. Returns the newly created matrix.
|
// bottom. Returns the newly created matrix.
|
||||||
@@ -80,6 +108,17 @@ func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateFromJSON creates a new matrix from JSON data representing a
|
||||||
|
// two-dimensional array. Returns an error if the JSON is invalid, the array is
|
||||||
|
// empty, or the inner arrays have inconsistent lengths.
|
||||||
|
func CreateFromJSON[T Number](data []byte) (*Matrix[T], error) {
|
||||||
|
m := &Matrix[T]{}
|
||||||
|
if err := m.UnmarshalJSON(data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Convert will take a matrix of type T and convert it into a matrix of type U
|
// Convert will take a matrix of type T and convert it into a matrix of type U
|
||||||
// Only works on SimpleNumber matrices
|
// Only works on SimpleNumber matrices
|
||||||
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
|
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
|
||||||
@@ -112,6 +151,13 @@ func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
|||||||
return first.Copy().Add(additional...)
|
return first.Copy().Add(additional...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HadamardProduct takes at least one matrix and performs component-wise
|
||||||
|
// multiplication of all given matrices. Returns a new matrix with the computed
|
||||||
|
// values. Panics if the arguments don't have matching dimensions.
|
||||||
|
func HadamardProduct[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
||||||
|
return first.Copy().HadamardMultiply(additional...)
|
||||||
|
}
|
||||||
|
|
||||||
// Copy creates a deep copy of the matrix and returns the new instance
|
// Copy creates a deep copy of the matrix and returns the new instance
|
||||||
func (m *Matrix[T]) Copy() *Matrix[T] {
|
func (m *Matrix[T]) Copy() *Matrix[T] {
|
||||||
mCopy := Create[T](m.rows, m.cols)
|
mCopy := Create[T](m.rows, m.cols)
|
||||||
@@ -240,3 +286,85 @@ func (m *Matrix[T]) Scale(scalar T) *Matrix[T] {
|
|||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HadamardMultiply performs an in-place component-wise multiplication of this
|
||||||
|
// matrix with zero or more matrices.
|
||||||
|
// Ensures correct behavior even if the matrix itself is passed as one or more
|
||||||
|
// arguments.
|
||||||
|
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
|
||||||
|
// have matching dimensions.
|
||||||
|
func (m *Matrix[T]) HadamardMultiply(matrices ...*Matrix[T]) *Matrix[T] {
|
||||||
|
numSelf := 0
|
||||||
|
for _, other := range matrices {
|
||||||
|
if m.rows != other.rows || m.cols != other.cols {
|
||||||
|
panic(ErrIncompatibleMatrixDimensions)
|
||||||
|
}
|
||||||
|
if other == m {
|
||||||
|
numSelf += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have multiple self references to multiply, work on duplicate data to
|
||||||
|
// make multiplication behave as expected
|
||||||
|
values := m.values
|
||||||
|
if numSelf > 1 {
|
||||||
|
values = m.Copy().values
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, other := range matrices {
|
||||||
|
for i := range m.rows {
|
||||||
|
for j := range m.cols {
|
||||||
|
values[i][j] *= other.values[i][j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.values = values
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill sets all components of this matrix to the given value. Returns the receiver.
|
||||||
|
func (m *Matrix[T]) Fill(value T) *Matrix[T] {
|
||||||
|
for i := range m.rows {
|
||||||
|
for j := range m.cols {
|
||||||
|
m.values[i][j] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler for Matrix, creating a matrix from
|
||||||
|
// a JSON two-dimensional array.
|
||||||
|
func (m *Matrix[T]) UnmarshalJSON(data []byte) error {
|
||||||
|
var values [][]T
|
||||||
|
if err := json.Unmarshal(data, &values); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := len(values)
|
||||||
|
if rows == 0 {
|
||||||
|
return ErrInvalidDimensions
|
||||||
|
}
|
||||||
|
cols := len(values[0])
|
||||||
|
if cols == 0 {
|
||||||
|
return ErrInvalidDimensions
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range rows {
|
||||||
|
if len(values[i]) != cols {
|
||||||
|
return ErrIncompatibleDataDimensions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.rows = rows
|
||||||
|
m.cols = cols
|
||||||
|
m.values = values
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler for Matrix, serializing the matrix as
|
||||||
|
// a JSON two-dimensional array.
|
||||||
|
func (m *Matrix[T]) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(m.values)
|
||||||
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package matrix_test
|
package matrix_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.omicron.one/playground/cryptography/matrix"
|
"git.omicron.one/playground/cryptography/matrix"
|
||||||
@@ -55,6 +56,41 @@ func TestCreate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateFromSlice(t *testing.T) {
|
||||||
|
m := matrix.CreateFromSlice([][]int{
|
||||||
|
{1, 2, 3},
|
||||||
|
{4, 5, 6},
|
||||||
|
})
|
||||||
|
assert.NotNil(t, m)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, m.Rows())
|
||||||
|
assert.Equal(t, 3, m.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 1, m.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, m.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, m.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, m.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, m.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, m.Get(1, 2))
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.CreateFromSlice([][]int{})
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.CreateFromSlice([][]int{
|
||||||
|
{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
|
||||||
|
matrix.CreateFromSlice([][]int{
|
||||||
|
{1, 2, 3},
|
||||||
|
{4, 5},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateFromFlatSlice(t *testing.T) {
|
func TestCreateFromFlatSlice(t *testing.T) {
|
||||||
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
assert.NotNil(t, m)
|
assert.NotNil(t, m)
|
||||||
@@ -94,6 +130,46 @@ func TestCreateFromFlatSlice(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateFromJSON(t *testing.T) {
|
||||||
|
// data json
|
||||||
|
data := []byte(`[[1, 2, 3], [4, 5, 6]]`)
|
||||||
|
m, err := matrix.CreateFromJSON[int](data)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, m)
|
||||||
|
assert.Equal(t, 2, m.Rows())
|
||||||
|
assert.Equal(t, 3, m.Cols())
|
||||||
|
assert.Equal(t, 1, m.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, m.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, m.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, m.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, m.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, m.Get(1, 2))
|
||||||
|
|
||||||
|
// invalid json
|
||||||
|
data = []byte(`[[1, 2, 3], [4, 5,`)
|
||||||
|
m, err = matrix.CreateFromJSON[int](data)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Nil(t, m)
|
||||||
|
|
||||||
|
// empty matrix
|
||||||
|
data = []byte(`[]`)
|
||||||
|
m, err = matrix.CreateFromJSON[int](data)
|
||||||
|
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||||
|
assert.Nil(t, m)
|
||||||
|
|
||||||
|
// empty rows
|
||||||
|
data = []byte(`[[]]`)
|
||||||
|
m, err = matrix.CreateFromJSON[int](data)
|
||||||
|
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||||
|
assert.Nil(t, m)
|
||||||
|
|
||||||
|
// mixed row length
|
||||||
|
data = []byte(`[[1, 2, 3], [4, 5]]`)
|
||||||
|
m, err = matrix.CreateFromJSON[int](data)
|
||||||
|
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
|
||||||
|
assert.Nil(t, m)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSum(t *testing.T) {
|
func TestSum(t *testing.T) {
|
||||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||||
@@ -366,3 +442,214 @@ func TestMatrix_Scale(t *testing.T) {
|
|||||||
assert.Equal(t, -20, a.Get(1, 1))
|
assert.Equal(t, -20, a.Get(1, 1))
|
||||||
assert.Equal(t, -24, a.Get(1, 2))
|
assert.Equal(t, -24, a.Get(1, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMatrix_HadamardMultiply(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
|
||||||
|
|
||||||
|
// Multiply nothing
|
||||||
|
c := a.HadamardMultiply()
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply itself multiple times
|
||||||
|
c = a.HadamardMultiply(a, a, a)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 16, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 81, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 256, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 625, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 1296, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply other matrix
|
||||||
|
c = a.HadamardMultiply(b)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 32, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 162, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 512, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 1250, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 2592, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply incorrect dimension
|
||||||
|
assert.PanicsWithValue(
|
||||||
|
t, matrix.ErrIncompatibleMatrixDimensions,
|
||||||
|
func() {
|
||||||
|
d := matrix.Create[int](3, 2)
|
||||||
|
a.HadamardMultiply(b, a, a, d)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 32, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 162, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 512, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 1250, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 2592, a.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHadamardProduct(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
|
||||||
|
|
||||||
|
// Multiply only one matrix
|
||||||
|
c := matrix.HadamardProduct(a)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotSame(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 1, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Multilply one matrix multiple times
|
||||||
|
c = matrix.HadamardProduct(a, a, a)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotSame(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 1, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 8, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 27, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 64, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 125, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 216, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply different matrices
|
||||||
|
c = matrix.HadamardProduct(a, b)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotEqual(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 4, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 6, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 8, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 10, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 12, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply incorrect dimensions
|
||||||
|
d := matrix.Create[int](3, 2)
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() {
|
||||||
|
matrix.HadamardProduct(a, d)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Fill(t *testing.T) {
|
||||||
|
a := matrix.Create[int](2, 3)
|
||||||
|
assert.Equal(t, 0, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 0, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 0, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 0, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 0, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 0, a.Get(1, 2))
|
||||||
|
|
||||||
|
a.Fill(3)
|
||||||
|
assert.Equal(t, 3, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 3, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 3, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 3, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 3, a.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_UnmarshalJSON(t *testing.T) {
|
||||||
|
// int matrix
|
||||||
|
data := []byte(`[[1,2,3],[4,5,6]]`)
|
||||||
|
var m *matrix.Matrix[int]
|
||||||
|
err := json.Unmarshal(data, &m)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, m.Rows())
|
||||||
|
assert.Equal(t, 3, m.Cols())
|
||||||
|
assert.Equal(t, 1, m.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, m.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, m.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, m.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, m.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, m.Get(1, 2))
|
||||||
|
|
||||||
|
// float matrix
|
||||||
|
data = []byte(`[[1.5,2.5],[3.5,4.5]]`)
|
||||||
|
var mf *matrix.Matrix[float64]
|
||||||
|
err = json.Unmarshal(data, &mf)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, mf.Rows())
|
||||||
|
assert.Equal(t, 2, mf.Cols())
|
||||||
|
assert.Equal(t, 1.5, mf.Get(0, 0))
|
||||||
|
assert.Equal(t, 2.5, mf.Get(0, 1))
|
||||||
|
assert.Equal(t, 3.5, mf.Get(1, 0))
|
||||||
|
assert.Equal(t, 4.5, mf.Get(1, 1))
|
||||||
|
|
||||||
|
// via json.Unmarshal
|
||||||
|
matrices := []byte(`[[[1,2],[3,4]],[[5,6,7]]]`)
|
||||||
|
var ms []*matrix.Matrix[int]
|
||||||
|
err = json.Unmarshal(matrices, &ms)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Len(t, ms, 2)
|
||||||
|
assert.Equal(t, 2, ms[0].Get(0, 1))
|
||||||
|
assert.Equal(t, 7, ms[1].Get(0, 2))
|
||||||
|
|
||||||
|
// invalid JSON
|
||||||
|
err = m.UnmarshalJSON([]byte(`invalid`))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
// empty array
|
||||||
|
err = m.UnmarshalJSON([]byte(`[]`))
|
||||||
|
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||||
|
|
||||||
|
// empty inner array
|
||||||
|
err = m.UnmarshalJSON([]byte(`[[]]`))
|
||||||
|
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
|
||||||
|
|
||||||
|
// inconsistent lengths
|
||||||
|
err = m.UnmarshalJSON([]byte(`[[1,2],[3]]`))
|
||||||
|
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_MarshallJSON(t *testing.T) {
|
||||||
|
// int matrix
|
||||||
|
m := matrix.CreateFromSlice([][]int{{1, 2, 3}, {4, 5, 6}})
|
||||||
|
data, err := m.MarshalJSON()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
|
||||||
|
expected := `[[1,2,3],[4,5,6]]`
|
||||||
|
assert.Equal(t, expected, string(data))
|
||||||
|
|
||||||
|
// float matrix
|
||||||
|
mf := matrix.CreateFromSlice([][]float64{{1.5, 2.5}, {3.5, 4.5}})
|
||||||
|
data, err = mf.MarshalJSON()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
expectedFloat := `[[1.5,2.5],[3.5,4.5]]`
|
||||||
|
assert.Equal(t, expectedFloat, string(data))
|
||||||
|
|
||||||
|
// slice of matrices via json.Marshal
|
||||||
|
m1 := matrix.CreateFromSlice([][]int{{1, 2}, {3, 4}})
|
||||||
|
m2 := matrix.CreateFromSlice([][]int{{5, 6, 7}})
|
||||||
|
matrices := []*matrix.Matrix[int]{m1, m2}
|
||||||
|
|
||||||
|
data, err = json.Marshal(matrices)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
expected = `[[[1,2],[3,4]],[[5,6,7]]]`
|
||||||
|
assert.Equal(t, expected, string(data))
|
||||||
|
}
|
||||||
|
13
util/util.go
Normal file
13
util/util.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
// package util provides small utility functions that make it easier to express common things
|
||||||
|
package util
|
||||||
|
|
||||||
|
import "encoding/hex"
|
||||||
|
|
||||||
|
// DeHex decodes a hexadecimal string into a byte slice. Panics if the string is invalid.
|
||||||
|
func DeHex(s string) []byte {
|
||||||
|
decoded, err := hex.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
panic("invalid hex string")
|
||||||
|
}
|
||||||
|
return decoded
|
||||||
|
}
|
25
util/util_test.go
Normal file
25
util/util_test.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package util_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.omicron.one/playground/cryptography/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeHex(t *testing.T) {
|
||||||
|
b := util.DeHex("")
|
||||||
|
assert.NotNil(t, b)
|
||||||
|
assert.Len(t, b, 0)
|
||||||
|
|
||||||
|
b = util.DeHex("deadbeef")
|
||||||
|
assert.NotNil(t, b)
|
||||||
|
assert.Equal(t, []byte("\xde\xad\xbe\xef"), b)
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, "invalid hex string", func() {
|
||||||
|
util.DeHex("dead serious this is not a hex string")
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, "invalid hex string", func() {
|
||||||
|
util.DeHex("deada55")
|
||||||
|
})
|
||||||
|
}
|
Reference in New Issue
Block a user