Add basic matrix package
This commit is contained in:
14
go.mod
Normal file
14
go.mod
Normal file
@ -0,0 +1,14 @@
|
||||
module git.omicron.one/playground/cryptography
|
||||
|
||||
go 1.24.2
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
12
go.sum
Normal file
12
go.sum
Normal file
@ -0,0 +1,12 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
44
matrix/examples_test.go
Normal file
44
matrix/examples_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package matrix_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.omicron.one/playground/cryptography/matrix"
|
||||
)
|
||||
|
||||
func ExampleConvert() {
|
||||
intMatrix := matrix.Create[int](3, 4)
|
||||
|
||||
_ = matrix.Convert[float64](intMatrix)
|
||||
}
|
||||
|
||||
func ExampleCreateFromFlatSlice() {
|
||||
m := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3})
|
||||
|
||||
for i := range m.Rows() {
|
||||
for j := range m.Cols() {
|
||||
fmt.Printf(" %d", m.Get(i, j))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// Output:
|
||||
// 0 1
|
||||
// 2 3
|
||||
}
|
||||
|
||||
func ExampleTransform() {
|
||||
intMatrix := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3})
|
||||
|
||||
_ = matrix.Transform(intMatrix, func(in int) complex128 {
|
||||
return complex(float64(in), 0.0)
|
||||
})
|
||||
}
|
||||
|
||||
func ExampleMatrix_Apply() {
|
||||
m := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3})
|
||||
|
||||
m.Apply(func(in int) int {
|
||||
return in * in
|
||||
})
|
||||
}
|
242
matrix/matrix.go
Normal file
242
matrix/matrix.go
Normal file
@ -0,0 +1,242 @@
|
||||
// matrix provides very basic matrix operations with all numeric types. It is
|
||||
// not a high performance or feature rich implementation and is mostly intended
|
||||
// to use for collecting statistics
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidDimensions = errors.New("Invalid dimensions")
|
||||
ErrIncompatibleMatrixDimensions = errors.New("Incompatible matrix dimensions")
|
||||
ErrIncompatibleDataDimensions = errors.New("Incompatible data dimensions")
|
||||
)
|
||||
|
||||
// Matrices can be created for these underlying types
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex
|
||||
}
|
||||
|
||||
type SimpleNumber interface {
|
||||
constraints.Integer | constraints.Float
|
||||
}
|
||||
|
||||
// Matrix represents a matrix with values of a specific Number type
|
||||
type Matrix[T Number] struct {
|
||||
rows int
|
||||
cols int
|
||||
values [][]T
|
||||
}
|
||||
|
||||
func createValues[T Number](rows, cols int) [][]T {
|
||||
values := make([][]T, rows)
|
||||
for i := range rows {
|
||||
values[i] = make([]T, cols)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Create creates a new matrix with the given number of rows and columns. All
|
||||
// values of this matrix are set to zero. Returns the new matrix.
|
||||
//
|
||||
// Panics with ErrInvalidDimensions if rows < 1 or cols < 1.
|
||||
func Create[T Number](rows, cols int) *Matrix[T] {
|
||||
if rows < 1 || cols < 1 {
|
||||
panic(ErrInvalidDimensions)
|
||||
}
|
||||
return &Matrix[T]{
|
||||
rows: rows,
|
||||
cols: cols,
|
||||
values: createValues[T](rows, cols),
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a new matrix of a given size and sets the values from the given
|
||||
// slice. The values are used to fill the matrix left to right and top to
|
||||
// bottom. Returns the newly created matrix.
|
||||
//
|
||||
// Panics with ErrInvalidDimensions if rows < 1 or cols < 1.
|
||||
// Panics with ErrIncompatibleDataDimensions if the length of the values doesn't
|
||||
// match the size of the matrix. This is the only possible error condition.
|
||||
func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] {
|
||||
if rows < 1 || cols < 1 {
|
||||
panic(ErrInvalidDimensions)
|
||||
}
|
||||
if len(values) != rows*cols {
|
||||
panic(ErrIncompatibleDataDimensions)
|
||||
}
|
||||
m := Create[T](rows, cols)
|
||||
|
||||
n := 0
|
||||
for i := range rows {
|
||||
for j := range cols {
|
||||
m.values[i][j] = values[n]
|
||||
n++
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Convert will take a matrix of type T and convert it into a matrix of type U
|
||||
// Only works on SimpleNumber matrices
|
||||
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
|
||||
out := Create[U](in.rows, in.cols)
|
||||
for i := range in.rows {
|
||||
for j := range in.cols {
|
||||
out.values[i][j] = U(in.values[i][j])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Transform the values of the given matrix with a transfrom function. This
|
||||
// operation may change the type of the values. Unlike Convert it can be used
|
||||
// on all Number types, not just SimpleNumber.
|
||||
func Transform[U, T Number](in *Matrix[T], transformFn func(T) U) *Matrix[U] {
|
||||
out := Create[U](in.rows, in.cols)
|
||||
for i := range in.rows {
|
||||
for j := range in.cols {
|
||||
out.values[i][j] = transformFn(in.values[i][j])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Sum takes at least one matrix and sums it together.
|
||||
// Returns a new matrix that is the sum of the arguments.
|
||||
// Panics if the arguments don't have matching dimensions.
|
||||
func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
||||
return first.Copy().Add(additional...)
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the matrix and returns the new instance
|
||||
func (m *Matrix[T]) Copy() *Matrix[T] {
|
||||
mCopy := Create[T](m.rows, m.cols)
|
||||
for i := range m.rows {
|
||||
copy(mCopy.values[i], m.values[i])
|
||||
}
|
||||
return mCopy
|
||||
}
|
||||
|
||||
// Size returns the dimensions of the matrix as (rows, columns)
|
||||
func (m *Matrix[T]) Size() (int, int) {
|
||||
return m.rows, m.cols
|
||||
}
|
||||
|
||||
// Rows returns the number of rows in the matrix
|
||||
func (m *Matrix[T]) Rows() int {
|
||||
return m.rows
|
||||
}
|
||||
|
||||
// Cols returns the number of columns in the matrix
|
||||
func (m *Matrix[T]) Cols() int {
|
||||
return m.cols
|
||||
}
|
||||
|
||||
// Set sets the value of the matrix at the given row and col position to the
|
||||
// given value
|
||||
func (m *Matrix[T]) Set(row, col int, value T) {
|
||||
m.values[row][col] = value
|
||||
}
|
||||
|
||||
// Set assigns the specified value to the element at the given row and column
|
||||
func (m *Matrix[T]) Get(row, col int) T {
|
||||
return m.values[row][col]
|
||||
}
|
||||
|
||||
// Add performs in-place addition of zero or more matrices to this matrix and
|
||||
// returns the receiver.
|
||||
// Ensures correct behavior even if the matrix itself is passed as one or more
|
||||
// arguments.
|
||||
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
|
||||
// have matching dimensions.
|
||||
func (m *Matrix[T]) Add(matrices ...*Matrix[T]) *Matrix[T] {
|
||||
numSelf := 0
|
||||
for _, other := range matrices {
|
||||
if m.rows != other.rows || m.cols != other.cols {
|
||||
panic(ErrIncompatibleMatrixDimensions)
|
||||
}
|
||||
if other == m {
|
||||
numSelf += 1
|
||||
}
|
||||
}
|
||||
|
||||
// If we have multiple self references to add, work on duplicate data to
|
||||
// make addition behave as expected
|
||||
values := m.values
|
||||
if numSelf > 1 {
|
||||
values = m.Copy().values
|
||||
}
|
||||
|
||||
for _, other := range matrices {
|
||||
for i := range m.rows {
|
||||
for j := range m.cols {
|
||||
values[i][j] += other.values[i][j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.values = values
|
||||
return m
|
||||
}
|
||||
|
||||
// Subtract performs in-place subtraction of zero or more matrices from this matrix and
|
||||
// returns the receiver.
|
||||
// Ensures correct behavior even if the matrix itself is passed as one or more
|
||||
// arguments.
|
||||
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
|
||||
// have matching dimensions.
|
||||
func (m *Matrix[T]) Subtract(matrices ...*Matrix[T]) *Matrix[T] {
|
||||
numSelf := 0
|
||||
for _, other := range matrices {
|
||||
if m.rows != other.rows || m.cols != other.cols {
|
||||
panic(ErrIncompatibleMatrixDimensions)
|
||||
}
|
||||
if other == m {
|
||||
numSelf += 1
|
||||
}
|
||||
}
|
||||
|
||||
// If we have multiple self references to subtract, work on duplicate data to
|
||||
// make subtraction behave as expected
|
||||
values := m.values
|
||||
if numSelf > 1 {
|
||||
values = m.Copy().values
|
||||
}
|
||||
|
||||
for _, other := range matrices {
|
||||
for i := range m.rows {
|
||||
for j := range m.cols {
|
||||
values[i][j] -= other.values[i][j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.values = values
|
||||
return m
|
||||
}
|
||||
|
||||
// Apply performs an in-place transformation of each element of the matrix using
|
||||
// the provided function. Returns the receiver.
|
||||
func (m *Matrix[T]) Apply(fn func(T) T) *Matrix[T] {
|
||||
for i := range m.rows {
|
||||
for j := range m.cols {
|
||||
m.values[i][j] = fn(m.values[i][j])
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Scale does an in-place scalar multiplication of the matrix values. Returns
|
||||
// the receiver.
|
||||
func (m *Matrix[T]) Scale(scalar T) *Matrix[T] {
|
||||
for i := range m.rows {
|
||||
for j := range m.cols {
|
||||
m.values[i][j] *= scalar
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
368
matrix/matrix_test.go
Normal file
368
matrix/matrix_test.go
Normal file
@ -0,0 +1,368 @@
|
||||
package matrix_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.omicron.one/playground/cryptography/matrix"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConvert(t *testing.T) {
|
||||
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
assert.NotNil(t, m)
|
||||
|
||||
mf := matrix.Convert[float64](m)
|
||||
assert.Equal(t, m.Rows(), mf.Rows())
|
||||
assert.Equal(t, m.Cols(), mf.Cols())
|
||||
|
||||
for row := range mf.Rows() {
|
||||
for col := range mf.Cols() {
|
||||
assert.Equal(t, float64(m.Get(row, col)), mf.Get(row, col))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
m := matrix.Create[int](3, 4)
|
||||
assert.NotNil(t, m)
|
||||
|
||||
rows, cols := m.Size()
|
||||
assert.Equal(t, 3, rows)
|
||||
assert.Equal(t, 3, m.Rows())
|
||||
assert.Equal(t, 4, cols)
|
||||
assert.Equal(t, 4, m.Cols())
|
||||
|
||||
for row := range rows {
|
||||
for col := range cols {
|
||||
assert.Equal(t, 0, m.Get(row, col))
|
||||
}
|
||||
}
|
||||
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.Create[int](0, 1)
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.Create[int](-1, 1)
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.Create[int](1, 0)
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.Create[int](1, -1)
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.Create[int](-1, -1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromFlatSlice(t *testing.T) {
|
||||
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
assert.NotNil(t, m)
|
||||
|
||||
assert.Equal(t, 2, m.Rows())
|
||||
assert.Equal(t, 3, m.Cols())
|
||||
|
||||
assert.Equal(t, 1, m.Get(0, 0))
|
||||
assert.Equal(t, 2, m.Get(0, 1))
|
||||
assert.Equal(t, 3, m.Get(0, 2))
|
||||
assert.Equal(t, 4, m.Get(1, 0))
|
||||
assert.Equal(t, 5, m.Get(1, 1))
|
||||
assert.Equal(t, 6, m.Get(1, 2))
|
||||
|
||||
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
|
||||
matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5})
|
||||
})
|
||||
|
||||
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
|
||||
matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6, 7})
|
||||
})
|
||||
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.CreateFromFlatSlice(0, 1, []int{})
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.CreateFromFlatSlice(-1, 1, []int{})
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.CreateFromFlatSlice(1, 0, []int{})
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.CreateFromFlatSlice(1, -1, []int{})
|
||||
})
|
||||
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||
matrix.CreateFromFlatSlice(-1, -1, []int{1})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSum(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||
|
||||
// Sum only one matrix
|
||||
c := matrix.Sum(a)
|
||||
assert.NotNil(t, c)
|
||||
assert.NotSame(t, a, c)
|
||||
|
||||
assert.Equal(t, 2, c.Rows())
|
||||
assert.Equal(t, 3, c.Cols())
|
||||
|
||||
assert.Equal(t, 1, c.Get(0, 0))
|
||||
assert.Equal(t, 2, c.Get(0, 1))
|
||||
assert.Equal(t, 3, c.Get(0, 2))
|
||||
assert.Equal(t, 4, c.Get(1, 0))
|
||||
assert.Equal(t, 5, c.Get(1, 1))
|
||||
assert.Equal(t, 6, c.Get(1, 2))
|
||||
|
||||
// Sum one matrix multiple times
|
||||
c = matrix.Sum(a, a, a)
|
||||
assert.NotNil(t, c)
|
||||
assert.NotSame(t, a, c)
|
||||
|
||||
assert.Equal(t, 2, c.Rows())
|
||||
assert.Equal(t, 3, c.Cols())
|
||||
|
||||
assert.Equal(t, 3, c.Get(0, 0))
|
||||
assert.Equal(t, 6, c.Get(0, 1))
|
||||
assert.Equal(t, 9, c.Get(0, 2))
|
||||
assert.Equal(t, 12, c.Get(1, 0))
|
||||
assert.Equal(t, 15, c.Get(1, 1))
|
||||
assert.Equal(t, 18, c.Get(1, 2))
|
||||
|
||||
// Sum different matrices
|
||||
c = matrix.Sum(a, b)
|
||||
assert.NotNil(t, c)
|
||||
assert.NotEqual(t, a, c)
|
||||
|
||||
assert.Equal(t, 2, c.Rows())
|
||||
assert.Equal(t, 3, c.Cols())
|
||||
|
||||
assert.Equal(t, 2, c.Get(0, 0))
|
||||
assert.Equal(t, 3, c.Get(0, 1))
|
||||
assert.Equal(t, 4, c.Get(0, 2))
|
||||
assert.Equal(t, 5, c.Get(1, 0))
|
||||
assert.Equal(t, 6, c.Get(1, 1))
|
||||
assert.Equal(t, 7, c.Get(1, 2))
|
||||
|
||||
// Sum incorrect dimensions
|
||||
d := matrix.Create[int](3, 2)
|
||||
assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() {
|
||||
matrix.Sum(a, d)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransform(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
b := matrix.Transform(a, func(value int) complex128 {
|
||||
return complex(float64(value), 3.0)
|
||||
})
|
||||
|
||||
assert.Equal(t, a.Rows(), b.Rows())
|
||||
assert.Equal(t, a.Cols(), b.Cols())
|
||||
|
||||
assert.Equal(t, complex(1.0, 3.0), b.Get(0, 0))
|
||||
assert.Equal(t, complex(2.0, 3.0), b.Get(0, 1))
|
||||
assert.Equal(t, complex(3.0, 3.0), b.Get(0, 2))
|
||||
assert.Equal(t, complex(4.0, 3.0), b.Get(1, 0))
|
||||
assert.Equal(t, complex(5.0, 3.0), b.Get(1, 1))
|
||||
assert.Equal(t, complex(6.0, 3.0), b.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestMatrix_Add(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||
|
||||
// Add nothing
|
||||
c := a.Add()
|
||||
assert.Same(t, a, c)
|
||||
assert.Equal(t, 1, a.Get(0, 0))
|
||||
assert.Equal(t, 2, a.Get(0, 1))
|
||||
assert.Equal(t, 3, a.Get(0, 2))
|
||||
assert.Equal(t, 4, a.Get(1, 0))
|
||||
assert.Equal(t, 5, a.Get(1, 1))
|
||||
assert.Equal(t, 6, a.Get(1, 2))
|
||||
|
||||
// Add itself multiple times
|
||||
c = a.Add(a, a, a)
|
||||
assert.Same(t, a, c)
|
||||
|
||||
assert.Equal(t, 4, a.Get(0, 0))
|
||||
assert.Equal(t, 8, a.Get(0, 1))
|
||||
assert.Equal(t, 12, a.Get(0, 2))
|
||||
assert.Equal(t, 16, a.Get(1, 0))
|
||||
assert.Equal(t, 20, a.Get(1, 1))
|
||||
assert.Equal(t, 24, a.Get(1, 2))
|
||||
|
||||
// Add other matrix
|
||||
c = a.Add(b)
|
||||
assert.Same(t, a, c)
|
||||
|
||||
assert.Equal(t, 5, a.Get(0, 0))
|
||||
assert.Equal(t, 9, a.Get(0, 1))
|
||||
assert.Equal(t, 13, a.Get(0, 2))
|
||||
assert.Equal(t, 17, a.Get(1, 0))
|
||||
assert.Equal(t, 21, a.Get(1, 1))
|
||||
assert.Equal(t, 25, a.Get(1, 2))
|
||||
|
||||
// Add incorrect dimension
|
||||
assert.PanicsWithValue(
|
||||
t, matrix.ErrIncompatibleMatrixDimensions,
|
||||
func() {
|
||||
d := matrix.Create[int](3, 2)
|
||||
a.Add(b, a, a, d)
|
||||
},
|
||||
)
|
||||
|
||||
assert.Equal(t, 5, a.Get(0, 0))
|
||||
assert.Equal(t, 9, a.Get(0, 1))
|
||||
assert.Equal(t, 13, a.Get(0, 2))
|
||||
assert.Equal(t, 17, a.Get(1, 0))
|
||||
assert.Equal(t, 21, a.Get(1, 1))
|
||||
assert.Equal(t, 25, a.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestMatrix_Subtract(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||
|
||||
// Subtract nothing
|
||||
c := a.Subtract()
|
||||
assert.Same(t, a, c)
|
||||
assert.Equal(t, 1, a.Get(0, 0))
|
||||
assert.Equal(t, 2, a.Get(0, 1))
|
||||
assert.Equal(t, 3, a.Get(0, 2))
|
||||
assert.Equal(t, 4, a.Get(1, 0))
|
||||
assert.Equal(t, 5, a.Get(1, 1))
|
||||
assert.Equal(t, 6, a.Get(1, 2))
|
||||
|
||||
// Add itself multiple times
|
||||
c = a.Subtract(a, a, a, a)
|
||||
assert.Same(t, a, c)
|
||||
|
||||
assert.Equal(t, -3, a.Get(0, 0))
|
||||
assert.Equal(t, -6, a.Get(0, 1))
|
||||
assert.Equal(t, -9, a.Get(0, 2))
|
||||
assert.Equal(t, -12, a.Get(1, 0))
|
||||
assert.Equal(t, -15, a.Get(1, 1))
|
||||
assert.Equal(t, -18, a.Get(1, 2))
|
||||
|
||||
// Add other matrix
|
||||
c = a.Subtract(b)
|
||||
assert.Same(t, a, c)
|
||||
|
||||
assert.Equal(t, -4, a.Get(0, 0))
|
||||
assert.Equal(t, -7, a.Get(0, 1))
|
||||
assert.Equal(t, -10, a.Get(0, 2))
|
||||
assert.Equal(t, -13, a.Get(1, 0))
|
||||
assert.Equal(t, -16, a.Get(1, 1))
|
||||
assert.Equal(t, -19, a.Get(1, 2))
|
||||
|
||||
// Add incorrect dimension
|
||||
assert.PanicsWithValue(
|
||||
t, matrix.ErrIncompatibleMatrixDimensions,
|
||||
func() {
|
||||
d := matrix.Create[int](3, 2)
|
||||
a.Subtract(b, a, a, d)
|
||||
},
|
||||
)
|
||||
|
||||
assert.Equal(t, -4, a.Get(0, 0))
|
||||
assert.Equal(t, -7, a.Get(0, 1))
|
||||
assert.Equal(t, -10, a.Get(0, 2))
|
||||
assert.Equal(t, -13, a.Get(1, 0))
|
||||
assert.Equal(t, -16, a.Get(1, 1))
|
||||
assert.Equal(t, -19, a.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestMatrix_Apply(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
assert.NotNil(t, a)
|
||||
|
||||
a.Apply(func(x int) int { return x * x })
|
||||
|
||||
assert.Equal(t, 1, a.Get(0, 0))
|
||||
assert.Equal(t, 4, a.Get(0, 1))
|
||||
assert.Equal(t, 9, a.Get(0, 2))
|
||||
assert.Equal(t, 16, a.Get(1, 0))
|
||||
assert.Equal(t, 25, a.Get(1, 1))
|
||||
assert.Equal(t, 36, a.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestMatrix_Copy(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
assert.NotNil(t, a)
|
||||
|
||||
b := a.Copy()
|
||||
assert.NotSame(t, a, b)
|
||||
assert.Equal(t, 1, b.Get(0, 0))
|
||||
assert.Equal(t, 2, b.Get(0, 1))
|
||||
assert.Equal(t, 3, b.Get(0, 2))
|
||||
assert.Equal(t, 4, b.Get(1, 0))
|
||||
assert.Equal(t, 5, b.Get(1, 1))
|
||||
assert.Equal(t, 6, b.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestMatrix_Get(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
assert.NotNil(t, a)
|
||||
|
||||
assert.Equal(t, 1, a.Get(0, 0))
|
||||
assert.Equal(t, 2, a.Get(0, 1))
|
||||
assert.Equal(t, 3, a.Get(0, 2))
|
||||
assert.Equal(t, 4, a.Get(1, 0))
|
||||
assert.Equal(t, 5, a.Get(1, 1))
|
||||
assert.Equal(t, 6, a.Get(1, 2))
|
||||
|
||||
assert.Panics(t, func() {
|
||||
a.Get(-1, 0)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
a.Get(0, -1)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
a.Get(2, 0)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
a.Get(0, 3)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatrix_Set(t *testing.T) {
|
||||
a := matrix.Create[int](2, 2)
|
||||
assert.NotNil(t, a)
|
||||
|
||||
for row := range a.Rows() {
|
||||
for col := range a.Cols() {
|
||||
assert.NotEqual(t, a.Get(row, col), 42)
|
||||
a.Set(row, col, 42)
|
||||
assert.Equal(t, 42, a.Get(row, col))
|
||||
}
|
||||
}
|
||||
|
||||
assert.Panics(t, func() {
|
||||
a.Set(-1, 0, 42)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
a.Set(0, -1, 42)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
a.Set(2, 0, 42)
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
a.Set(0, 2, 42)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatrix_Scale(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
|
||||
// Add nothing
|
||||
b := a.Scale(-4)
|
||||
assert.Same(t, a, b)
|
||||
|
||||
assert.Equal(t, -4, a.Get(0, 0))
|
||||
assert.Equal(t, -8, a.Get(0, 1))
|
||||
assert.Equal(t, -12, a.Get(0, 2))
|
||||
assert.Equal(t, -16, a.Get(1, 0))
|
||||
assert.Equal(t, -20, a.Get(1, 1))
|
||||
assert.Equal(t, -24, a.Get(1, 2))
|
||||
}
|
Reference in New Issue
Block a user