Compare commits

...

2 Commits

Author SHA1 Message Date
7b8a6f03c2 Implement JSON marshalling for the matrix class
All checks were successful
Validate the build / validate-build (push) Successful in 1m1s
2025-05-22 12:56:45 +02:00
9729fe6dcb Add matrix.CreateFromSlice function 2025-05-22 12:56:38 +02:00
2 changed files with 235 additions and 0 deletions

View File

@ -4,6 +4,7 @@
package matrix
import (
"encoding/json"
"errors"
"golang.org/x/exp/constraints"
@ -54,6 +55,33 @@ func Create[T Number](rows, cols int) *Matrix[T] {
}
}
// Creates a new matrix from a given 2D slice of values. The first index of the
// slice denotes the rows and the second index denotes the columns. Returns the
// newly created matrix.
//
// Panics with ErrIncompatibleDataDimensions if any of the lengths are 0 or if
// not all rows have the same length.
func CreateFromSlice[T Number](values [][]T) *Matrix[T] {
rows := len(values)
if rows == 0 {
panic(ErrInvalidDimensions)
}
cols := len(values[0])
if cols == 0 {
panic(ErrInvalidDimensions)
}
m := Create[T](rows, cols)
for i := range rows {
if len(values[i]) != cols {
panic(ErrIncompatibleDataDimensions)
}
copy(m.values[i], values[i])
}
return m
}
// Creates a new matrix of a given size and sets the values from the given
// slice. The values are used to fill the matrix left to right and top to
// bottom. Returns the newly created matrix.
@ -80,6 +108,17 @@ func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] {
return m
}
// CreateFromJSON creates a new matrix from JSON data representing a
// two-dimensional array. Returns an error if the JSON is invalid, the array is
// empty, or the inner arrays have inconsistent lengths.
func CreateFromJSON[T Number](data []byte) (*Matrix[T], error) {
m := &Matrix[T]{}
if err := m.UnmarshalJSON(data); err != nil {
return nil, err
}
return m, nil
}
// Convert will take a matrix of type T and convert it into a matrix of type U
// Only works on SimpleNumber matrices
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
@ -293,3 +332,39 @@ func (m *Matrix[T]) Fill(value T) *Matrix[T] {
}
return m
}
// UnmarshalJSON implements json.Unmarshaler for Matrix, creating a matrix from
// a JSON two-dimensional array.
func (m *Matrix[T]) UnmarshalJSON(data []byte) error {
var values [][]T
if err := json.Unmarshal(data, &values); err != nil {
return err
}
rows := len(values)
if rows == 0 {
return ErrInvalidDimensions
}
cols := len(values[0])
if cols == 0 {
return ErrInvalidDimensions
}
for i := range rows {
if len(values[i]) != cols {
return ErrIncompatibleDataDimensions
}
}
m.rows = rows
m.cols = cols
m.values = values
return nil
}
// MarshalJSON implements json.Marshaler for Matrix, serializing the matrix as
// a JSON two-dimensional array.
func (m *Matrix[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(m.values)
}

View File

@ -1,6 +1,7 @@
package matrix_test
import (
"encoding/json"
"testing"
"git.omicron.one/playground/cryptography/matrix"
@ -55,6 +56,41 @@ func TestCreate(t *testing.T) {
})
}
func TestCreateFromSlice(t *testing.T) {
m := matrix.CreateFromSlice([][]int{
{1, 2, 3},
{4, 5, 6},
})
assert.NotNil(t, m)
assert.Equal(t, 2, m.Rows())
assert.Equal(t, 3, m.Cols())
assert.Equal(t, 1, m.Get(0, 0))
assert.Equal(t, 2, m.Get(0, 1))
assert.Equal(t, 3, m.Get(0, 2))
assert.Equal(t, 4, m.Get(1, 0))
assert.Equal(t, 5, m.Get(1, 1))
assert.Equal(t, 6, m.Get(1, 2))
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
matrix.CreateFromSlice([][]int{})
})
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
matrix.CreateFromSlice([][]int{
{},
})
})
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
matrix.CreateFromSlice([][]int{
{1, 2, 3},
{4, 5},
})
})
}
func TestCreateFromFlatSlice(t *testing.T) {
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
assert.NotNil(t, m)
@ -94,6 +130,46 @@ func TestCreateFromFlatSlice(t *testing.T) {
})
}
func TestCreateFromJSON(t *testing.T) {
// data json
data := []byte(`[[1, 2, 3], [4, 5, 6]]`)
m, err := matrix.CreateFromJSON[int](data)
assert.Nil(t, err)
assert.NotNil(t, m)
assert.Equal(t, 2, m.Rows())
assert.Equal(t, 3, m.Cols())
assert.Equal(t, 1, m.Get(0, 0))
assert.Equal(t, 2, m.Get(0, 1))
assert.Equal(t, 3, m.Get(0, 2))
assert.Equal(t, 4, m.Get(1, 0))
assert.Equal(t, 5, m.Get(1, 1))
assert.Equal(t, 6, m.Get(1, 2))
// invalid json
data = []byte(`[[1, 2, 3], [4, 5,`)
m, err = matrix.CreateFromJSON[int](data)
assert.NotNil(t, err)
assert.Nil(t, m)
// empty matrix
data = []byte(`[]`)
m, err = matrix.CreateFromJSON[int](data)
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
assert.Nil(t, m)
// empty rows
data = []byte(`[[]]`)
m, err = matrix.CreateFromJSON[int](data)
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
assert.Nil(t, m)
// mixed row length
data = []byte(`[[1, 2, 3], [4, 5]]`)
m, err = matrix.CreateFromJSON[int](data)
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
assert.Nil(t, m)
}
func TestSum(t *testing.T) {
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
@ -493,3 +569,87 @@ func TestMatrix_Fill(t *testing.T) {
assert.Equal(t, 3, a.Get(1, 1))
assert.Equal(t, 3, a.Get(1, 2))
}
func TestMatrix_UnmarshalJSON(t *testing.T) {
// int matrix
data := []byte(`[[1,2,3],[4,5,6]]`)
var m *matrix.Matrix[int]
err := json.Unmarshal(data, &m)
assert.Nil(t, err)
assert.Equal(t, 2, m.Rows())
assert.Equal(t, 3, m.Cols())
assert.Equal(t, 1, m.Get(0, 0))
assert.Equal(t, 2, m.Get(0, 1))
assert.Equal(t, 3, m.Get(0, 2))
assert.Equal(t, 4, m.Get(1, 0))
assert.Equal(t, 5, m.Get(1, 1))
assert.Equal(t, 6, m.Get(1, 2))
// float matrix
data = []byte(`[[1.5,2.5],[3.5,4.5]]`)
var mf *matrix.Matrix[float64]
err = json.Unmarshal(data, &mf)
assert.Nil(t, err)
assert.Equal(t, 2, mf.Rows())
assert.Equal(t, 2, mf.Cols())
assert.Equal(t, 1.5, mf.Get(0, 0))
assert.Equal(t, 2.5, mf.Get(0, 1))
assert.Equal(t, 3.5, mf.Get(1, 0))
assert.Equal(t, 4.5, mf.Get(1, 1))
// via json.Unmarshal
matrices := []byte(`[[[1,2],[3,4]],[[5,6,7]]]`)
var ms []*matrix.Matrix[int]
err = json.Unmarshal(matrices, &ms)
assert.Nil(t, err)
assert.Len(t, ms, 2)
assert.Equal(t, 2, ms[0].Get(0, 1))
assert.Equal(t, 7, ms[1].Get(0, 2))
// invalid JSON
err = m.UnmarshalJSON([]byte(`invalid`))
assert.NotNil(t, err)
// empty array
err = m.UnmarshalJSON([]byte(`[]`))
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
// empty inner array
err = m.UnmarshalJSON([]byte(`[[]]`))
assert.ErrorIs(t, err, matrix.ErrInvalidDimensions)
// inconsistent lengths
err = m.UnmarshalJSON([]byte(`[[1,2],[3]]`))
assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions)
}
func TestMatrix_MarshallJSON(t *testing.T) {
// int matrix
m := matrix.CreateFromSlice([][]int{{1, 2, 3}, {4, 5, 6}})
data, err := m.MarshalJSON()
assert.Nil(t, err)
assert.NotNil(t, data)
expected := `[[1,2,3],[4,5,6]]`
assert.Equal(t, expected, string(data))
// float matrix
mf := matrix.CreateFromSlice([][]float64{{1.5, 2.5}, {3.5, 4.5}})
data, err = mf.MarshalJSON()
assert.Nil(t, err)
assert.NotNil(t, data)
expectedFloat := `[[1.5,2.5],[3.5,4.5]]`
assert.Equal(t, expectedFloat, string(data))
// slice of matrices via json.Marshal
m1 := matrix.CreateFromSlice([][]int{{1, 2}, {3, 4}})
m2 := matrix.CreateFromSlice([][]int{{5, 6, 7}})
matrices := []*matrix.Matrix[int]{m1, m2}
data, err = json.Marshal(matrices)
assert.Nil(t, err)
expected = `[[[1,2],[3,4]],[[5,6,7]]]`
assert.Equal(t, expected, string(data))
}