Files
cryptography/matrix/matrix.go
2025-05-18 00:33:07 +02:00

243 lines
6.3 KiB
Go

// matrix provides very basic matrix operations with all numeric types. It is
// not a high performance or feature rich implementation and is mostly intended
// to use for collecting statistics
package matrix
import (
"errors"
"golang.org/x/exp/constraints"
)
var (
ErrInvalidDimensions = errors.New("Invalid dimensions")
ErrIncompatibleMatrixDimensions = errors.New("Incompatible matrix dimensions")
ErrIncompatibleDataDimensions = errors.New("Incompatible data dimensions")
)
// Matrices can be created for these underlying types
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex
}
type SimpleNumber interface {
constraints.Integer | constraints.Float
}
// Matrix represents a matrix with values of a specific Number type
type Matrix[T Number] struct {
rows int
cols int
values [][]T
}
func createValues[T Number](rows, cols int) [][]T {
values := make([][]T, rows)
for i := range rows {
values[i] = make([]T, cols)
}
return values
}
// Create creates a new matrix with the given number of rows and columns. All
// values of this matrix are set to zero. Returns the new matrix.
//
// Panics with ErrInvalidDimensions if rows < 1 or cols < 1.
func Create[T Number](rows, cols int) *Matrix[T] {
if rows < 1 || cols < 1 {
panic(ErrInvalidDimensions)
}
return &Matrix[T]{
rows: rows,
cols: cols,
values: createValues[T](rows, cols),
}
}
// Creates a new matrix of a given size and sets the values from the given
// slice. The values are used to fill the matrix left to right and top to
// bottom. Returns the newly created matrix.
//
// Panics with ErrInvalidDimensions if rows < 1 or cols < 1.
// Panics with ErrIncompatibleDataDimensions if the length of the values doesn't
// match the size of the matrix. This is the only possible error condition.
func CreateFromFlatSlice[T Number](rows, cols int, values []T) *Matrix[T] {
if rows < 1 || cols < 1 {
panic(ErrInvalidDimensions)
}
if len(values) != rows*cols {
panic(ErrIncompatibleDataDimensions)
}
m := Create[T](rows, cols)
n := 0
for i := range rows {
for j := range cols {
m.values[i][j] = values[n]
n++
}
}
return m
}
// Convert will take a matrix of type T and convert it into a matrix of type U
// Only works on SimpleNumber matrices
func Convert[U, T SimpleNumber](in *Matrix[T]) *Matrix[U] {
out := Create[U](in.rows, in.cols)
for i := range in.rows {
for j := range in.cols {
out.values[i][j] = U(in.values[i][j])
}
}
return out
}
// Transform the values of the given matrix with a transfrom function. This
// operation may change the type of the values. Unlike Convert it can be used
// on all Number types, not just SimpleNumber.
func Transform[U, T Number](in *Matrix[T], transformFn func(T) U) *Matrix[U] {
out := Create[U](in.rows, in.cols)
for i := range in.rows {
for j := range in.cols {
out.values[i][j] = transformFn(in.values[i][j])
}
}
return out
}
// Sum takes at least one matrix and sums it together.
// Returns a new matrix that is the sum of the arguments.
// Panics if the arguments don't have matching dimensions.
func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
return first.Copy().Add(additional...)
}
// Copy creates a deep copy of the matrix and returns the new instance
func (m *Matrix[T]) Copy() *Matrix[T] {
mCopy := Create[T](m.rows, m.cols)
for i := range m.rows {
copy(mCopy.values[i], m.values[i])
}
return mCopy
}
// Size returns the dimensions of the matrix as (rows, columns)
func (m *Matrix[T]) Size() (int, int) {
return m.rows, m.cols
}
// Rows returns the number of rows in the matrix
func (m *Matrix[T]) Rows() int {
return m.rows
}
// Cols returns the number of columns in the matrix
func (m *Matrix[T]) Cols() int {
return m.cols
}
// Set sets the value of the matrix at the given row and col position to the
// given value
func (m *Matrix[T]) Set(row, col int, value T) {
m.values[row][col] = value
}
// Set assigns the specified value to the element at the given row and column
func (m *Matrix[T]) Get(row, col int) T {
return m.values[row][col]
}
// Add performs in-place addition of zero or more matrices to this matrix and
// returns the receiver.
// Ensures correct behavior even if the matrix itself is passed as one or more
// arguments.
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
// have matching dimensions.
func (m *Matrix[T]) Add(matrices ...*Matrix[T]) *Matrix[T] {
numSelf := 0
for _, other := range matrices {
if m.rows != other.rows || m.cols != other.cols {
panic(ErrIncompatibleMatrixDimensions)
}
if other == m {
numSelf += 1
}
}
// If we have multiple self references to add, work on duplicate data to
// make addition behave as expected
values := m.values
if numSelf > 1 {
values = m.Copy().values
}
for _, other := range matrices {
for i := range m.rows {
for j := range m.cols {
values[i][j] += other.values[i][j]
}
}
}
m.values = values
return m
}
// Subtract performs in-place subtraction of zero or more matrices from this matrix and
// returns the receiver.
// Ensures correct behavior even if the matrix itself is passed as one or more
// arguments.
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
// have matching dimensions.
func (m *Matrix[T]) Subtract(matrices ...*Matrix[T]) *Matrix[T] {
numSelf := 0
for _, other := range matrices {
if m.rows != other.rows || m.cols != other.cols {
panic(ErrIncompatibleMatrixDimensions)
}
if other == m {
numSelf += 1
}
}
// If we have multiple self references to subtract, work on duplicate data to
// make subtraction behave as expected
values := m.values
if numSelf > 1 {
values = m.Copy().values
}
for _, other := range matrices {
for i := range m.rows {
for j := range m.cols {
values[i][j] -= other.values[i][j]
}
}
}
m.values = values
return m
}
// Apply performs an in-place transformation of each element of the matrix using
// the provided function. Returns the receiver.
func (m *Matrix[T]) Apply(fn func(T) T) *Matrix[T] {
for i := range m.rows {
for j := range m.cols {
m.values[i][j] = fn(m.values[i][j])
}
}
return m
}
// Scale does an in-place scalar multiplication of the matrix values. Returns
// the receiver.
func (m *Matrix[T]) Scale(scalar T) *Matrix[T] {
for i := range m.rows {
for j := range m.cols {
m.values[i][j] *= scalar
}
}
return m
}