From 18009f87543e2e02cdb5b0f30ffc06a4b6826227 Mon Sep 17 00:00:00 2001 From: omicron Date: Sun, 18 May 2025 00:33:07 +0200 Subject: [PATCH] Add basic matrix package --- go.mod | 14 ++ go.sum | 12 ++ matrix/examples_test.go | 44 +++++ matrix/matrix.go | 242 ++++++++++++++++++++++++++ matrix/matrix_test.go | 368 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 680 insertions(+) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 matrix/examples_test.go create mode 100644 matrix/matrix.go create mode 100644 matrix/matrix_test.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..41b6a3d --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module git.omicron.one/playground/cryptography + +go 1.24.2 + +require ( + github.com/stretchr/testify v1.10.0 + golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..fb3828e --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/matrix/examples_test.go b/matrix/examples_test.go new file mode 100644 index 0000000..8b37949 --- /dev/null +++ b/matrix/examples_test.go @@ -0,0 +1,44 @@ +package matrix_test + +import ( + "fmt" + + "git.omicron.one/playground/cryptography/matrix" +) + +func ExampleConvert() { + intMatrix := matrix.Create[int](3, 4) + + _ = matrix.Convert[float64](intMatrix) +} + +func ExampleCreateFromFlatSlice() { + m := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3}) + + for i := range m.Rows() { + for j := range m.Cols() { + fmt.Printf(" %d", m.Get(i, j)) + } + fmt.Println() + } + + // Output: + // 0 1 + // 2 3 +} + +func ExampleTransform() { + intMatrix := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3}) + + _ = matrix.Transform(intMatrix, func(in int) complex128 { + return complex(float64(in), 0.0) + }) +} + +func ExampleMatrix_Apply() { + m := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3}) + + m.Apply(func(in int) int { + return in * in + }) +} diff --git a/matrix/matrix.go b/matrix/matrix.go new file mode 100644 index 0000000..147e505 --- /dev/null +++ b/matrix/matrix.go @@ -0,0 +1,242 @@ +// matrix provides very basic matrix operations with all numeric types. It is +// not a high performance or feature rich implementation and is mostly intended +// to use for collecting statistics +package matrix + +import ( + "errors" + + "golang.org/x/exp/constraints" +) + +var ( + ErrInvalidDimensions = errors.New("Invalid dimensions") + ErrIncompatibleMatrixDimensions = errors.New("Incompatible matrix dimensions") + ErrIncompatibleDataDimensions = errors.New("Incompatible data dimensions") +) + +// Matrices can be created for these underlying types +type Number interface { + constraints.Integer | constraints.Float | constraints.Complex +} + +type SimpleNumber interface { + constraints.Integer | constraints.Float +} + +// Matrix represents a matrix with values of a specific Number type +type Matrix[T Number] struct { + rows int + cols int + values [][]T +} + +func createValues[T Number](rows, cols int) [][]T { + values := make([][]T, rows) + for i := range rows { + values[i] = make([]T, cols) + } + return values +} + +// Create creates a new matrix with the given number of rows and columns. All +// values of this matrix are set to zero. Returns the new matrix. +// +// Panics with ErrInvalidDimensions if rows < 1 or cols < 1. +func Create[T Number](rows, cols int) *Matrix[T] { + if rows < 1 || cols < 1 { + panic(ErrInvalidDimensions) + } + return &Matrix[T]{ + rows: rows, + cols: cols, + values: createValues[T](rows, cols), + } +} + +// Creates a new matrix of a given size and sets the values from the given +// slice. The values are used to fill the matrix left to right and top to +// bottom. Returns the newly created matrix. +// +// Panics with ErrInvalidDimensions if rows < 1 or cols < 1. +// Panics with ErrIncompatibleDataDimensions if the length of the values doesn't +// match the size of the matrix. This is the only possible error condition. +func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] { + if rows < 1 || cols < 1 { + panic(ErrInvalidDimensions) + } + if len(values) != rows*cols { + panic(ErrIncompatibleDataDimensions) + } + m := Create[T](rows, cols) + + n := 0 + for i := range rows { + for j := range cols { + m.values[i][j] = values[n] + n++ + } + } + return m +} + +// Convert will take a matrix of type T and convert it into a matrix of type U +// Only works on SimpleNumber matrices +func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] { + out := Create[U](in.rows, in.cols) + for i := range in.rows { + for j := range in.cols { + out.values[i][j] = U(in.values[i][j]) + } + } + return out +} + +// Transform the values of the given matrix with a transfrom function. This +// operation may change the type of the values. Unlike Convert it can be used +// on all Number types, not just SimpleNumber. +func Transform[U, T Number](in *Matrix[T], transformFn func(T) U) *Matrix[U] { + out := Create[U](in.rows, in.cols) + for i := range in.rows { + for j := range in.cols { + out.values[i][j] = transformFn(in.values[i][j]) + } + } + return out +} + +// Sum takes at least one matrix and sums it together. +// Returns a new matrix that is the sum of the arguments. +// Panics if the arguments don't have matching dimensions. +func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] { + return first.Copy().Add(additional...) +} + +// Copy creates a deep copy of the matrix and returns the new instance +func (m *Matrix[T]) Copy() *Matrix[T] { + mCopy := Create[T](m.rows, m.cols) + for i := range m.rows { + copy(mCopy.values[i], m.values[i]) + } + return mCopy +} + +// Size returns the dimensions of the matrix as (rows, columns) +func (m *Matrix[T]) Size() (int, int) { + return m.rows, m.cols +} + +// Rows returns the number of rows in the matrix +func (m *Matrix[T]) Rows() int { + return m.rows +} + +// Cols returns the number of columns in the matrix +func (m *Matrix[T]) Cols() int { + return m.cols +} + +// Set sets the value of the matrix at the given row and col position to the +// given value +func (m *Matrix[T]) Set(row, col int, value T) { + m.values[row][col] = value +} + +// Set assigns the specified value to the element at the given row and column +func (m *Matrix[T]) Get(row, col int) T { + return m.values[row][col] +} + +// Add performs in-place addition of zero or more matrices to this matrix and +// returns the receiver. +// Ensures correct behavior even if the matrix itself is passed as one or more +// arguments. +// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't +// have matching dimensions. +func (m *Matrix[T]) Add(matrices ...*Matrix[T]) *Matrix[T] { + numSelf := 0 + for _, other := range matrices { + if m.rows != other.rows || m.cols != other.cols { + panic(ErrIncompatibleMatrixDimensions) + } + if other == m { + numSelf += 1 + } + } + + // If we have multiple self references to add, work on duplicate data to + // make addition behave as expected + values := m.values + if numSelf > 1 { + values = m.Copy().values + } + + for _, other := range matrices { + for i := range m.rows { + for j := range m.cols { + values[i][j] += other.values[i][j] + } + } + } + + m.values = values + return m +} + +// Subtract performs in-place subtraction of zero or more matrices from this matrix and +// returns the receiver. +// Ensures correct behavior even if the matrix itself is passed as one or more +// arguments. +// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't +// have matching dimensions. +func (m *Matrix[T]) Subtract(matrices ...*Matrix[T]) *Matrix[T] { + numSelf := 0 + for _, other := range matrices { + if m.rows != other.rows || m.cols != other.cols { + panic(ErrIncompatibleMatrixDimensions) + } + if other == m { + numSelf += 1 + } + } + + // If we have multiple self references to subtract, work on duplicate data to + // make subtraction behave as expected + values := m.values + if numSelf > 1 { + values = m.Copy().values + } + + for _, other := range matrices { + for i := range m.rows { + for j := range m.cols { + values[i][j] -= other.values[i][j] + } + } + } + + m.values = values + return m +} + +// Apply performs an in-place transformation of each element of the matrix using +// the provided function. Returns the receiver. +func (m *Matrix[T]) Apply(fn func(T) T) *Matrix[T] { + for i := range m.rows { + for j := range m.cols { + m.values[i][j] = fn(m.values[i][j]) + } + } + return m +} + +// Scale does an in-place scalar multiplication of the matrix values. Returns +// the receiver. +func (m *Matrix[T]) Scale(scalar T) *Matrix[T] { + for i := range m.rows { + for j := range m.cols { + m.values[i][j] *= scalar + } + } + return m +} diff --git a/matrix/matrix_test.go b/matrix/matrix_test.go new file mode 100644 index 0000000..1fda57a --- /dev/null +++ b/matrix/matrix_test.go @@ -0,0 +1,368 @@ +package matrix_test + +import ( + "testing" + + "git.omicron.one/playground/cryptography/matrix" + "github.com/stretchr/testify/assert" +) + +func TestConvert(t *testing.T) { + m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + assert.NotNil(t, m) + + mf := matrix.Convert[float64](m) + assert.Equal(t, m.Rows(), mf.Rows()) + assert.Equal(t, m.Cols(), mf.Cols()) + + for row := range mf.Rows() { + for col := range mf.Cols() { + assert.Equal(t, float64(m.Get(row, col)), mf.Get(row, col)) + } + } +} + +func TestCreate(t *testing.T) { + m := matrix.Create[int](3, 4) + assert.NotNil(t, m) + + rows, cols := m.Size() + assert.Equal(t, 3, rows) + assert.Equal(t, 3, m.Rows()) + assert.Equal(t, 4, cols) + assert.Equal(t, 4, m.Cols()) + + for row := range rows { + for col := range cols { + assert.Equal(t, 0, m.Get(row, col)) + } + } + + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.Create[int](0, 1) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.Create[int](-1, 1) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.Create[int](1, 0) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.Create[int](1, -1) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.Create[int](-1, -1) + }) +} + +func TestCreateFromFlatSlice(t *testing.T) { + m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + assert.NotNil(t, m) + + assert.Equal(t, 2, m.Rows()) + assert.Equal(t, 3, m.Cols()) + + assert.Equal(t, 1, m.Get(0, 0)) + assert.Equal(t, 2, m.Get(0, 1)) + assert.Equal(t, 3, m.Get(0, 2)) + assert.Equal(t, 4, m.Get(1, 0)) + assert.Equal(t, 5, m.Get(1, 1)) + assert.Equal(t, 6, m.Get(1, 2)) + + assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() { + matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5}) + }) + + assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() { + matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6, 7}) + }) + + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.CreateFromFlatSlice(0, 1, []int{}) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.CreateFromFlatSlice(-1, 1, []int{}) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.CreateFromFlatSlice(1, 0, []int{}) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.CreateFromFlatSlice(1, -1, []int{}) + }) + assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() { + matrix.CreateFromFlatSlice(-1, -1, []int{1}) + }) +} + +func TestSum(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1}) + + // Sum only one matrix + c := matrix.Sum(a) + assert.NotNil(t, c) + assert.NotSame(t, a, c) + + assert.Equal(t, 2, c.Rows()) + assert.Equal(t, 3, c.Cols()) + + assert.Equal(t, 1, c.Get(0, 0)) + assert.Equal(t, 2, c.Get(0, 1)) + assert.Equal(t, 3, c.Get(0, 2)) + assert.Equal(t, 4, c.Get(1, 0)) + assert.Equal(t, 5, c.Get(1, 1)) + assert.Equal(t, 6, c.Get(1, 2)) + + // Sum one matrix multiple times + c = matrix.Sum(a, a, a) + assert.NotNil(t, c) + assert.NotSame(t, a, c) + + assert.Equal(t, 2, c.Rows()) + assert.Equal(t, 3, c.Cols()) + + assert.Equal(t, 3, c.Get(0, 0)) + assert.Equal(t, 6, c.Get(0, 1)) + assert.Equal(t, 9, c.Get(0, 2)) + assert.Equal(t, 12, c.Get(1, 0)) + assert.Equal(t, 15, c.Get(1, 1)) + assert.Equal(t, 18, c.Get(1, 2)) + + // Sum different matrices + c = matrix.Sum(a, b) + assert.NotNil(t, c) + assert.NotEqual(t, a, c) + + assert.Equal(t, 2, c.Rows()) + assert.Equal(t, 3, c.Cols()) + + assert.Equal(t, 2, c.Get(0, 0)) + assert.Equal(t, 3, c.Get(0, 1)) + assert.Equal(t, 4, c.Get(0, 2)) + assert.Equal(t, 5, c.Get(1, 0)) + assert.Equal(t, 6, c.Get(1, 1)) + assert.Equal(t, 7, c.Get(1, 2)) + + // Sum incorrect dimensions + d := matrix.Create[int](3, 2) + assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() { + matrix.Sum(a, d) + }) +} + +func TestTransform(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + b := matrix.Transform(a, func(value int) complex128 { + return complex(float64(value), 3.0) + }) + + assert.Equal(t, a.Rows(), b.Rows()) + assert.Equal(t, a.Cols(), b.Cols()) + + assert.Equal(t, complex(1.0, 3.0), b.Get(0, 0)) + assert.Equal(t, complex(2.0, 3.0), b.Get(0, 1)) + assert.Equal(t, complex(3.0, 3.0), b.Get(0, 2)) + assert.Equal(t, complex(4.0, 3.0), b.Get(1, 0)) + assert.Equal(t, complex(5.0, 3.0), b.Get(1, 1)) + assert.Equal(t, complex(6.0, 3.0), b.Get(1, 2)) +} + +func TestMatrix_Add(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1}) + + // Add nothing + c := a.Add() + assert.Same(t, a, c) + assert.Equal(t, 1, a.Get(0, 0)) + assert.Equal(t, 2, a.Get(0, 1)) + assert.Equal(t, 3, a.Get(0, 2)) + assert.Equal(t, 4, a.Get(1, 0)) + assert.Equal(t, 5, a.Get(1, 1)) + assert.Equal(t, 6, a.Get(1, 2)) + + // Add itself multiple times + c = a.Add(a, a, a) + assert.Same(t, a, c) + + assert.Equal(t, 4, a.Get(0, 0)) + assert.Equal(t, 8, a.Get(0, 1)) + assert.Equal(t, 12, a.Get(0, 2)) + assert.Equal(t, 16, a.Get(1, 0)) + assert.Equal(t, 20, a.Get(1, 1)) + assert.Equal(t, 24, a.Get(1, 2)) + + // Add other matrix + c = a.Add(b) + assert.Same(t, a, c) + + assert.Equal(t, 5, a.Get(0, 0)) + assert.Equal(t, 9, a.Get(0, 1)) + assert.Equal(t, 13, a.Get(0, 2)) + assert.Equal(t, 17, a.Get(1, 0)) + assert.Equal(t, 21, a.Get(1, 1)) + assert.Equal(t, 25, a.Get(1, 2)) + + // Add incorrect dimension + assert.PanicsWithValue( + t, matrix.ErrIncompatibleMatrixDimensions, + func() { + d := matrix.Create[int](3, 2) + a.Add(b, a, a, d) + }, + ) + + assert.Equal(t, 5, a.Get(0, 0)) + assert.Equal(t, 9, a.Get(0, 1)) + assert.Equal(t, 13, a.Get(0, 2)) + assert.Equal(t, 17, a.Get(1, 0)) + assert.Equal(t, 21, a.Get(1, 1)) + assert.Equal(t, 25, a.Get(1, 2)) +} + +func TestMatrix_Subtract(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1}) + + // Subtract nothing + c := a.Subtract() + assert.Same(t, a, c) + assert.Equal(t, 1, a.Get(0, 0)) + assert.Equal(t, 2, a.Get(0, 1)) + assert.Equal(t, 3, a.Get(0, 2)) + assert.Equal(t, 4, a.Get(1, 0)) + assert.Equal(t, 5, a.Get(1, 1)) + assert.Equal(t, 6, a.Get(1, 2)) + + // Add itself multiple times + c = a.Subtract(a, a, a, a) + assert.Same(t, a, c) + + assert.Equal(t, -3, a.Get(0, 0)) + assert.Equal(t, -6, a.Get(0, 1)) + assert.Equal(t, -9, a.Get(0, 2)) + assert.Equal(t, -12, a.Get(1, 0)) + assert.Equal(t, -15, a.Get(1, 1)) + assert.Equal(t, -18, a.Get(1, 2)) + + // Add other matrix + c = a.Subtract(b) + assert.Same(t, a, c) + + assert.Equal(t, -4, a.Get(0, 0)) + assert.Equal(t, -7, a.Get(0, 1)) + assert.Equal(t, -10, a.Get(0, 2)) + assert.Equal(t, -13, a.Get(1, 0)) + assert.Equal(t, -16, a.Get(1, 1)) + assert.Equal(t, -19, a.Get(1, 2)) + + // Add incorrect dimension + assert.PanicsWithValue( + t, matrix.ErrIncompatibleMatrixDimensions, + func() { + d := matrix.Create[int](3, 2) + a.Subtract(b, a, a, d) + }, + ) + + assert.Equal(t, -4, a.Get(0, 0)) + assert.Equal(t, -7, a.Get(0, 1)) + assert.Equal(t, -10, a.Get(0, 2)) + assert.Equal(t, -13, a.Get(1, 0)) + assert.Equal(t, -16, a.Get(1, 1)) + assert.Equal(t, -19, a.Get(1, 2)) +} + +func TestMatrix_Apply(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + assert.NotNil(t, a) + + a.Apply(func(x int) int { return x * x }) + + assert.Equal(t, 1, a.Get(0, 0)) + assert.Equal(t, 4, a.Get(0, 1)) + assert.Equal(t, 9, a.Get(0, 2)) + assert.Equal(t, 16, a.Get(1, 0)) + assert.Equal(t, 25, a.Get(1, 1)) + assert.Equal(t, 36, a.Get(1, 2)) +} + +func TestMatrix_Copy(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + assert.NotNil(t, a) + + b := a.Copy() + assert.NotSame(t, a, b) + assert.Equal(t, 1, b.Get(0, 0)) + assert.Equal(t, 2, b.Get(0, 1)) + assert.Equal(t, 3, b.Get(0, 2)) + assert.Equal(t, 4, b.Get(1, 0)) + assert.Equal(t, 5, b.Get(1, 1)) + assert.Equal(t, 6, b.Get(1, 2)) +} + +func TestMatrix_Get(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + assert.NotNil(t, a) + + assert.Equal(t, 1, a.Get(0, 0)) + assert.Equal(t, 2, a.Get(0, 1)) + assert.Equal(t, 3, a.Get(0, 2)) + assert.Equal(t, 4, a.Get(1, 0)) + assert.Equal(t, 5, a.Get(1, 1)) + assert.Equal(t, 6, a.Get(1, 2)) + + assert.Panics(t, func() { + a.Get(-1, 0) + }) + assert.Panics(t, func() { + a.Get(0, -1) + }) + assert.Panics(t, func() { + a.Get(2, 0) + }) + assert.Panics(t, func() { + a.Get(0, 3) + }) +} + +func TestMatrix_Set(t *testing.T) { + a := matrix.Create[int](2, 2) + assert.NotNil(t, a) + + for row := range a.Rows() { + for col := range a.Cols() { + assert.NotEqual(t, a.Get(row, col), 42) + a.Set(row, col, 42) + assert.Equal(t, 42, a.Get(row, col)) + } + } + + assert.Panics(t, func() { + a.Set(-1, 0, 42) + }) + assert.Panics(t, func() { + a.Set(0, -1, 42) + }) + assert.Panics(t, func() { + a.Set(2, 0, 42) + }) + assert.Panics(t, func() { + a.Set(0, 2, 42) + }) +} + +func TestMatrix_Scale(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + + // Add nothing + b := a.Scale(-4) + assert.Same(t, a, b) + + assert.Equal(t, -4, a.Get(0, 0)) + assert.Equal(t, -8, a.Get(0, 1)) + assert.Equal(t, -12, a.Get(0, 2)) + assert.Equal(t, -16, a.Get(1, 0)) + assert.Equal(t, -20, a.Get(1, 1)) + assert.Equal(t, -24, a.Get(1, 2)) +}