Add basic matrix package
This commit is contained in:
14
go.mod
Normal file
14
go.mod
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
module git.omicron.one/playground/cryptography
|
||||||
|
|
||||||
|
go 1.24.2
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/stretchr/testify v1.10.0
|
||||||
|
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
12
go.sum
Normal file
12
go.sum
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
|
||||||
|
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
44
matrix/examples_test.go
Normal file
44
matrix/examples_test.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package matrix_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.omicron.one/playground/cryptography/matrix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleConvert() {
|
||||||
|
intMatrix := matrix.Create[int](3, 4)
|
||||||
|
|
||||||
|
_ = matrix.Convert[float64](intMatrix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleCreateFromFlatSlice() {
|
||||||
|
m := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3})
|
||||||
|
|
||||||
|
for i := range m.Rows() {
|
||||||
|
for j := range m.Cols() {
|
||||||
|
fmt.Printf(" %d", m.Get(i, j))
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// 0 1
|
||||||
|
// 2 3
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleTransform() {
|
||||||
|
intMatrix := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3})
|
||||||
|
|
||||||
|
_ = matrix.Transform(intMatrix, func(in int) complex128 {
|
||||||
|
return complex(float64(in), 0.0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleMatrix_Apply() {
|
||||||
|
m := matrix.CreateFromFlatSlice(2, 2, []int{0, 1, 2, 3})
|
||||||
|
|
||||||
|
m.Apply(func(in int) int {
|
||||||
|
return in * in
|
||||||
|
})
|
||||||
|
}
|
242
matrix/matrix.go
Normal file
242
matrix/matrix.go
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
// matrix provides very basic matrix operations with all numeric types. It is
|
||||||
|
// not a high performance or feature rich implementation and is mostly intended
|
||||||
|
// to use for collecting statistics
|
||||||
|
package matrix
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidDimensions = errors.New("Invalid dimensions")
|
||||||
|
ErrIncompatibleMatrixDimensions = errors.New("Incompatible matrix dimensions")
|
||||||
|
ErrIncompatibleDataDimensions = errors.New("Incompatible data dimensions")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Matrices can be created for these underlying types
|
||||||
|
type Number interface {
|
||||||
|
constraints.Integer | constraints.Float | constraints.Complex
|
||||||
|
}
|
||||||
|
|
||||||
|
type SimpleNumber interface {
|
||||||
|
constraints.Integer | constraints.Float
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matrix represents a matrix with values of a specific Number type
|
||||||
|
type Matrix[T Number] struct {
|
||||||
|
rows int
|
||||||
|
cols int
|
||||||
|
values [][]T
|
||||||
|
}
|
||||||
|
|
||||||
|
func createValues[T Number](rows, cols int) [][]T {
|
||||||
|
values := make([][]T, rows)
|
||||||
|
for i := range rows {
|
||||||
|
values[i] = make([]T, cols)
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create creates a new matrix with the given number of rows and columns. All
|
||||||
|
// values of this matrix are set to zero. Returns the new matrix.
|
||||||
|
//
|
||||||
|
// Panics with ErrInvalidDimensions if rows < 1 or cols < 1.
|
||||||
|
func Create[T Number](rows, cols int) *Matrix[T] {
|
||||||
|
if rows < 1 || cols < 1 {
|
||||||
|
panic(ErrInvalidDimensions)
|
||||||
|
}
|
||||||
|
return &Matrix[T]{
|
||||||
|
rows: rows,
|
||||||
|
cols: cols,
|
||||||
|
values: createValues[T](rows, cols),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new matrix of a given size and sets the values from the given
|
||||||
|
// slice. The values are used to fill the matrix left to right and top to
|
||||||
|
// bottom. Returns the newly created matrix.
|
||||||
|
//
|
||||||
|
// Panics with ErrInvalidDimensions if rows < 1 or cols < 1.
|
||||||
|
// Panics with ErrIncompatibleDataDimensions if the length of the values doesn't
|
||||||
|
// match the size of the matrix. This is the only possible error condition.
|
||||||
|
func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] {
|
||||||
|
if rows < 1 || cols < 1 {
|
||||||
|
panic(ErrInvalidDimensions)
|
||||||
|
}
|
||||||
|
if len(values) != rows*cols {
|
||||||
|
panic(ErrIncompatibleDataDimensions)
|
||||||
|
}
|
||||||
|
m := Create[T](rows, cols)
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
for i := range rows {
|
||||||
|
for j := range cols {
|
||||||
|
m.values[i][j] = values[n]
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert will take a matrix of type T and convert it into a matrix of type U
|
||||||
|
// Only works on SimpleNumber matrices
|
||||||
|
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
|
||||||
|
out := Create[U](in.rows, in.cols)
|
||||||
|
for i := range in.rows {
|
||||||
|
for j := range in.cols {
|
||||||
|
out.values[i][j] = U(in.values[i][j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform the values of the given matrix with a transfrom function. This
|
||||||
|
// operation may change the type of the values. Unlike Convert it can be used
|
||||||
|
// on all Number types, not just SimpleNumber.
|
||||||
|
func Transform[U, T Number](in *Matrix[T], transformFn func(T) U) *Matrix[U] {
|
||||||
|
out := Create[U](in.rows, in.cols)
|
||||||
|
for i := range in.rows {
|
||||||
|
for j := range in.cols {
|
||||||
|
out.values[i][j] = transformFn(in.values[i][j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum takes at least one matrix and sums it together.
|
||||||
|
// Returns a new matrix that is the sum of the arguments.
|
||||||
|
// Panics if the arguments don't have matching dimensions.
|
||||||
|
func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
||||||
|
return first.Copy().Add(additional...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy creates a deep copy of the matrix and returns the new instance
|
||||||
|
func (m *Matrix[T]) Copy() *Matrix[T] {
|
||||||
|
mCopy := Create[T](m.rows, m.cols)
|
||||||
|
for i := range m.rows {
|
||||||
|
copy(mCopy.values[i], m.values[i])
|
||||||
|
}
|
||||||
|
return mCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the dimensions of the matrix as (rows, columns)
|
||||||
|
func (m *Matrix[T]) Size() (int, int) {
|
||||||
|
return m.rows, m.cols
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rows returns the number of rows in the matrix
|
||||||
|
func (m *Matrix[T]) Rows() int {
|
||||||
|
return m.rows
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cols returns the number of columns in the matrix
|
||||||
|
func (m *Matrix[T]) Cols() int {
|
||||||
|
return m.cols
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the value of the matrix at the given row and col position to the
|
||||||
|
// given value
|
||||||
|
func (m *Matrix[T]) Set(row, col int, value T) {
|
||||||
|
m.values[row][col] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set assigns the specified value to the element at the given row and column
|
||||||
|
func (m *Matrix[T]) Get(row, col int) T {
|
||||||
|
return m.values[row][col]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add performs in-place addition of zero or more matrices to this matrix and
|
||||||
|
// returns the receiver.
|
||||||
|
// Ensures correct behavior even if the matrix itself is passed as one or more
|
||||||
|
// arguments.
|
||||||
|
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
|
||||||
|
// have matching dimensions.
|
||||||
|
func (m *Matrix[T]) Add(matrices ...*Matrix[T]) *Matrix[T] {
|
||||||
|
numSelf := 0
|
||||||
|
for _, other := range matrices {
|
||||||
|
if m.rows != other.rows || m.cols != other.cols {
|
||||||
|
panic(ErrIncompatibleMatrixDimensions)
|
||||||
|
}
|
||||||
|
if other == m {
|
||||||
|
numSelf += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have multiple self references to add, work on duplicate data to
|
||||||
|
// make addition behave as expected
|
||||||
|
values := m.values
|
||||||
|
if numSelf > 1 {
|
||||||
|
values = m.Copy().values
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, other := range matrices {
|
||||||
|
for i := range m.rows {
|
||||||
|
for j := range m.cols {
|
||||||
|
values[i][j] += other.values[i][j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.values = values
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subtract performs in-place subtraction of zero or more matrices from this matrix and
|
||||||
|
// returns the receiver.
|
||||||
|
// Ensures correct behavior even if the matrix itself is passed as one or more
|
||||||
|
// arguments.
|
||||||
|
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
|
||||||
|
// have matching dimensions.
|
||||||
|
func (m *Matrix[T]) Subtract(matrices ...*Matrix[T]) *Matrix[T] {
|
||||||
|
numSelf := 0
|
||||||
|
for _, other := range matrices {
|
||||||
|
if m.rows != other.rows || m.cols != other.cols {
|
||||||
|
panic(ErrIncompatibleMatrixDimensions)
|
||||||
|
}
|
||||||
|
if other == m {
|
||||||
|
numSelf += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have multiple self references to subtract, work on duplicate data to
|
||||||
|
// make subtraction behave as expected
|
||||||
|
values := m.values
|
||||||
|
if numSelf > 1 {
|
||||||
|
values = m.Copy().values
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, other := range matrices {
|
||||||
|
for i := range m.rows {
|
||||||
|
for j := range m.cols {
|
||||||
|
values[i][j] -= other.values[i][j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.values = values
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply performs an in-place transformation of each element of the matrix using
|
||||||
|
// the provided function. Returns the receiver.
|
||||||
|
func (m *Matrix[T]) Apply(fn func(T) T) *Matrix[T] {
|
||||||
|
for i := range m.rows {
|
||||||
|
for j := range m.cols {
|
||||||
|
m.values[i][j] = fn(m.values[i][j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale does an in-place scalar multiplication of the matrix values. Returns
|
||||||
|
// the receiver.
|
||||||
|
func (m *Matrix[T]) Scale(scalar T) *Matrix[T] {
|
||||||
|
for i := range m.rows {
|
||||||
|
for j := range m.cols {
|
||||||
|
m.values[i][j] *= scalar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
368
matrix/matrix_test.go
Normal file
368
matrix/matrix_test.go
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
package matrix_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.omicron.one/playground/cryptography/matrix"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvert(t *testing.T) {
|
||||||
|
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
assert.NotNil(t, m)
|
||||||
|
|
||||||
|
mf := matrix.Convert[float64](m)
|
||||||
|
assert.Equal(t, m.Rows(), mf.Rows())
|
||||||
|
assert.Equal(t, m.Cols(), mf.Cols())
|
||||||
|
|
||||||
|
for row := range mf.Rows() {
|
||||||
|
for col := range mf.Cols() {
|
||||||
|
assert.Equal(t, float64(m.Get(row, col)), mf.Get(row, col))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreate(t *testing.T) {
|
||||||
|
m := matrix.Create[int](3, 4)
|
||||||
|
assert.NotNil(t, m)
|
||||||
|
|
||||||
|
rows, cols := m.Size()
|
||||||
|
assert.Equal(t, 3, rows)
|
||||||
|
assert.Equal(t, 3, m.Rows())
|
||||||
|
assert.Equal(t, 4, cols)
|
||||||
|
assert.Equal(t, 4, m.Cols())
|
||||||
|
|
||||||
|
for row := range rows {
|
||||||
|
for col := range cols {
|
||||||
|
assert.Equal(t, 0, m.Get(row, col))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.Create[int](0, 1)
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.Create[int](-1, 1)
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.Create[int](1, 0)
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.Create[int](1, -1)
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.Create[int](-1, -1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromFlatSlice(t *testing.T) {
|
||||||
|
m := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
assert.NotNil(t, m)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, m.Rows())
|
||||||
|
assert.Equal(t, 3, m.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 1, m.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, m.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, m.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, m.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, m.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, m.Get(1, 2))
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
|
||||||
|
matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5})
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrIncompatibleDataDimensions, func() {
|
||||||
|
matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6, 7})
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.CreateFromFlatSlice(0, 1, []int{})
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.CreateFromFlatSlice(-1, 1, []int{})
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.CreateFromFlatSlice(1, 0, []int{})
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.CreateFromFlatSlice(1, -1, []int{})
|
||||||
|
})
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrInvalidDimensions, func() {
|
||||||
|
matrix.CreateFromFlatSlice(-1, -1, []int{1})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSum(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||||
|
|
||||||
|
// Sum only one matrix
|
||||||
|
c := matrix.Sum(a)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotSame(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 1, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Sum one matrix multiple times
|
||||||
|
c = matrix.Sum(a, a, a)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotSame(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 3, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 6, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 9, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 12, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 15, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 18, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Sum different matrices
|
||||||
|
c = matrix.Sum(a, b)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotEqual(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 3, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 4, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 5, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 6, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 7, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Sum incorrect dimensions
|
||||||
|
d := matrix.Create[int](3, 2)
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() {
|
||||||
|
matrix.Sum(a, d)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransform(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.Transform(a, func(value int) complex128 {
|
||||||
|
return complex(float64(value), 3.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, a.Rows(), b.Rows())
|
||||||
|
assert.Equal(t, a.Cols(), b.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, complex(1.0, 3.0), b.Get(0, 0))
|
||||||
|
assert.Equal(t, complex(2.0, 3.0), b.Get(0, 1))
|
||||||
|
assert.Equal(t, complex(3.0, 3.0), b.Get(0, 2))
|
||||||
|
assert.Equal(t, complex(4.0, 3.0), b.Get(1, 0))
|
||||||
|
assert.Equal(t, complex(5.0, 3.0), b.Get(1, 1))
|
||||||
|
assert.Equal(t, complex(6.0, 3.0), b.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Add(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||||
|
|
||||||
|
// Add nothing
|
||||||
|
c := a.Add()
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Add itself multiple times
|
||||||
|
c = a.Add(a, a, a)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 4, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 8, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 12, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 16, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 20, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 24, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Add other matrix
|
||||||
|
c = a.Add(b)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 5, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 9, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 13, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 17, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 21, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 25, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Add incorrect dimension
|
||||||
|
assert.PanicsWithValue(
|
||||||
|
t, matrix.ErrIncompatibleMatrixDimensions,
|
||||||
|
func() {
|
||||||
|
d := matrix.Create[int](3, 2)
|
||||||
|
a.Add(b, a, a, d)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, 5, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 9, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 13, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 17, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 21, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 25, a.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Subtract(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.CreateFromFlatSlice(2, 3, []int{1, 1, 1, 1, 1, 1})
|
||||||
|
|
||||||
|
// Subtract nothing
|
||||||
|
c := a.Subtract()
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Add itself multiple times
|
||||||
|
c = a.Subtract(a, a, a, a)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, -3, a.Get(0, 0))
|
||||||
|
assert.Equal(t, -6, a.Get(0, 1))
|
||||||
|
assert.Equal(t, -9, a.Get(0, 2))
|
||||||
|
assert.Equal(t, -12, a.Get(1, 0))
|
||||||
|
assert.Equal(t, -15, a.Get(1, 1))
|
||||||
|
assert.Equal(t, -18, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Add other matrix
|
||||||
|
c = a.Subtract(b)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, -4, a.Get(0, 0))
|
||||||
|
assert.Equal(t, -7, a.Get(0, 1))
|
||||||
|
assert.Equal(t, -10, a.Get(0, 2))
|
||||||
|
assert.Equal(t, -13, a.Get(1, 0))
|
||||||
|
assert.Equal(t, -16, a.Get(1, 1))
|
||||||
|
assert.Equal(t, -19, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Add incorrect dimension
|
||||||
|
assert.PanicsWithValue(
|
||||||
|
t, matrix.ErrIncompatibleMatrixDimensions,
|
||||||
|
func() {
|
||||||
|
d := matrix.Create[int](3, 2)
|
||||||
|
a.Subtract(b, a, a, d)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, -4, a.Get(0, 0))
|
||||||
|
assert.Equal(t, -7, a.Get(0, 1))
|
||||||
|
assert.Equal(t, -10, a.Get(0, 2))
|
||||||
|
assert.Equal(t, -13, a.Get(1, 0))
|
||||||
|
assert.Equal(t, -16, a.Get(1, 1))
|
||||||
|
assert.Equal(t, -19, a.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Apply(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
assert.NotNil(t, a)
|
||||||
|
|
||||||
|
a.Apply(func(x int) int { return x * x })
|
||||||
|
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 4, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 9, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 16, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 25, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 36, a.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Copy(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
assert.NotNil(t, a)
|
||||||
|
|
||||||
|
b := a.Copy()
|
||||||
|
assert.NotSame(t, a, b)
|
||||||
|
assert.Equal(t, 1, b.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, b.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, b.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, b.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, b.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, b.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Get(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
assert.NotNil(t, a)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, a.Get(1, 2))
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Get(-1, 0)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Get(0, -1)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Get(2, 0)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Get(0, 3)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Set(t *testing.T) {
|
||||||
|
a := matrix.Create[int](2, 2)
|
||||||
|
assert.NotNil(t, a)
|
||||||
|
|
||||||
|
for row := range a.Rows() {
|
||||||
|
for col := range a.Cols() {
|
||||||
|
assert.NotEqual(t, a.Get(row, col), 42)
|
||||||
|
a.Set(row, col, 42)
|
||||||
|
assert.Equal(t, 42, a.Get(row, col))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Set(-1, 0, 42)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Set(0, -1, 42)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Set(2, 0, 42)
|
||||||
|
})
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.Set(0, 2, 42)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatrix_Scale(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
|
||||||
|
// Add nothing
|
||||||
|
b := a.Scale(-4)
|
||||||
|
assert.Same(t, a, b)
|
||||||
|
|
||||||
|
assert.Equal(t, -4, a.Get(0, 0))
|
||||||
|
assert.Equal(t, -8, a.Get(0, 1))
|
||||||
|
assert.Equal(t, -12, a.Get(0, 2))
|
||||||
|
assert.Equal(t, -16, a.Get(1, 0))
|
||||||
|
assert.Equal(t, -20, a.Get(1, 1))
|
||||||
|
assert.Equal(t, -24, a.Get(1, 2))
|
||||||
|
}
|
Reference in New Issue
Block a user