From cf17a6fb7289c77e5d16331273a9df170df86a35 Mon Sep 17 00:00:00 2001 From: omicron Date: Mon, 19 May 2025 14:35:23 +0200 Subject: [PATCH] Add HadamardProduct function and HadamardMultiply method --- matrix/matrix.go | 43 +++++++++++++++++ matrix/matrix_test.go | 109 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) diff --git a/matrix/matrix.go b/matrix/matrix.go index 147e505..174b9d7 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -112,6 +112,13 @@ func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] { return first.Copy().Add(additional...) } +// HadamardProduct takes at least one matrix and performs component-wise +// multiplication of all given matrices. Returns a new matrix with the computed +// values. Panics if the arguments don't have matching dimensions. +func HadamardProduct[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] { + return first.Copy().HadamardMultiply(additional...) +} + // Copy creates a deep copy of the matrix and returns the new instance func (m *Matrix[T]) Copy() *Matrix[T] { mCopy := Create[T](m.rows, m.cols) @@ -240,3 +247,39 @@ func (m *Matrix[T]) Scale(scalar T) *Matrix[T] { } return m } + +// HadamardMultiply performs an in-place component-wise multiplication of this +// matrix with zero or more matrices. +// Ensures correct behavior even if the matrix itself is passed as one or more +// arguments. +// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't +// have matching dimensions. +func (m *Matrix[T]) HadamardMultiply(matrices ...*Matrix[T]) *Matrix[T] { + numSelf := 0 + for _, other := range matrices { + if m.rows != other.rows || m.cols != other.cols { + panic(ErrIncompatibleMatrixDimensions) + } + if other == m { + numSelf += 1 + } + } + + // If we have multiple self references to multiply, work on duplicate data to + // make multiplication behave as expected + values := m.values + if numSelf > 1 { + values = m.Copy().values + } + + for _, other := range matrices { + for i := range m.rows { + for j := range m.cols { + values[i][j] *= other.values[i][j] + } + } + } + + m.values = values + return m +} diff --git a/matrix/matrix_test.go b/matrix/matrix_test.go index 1fda57a..fabc971 100644 --- a/matrix/matrix_test.go +++ b/matrix/matrix_test.go @@ -366,3 +366,112 @@ func TestMatrix_Scale(t *testing.T) { assert.Equal(t, -20, a.Get(1, 1)) assert.Equal(t, -24, a.Get(1, 2)) } + +func TestMatrix_HadamardMultiply(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2}) + + // Multiply nothing + c := a.HadamardMultiply() + assert.Same(t, a, c) + assert.Equal(t, 1, a.Get(0, 0)) + assert.Equal(t, 2, a.Get(0, 1)) + assert.Equal(t, 3, a.Get(0, 2)) + assert.Equal(t, 4, a.Get(1, 0)) + assert.Equal(t, 5, a.Get(1, 1)) + assert.Equal(t, 6, a.Get(1, 2)) + + // Multiply itself multiple times + c = a.HadamardMultiply(a, a, a) + assert.Same(t, a, c) + + assert.Equal(t, 1, a.Get(0, 0)) + assert.Equal(t, 16, a.Get(0, 1)) + assert.Equal(t, 81, a.Get(0, 2)) + assert.Equal(t, 256, a.Get(1, 0)) + assert.Equal(t, 625, a.Get(1, 1)) + assert.Equal(t, 1296, a.Get(1, 2)) + + // Multiply other matrix + c = a.HadamardMultiply(b) + assert.Same(t, a, c) + + assert.Equal(t, 2, a.Get(0, 0)) + assert.Equal(t, 32, a.Get(0, 1)) + assert.Equal(t, 162, a.Get(0, 2)) + assert.Equal(t, 512, a.Get(1, 0)) + assert.Equal(t, 1250, a.Get(1, 1)) + assert.Equal(t, 2592, a.Get(1, 2)) + + // Multiply incorrect dimension + assert.PanicsWithValue( + t, matrix.ErrIncompatibleMatrixDimensions, + func() { + d := matrix.Create[int](3, 2) + a.HadamardMultiply(b, a, a, d) + }, + ) + + assert.Equal(t, 2, a.Get(0, 0)) + assert.Equal(t, 32, a.Get(0, 1)) + assert.Equal(t, 162, a.Get(0, 2)) + assert.Equal(t, 512, a.Get(1, 0)) + assert.Equal(t, 1250, a.Get(1, 1)) + assert.Equal(t, 2592, a.Get(1, 2)) +} + +func TestHadamardProduct(t *testing.T) { + a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6}) + b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2}) + + // Multiply only one matrix + c := matrix.HadamardProduct(a) + assert.NotNil(t, c) + assert.NotSame(t, a, c) + + assert.Equal(t, 2, c.Rows()) + assert.Equal(t, 3, c.Cols()) + + assert.Equal(t, 1, c.Get(0, 0)) + assert.Equal(t, 2, c.Get(0, 1)) + assert.Equal(t, 3, c.Get(0, 2)) + assert.Equal(t, 4, c.Get(1, 0)) + assert.Equal(t, 5, c.Get(1, 1)) + assert.Equal(t, 6, c.Get(1, 2)) + + // Multilply one matrix multiple times + c = matrix.HadamardProduct(a, a, a) + assert.NotNil(t, c) + assert.NotSame(t, a, c) + + assert.Equal(t, 2, c.Rows()) + assert.Equal(t, 3, c.Cols()) + + assert.Equal(t, 1, c.Get(0, 0)) + assert.Equal(t, 8, c.Get(0, 1)) + assert.Equal(t, 27, c.Get(0, 2)) + assert.Equal(t, 64, c.Get(1, 0)) + assert.Equal(t, 125, c.Get(1, 1)) + assert.Equal(t, 216, c.Get(1, 2)) + + // Multiply different matrices + c = matrix.HadamardProduct(a, b) + assert.NotNil(t, c) + assert.NotEqual(t, a, c) + + assert.Equal(t, 2, c.Rows()) + assert.Equal(t, 3, c.Cols()) + + assert.Equal(t, 2, c.Get(0, 0)) + assert.Equal(t, 4, c.Get(0, 1)) + assert.Equal(t, 6, c.Get(0, 2)) + assert.Equal(t, 8, c.Get(1, 0)) + assert.Equal(t, 10, c.Get(1, 1)) + assert.Equal(t, 12, c.Get(1, 2)) + + // Multiply incorrect dimensions + d := matrix.Create[int](3, 2) + assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() { + matrix.HadamardProduct(a, d) + }) +}