From 7b8a6f03c28951ac60d66626e947bc30486d7871 Mon Sep 17 00:00:00 2001 From: omicron Date: Thu, 22 May 2025 02:05:57 +0200 Subject: [PATCH] Implement JSON marshalling for the matrix class --- matrix/matrix.go | 48 ++++++++++++++++ matrix/matrix_test.go | 125 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) diff --git a/matrix/matrix.go b/matrix/matrix.go index 5dd59bd..39a7e79 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -4,6 +4,7 @@ package matrix import ( + "encoding/json" "errors" "golang.org/x/exp/constraints" @@ -107,6 +108,17 @@ func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] { return m } +// CreateFromJSON creates a new matrix from JSON data representing a +// two-dimensional array. Returns an error if the JSON is invalid, the array is +// empty, or the inner arrays have inconsistent lengths. +func CreateFromJSON[T Number](data []byte) (*Matrix[T], error) { + m := &Matrix[T]{} + if err := m.UnmarshalJSON(data); err != nil { + return nil, err + } + return m, nil +} + // Convert will take a matrix of type T and convert it into a matrix of type U // Only works on SimpleNumber matrices func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] { @@ -320,3 +332,39 @@ func (m *Matrix[T]) Fill(value T) *Matrix[T] { } return m } + +// UnmarshalJSON implements json.Unmarshaler for Matrix, creating a matrix from +// a JSON two-dimensional array. +func (m *Matrix[T]) UnmarshalJSON(data []byte) error { + var values [][]T + if err := json.Unmarshal(data, &values); err != nil { + return err + } + + rows := len(values) + if rows == 0 { + return ErrInvalidDimensions + } + cols := len(values[0]) + if cols == 0 { + return ErrInvalidDimensions + } + + for i := range rows { + if len(values[i]) != cols { + return ErrIncompatibleDataDimensions + } + } + + m.rows = rows + m.cols = cols + m.values = values + + return nil +} + +// MarshalJSON implements json.Marshaler for Matrix, serializing the matrix as +// a JSON two-dimensional array. +func (m *Matrix[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(m.values) +} diff --git a/matrix/matrix_test.go b/matrix/matrix_test.go index 22782da..487d193 100644 --- a/matrix/matrix_test.go +++ b/matrix/matrix_test.go @@ -1,6 +1,7 @@ package matrix_test import ( + "encoding/json" "testing" "git.omicron.one/playground/cryptography/matrix" @@ -129,6 +130,46 @@ func TestCreateFromFlatSlice(t *testing.T) { }) } +func TestCreateFromJSON(t *testing.T) { + // data json + data := []byte(`[[1, 2, 3], [4, 5, 6]]`) + m, err := matrix.CreateFromJSON[int](data) + assert.Nil(t, err) + assert.NotNil(t, m) + assert.Equal(t, 2, m.Rows()) + assert.Equal(t, 3, m.Cols()) + assert.Equal(t, 1, m.Get(0, 0)) + assert.Equal(t, 2, m.Get(0, 1)) + assert.Equal(t, 3, m.Get(0, 2)) + assert.Equal(t, 4, m.Get(1, 0)) + assert.Equal(t, 5, m.Get(1, 1)) + assert.Equal(t, 6, m.Get(1, 2)) + + // invalid json + data = []byte(`[[1, 2, 3], [4, 5,`) + m, err = matrix.CreateFromJSON[int](data) + assert.NotNil(t, err) + assert.Nil(t, m) + + // empty matrix + data = []byte(`[]`) + m, err = matrix.CreateFromJSON[int](data) + assert.ErrorIs(t, err, matrix.ErrInvalidDimensions) + assert.Nil(t, m) + + // empty rows + data = []byte(`[[]]`) + m, err = matrix.CreateFromJSON[int](data) + assert.ErrorIs(t, err, matrix.ErrInvalidDimensions) + assert.Nil(t, m) + + // mixed row length + data = []byte(`[[1, 2, 3], [4, 5]]`) + m, err = matrix.CreateFromJSON[int](data) + assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions) + assert.Nil(t, m) +} + func TestSum(t *testing.T) { a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1}) @@ -528,3 +569,87 @@ func TestMatrix_Fill(t *testing.T) { assert.Equal(t, 3, a.Get(1, 1)) assert.Equal(t, 3, a.Get(1, 2)) } + +func TestMatrix_UnmarshalJSON(t *testing.T) { + // int matrix + data := []byte(`[[1,2,3],[4,5,6]]`) + var m *matrix.Matrix[int] + err := json.Unmarshal(data, &m) + assert.Nil(t, err) + + assert.Equal(t, 2, m.Rows()) + assert.Equal(t, 3, m.Cols()) + assert.Equal(t, 1, m.Get(0, 0)) + assert.Equal(t, 2, m.Get(0, 1)) + assert.Equal(t, 3, m.Get(0, 2)) + assert.Equal(t, 4, m.Get(1, 0)) + assert.Equal(t, 5, m.Get(1, 1)) + assert.Equal(t, 6, m.Get(1, 2)) + + // float matrix + data = []byte(`[[1.5,2.5],[3.5,4.5]]`) + var mf *matrix.Matrix[float64] + err = json.Unmarshal(data, &mf) + assert.Nil(t, err) + + assert.Equal(t, 2, mf.Rows()) + assert.Equal(t, 2, mf.Cols()) + assert.Equal(t, 1.5, mf.Get(0, 0)) + assert.Equal(t, 2.5, mf.Get(0, 1)) + assert.Equal(t, 3.5, mf.Get(1, 0)) + assert.Equal(t, 4.5, mf.Get(1, 1)) + + // via json.Unmarshal + matrices := []byte(`[[[1,2],[3,4]],[[5,6,7]]]`) + var ms []*matrix.Matrix[int] + err = json.Unmarshal(matrices, &ms) + assert.Nil(t, err) + assert.Len(t, ms, 2) + assert.Equal(t, 2, ms[0].Get(0, 1)) + assert.Equal(t, 7, ms[1].Get(0, 2)) + + // invalid JSON + err = m.UnmarshalJSON([]byte(`invalid`)) + assert.NotNil(t, err) + + // empty array + err = m.UnmarshalJSON([]byte(`[]`)) + assert.ErrorIs(t, err, matrix.ErrInvalidDimensions) + + // empty inner array + err = m.UnmarshalJSON([]byte(`[[]]`)) + assert.ErrorIs(t, err, matrix.ErrInvalidDimensions) + + // inconsistent lengths + err = m.UnmarshalJSON([]byte(`[[1,2],[3]]`)) + assert.ErrorIs(t, err, matrix.ErrIncompatibleDataDimensions) +} + +func TestMatrix_MarshallJSON(t *testing.T) { + // int matrix + m := matrix.CreateFromSlice([][]int{{1, 2, 3}, {4, 5, 6}}) + data, err := m.MarshalJSON() + assert.Nil(t, err) + assert.NotNil(t, data) + + expected := `[[1,2,3],[4,5,6]]` + assert.Equal(t, expected, string(data)) + + // float matrix + mf := matrix.CreateFromSlice([][]float64{{1.5, 2.5}, {3.5, 4.5}}) + data, err = mf.MarshalJSON() + assert.Nil(t, err) + assert.NotNil(t, data) + expectedFloat := `[[1.5,2.5],[3.5,4.5]]` + assert.Equal(t, expectedFloat, string(data)) + + // slice of matrices via json.Marshal + m1 := matrix.CreateFromSlice([][]int{{1, 2}, {3, 4}}) + m2 := matrix.CreateFromSlice([][]int{{5, 6, 7}}) + matrices := []*matrix.Matrix[int]{m1, m2} + + data, err = json.Marshal(matrices) + assert.Nil(t, err) + expected = `[[[1,2],[3,4]],[[5,6,7]]]` + assert.Equal(t, expected, string(data)) +}