Add HadamardProduct function and HadamardMultiply method

This commit is contained in:
2025-05-19 14:35:23 +02:00
parent 2756894219
commit cf17a6fb72
2 changed files with 152 additions and 0 deletions

View File

@ -112,6 +112,13 @@ func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
return first.Copy().Add(additional...)
}
// HadamardProduct takes at least one matrix and performs component-wise
// multiplication of all given matrices. Returns a new matrix with the computed
// values. Panics if the arguments don't have matching dimensions.
func HadamardProduct[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
return first.Copy().HadamardMultiply(additional...)
}
// Copy creates a deep copy of the matrix and returns the new instance
func (m *Matrix[T]) Copy() *Matrix[T] {
mCopy := Create[T](m.rows, m.cols)
@ -240,3 +247,39 @@ func (m *Matrix[T]) Scale(scalar T) *Matrix[T] {
}
return m
}
// HadamardMultiply performs an in-place component-wise multiplication of this
// matrix with zero or more matrices.
// Ensures correct behavior even if the matrix itself is passed as one or more
// arguments.
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
// have matching dimensions.
func (m *Matrix[T]) HadamardMultiply(matrices ...*Matrix[T]) *Matrix[T] {
numSelf := 0
for _, other := range matrices {
if m.rows != other.rows || m.cols != other.cols {
panic(ErrIncompatibleMatrixDimensions)
}
if other == m {
numSelf += 1
}
}
// If we have multiple self references to multiply, work on duplicate data to
// make multiplication behave as expected
values := m.values
if numSelf > 1 {
values = m.Copy().values
}
for _, other := range matrices {
for i := range m.rows {
for j := range m.cols {
values[i][j] *= other.values[i][j]
}
}
}
m.values = values
return m
}

View File

@ -366,3 +366,112 @@ func TestMatrix_Scale(t *testing.T) {
assert.Equal(t, -20, a.Get(1, 1))
assert.Equal(t, -24, a.Get(1, 2))
}
func TestMatrix_HadamardMultiply(t *testing.T) {
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
// Multiply nothing
c := a.HadamardMultiply()
assert.Same(t, a, c)
assert.Equal(t, 1, a.Get(0, 0))
assert.Equal(t, 2, a.Get(0, 1))
assert.Equal(t, 3, a.Get(0, 2))
assert.Equal(t, 4, a.Get(1, 0))
assert.Equal(t, 5, a.Get(1, 1))
assert.Equal(t, 6, a.Get(1, 2))
// Multiply itself multiple times
c = a.HadamardMultiply(a, a, a)
assert.Same(t, a, c)
assert.Equal(t, 1, a.Get(0, 0))
assert.Equal(t, 16, a.Get(0, 1))
assert.Equal(t, 81, a.Get(0, 2))
assert.Equal(t, 256, a.Get(1, 0))
assert.Equal(t, 625, a.Get(1, 1))
assert.Equal(t, 1296, a.Get(1, 2))
// Multiply other matrix
c = a.HadamardMultiply(b)
assert.Same(t, a, c)
assert.Equal(t, 2, a.Get(0, 0))
assert.Equal(t, 32, a.Get(0, 1))
assert.Equal(t, 162, a.Get(0, 2))
assert.Equal(t, 512, a.Get(1, 0))
assert.Equal(t, 1250, a.Get(1, 1))
assert.Equal(t, 2592, a.Get(1, 2))
// Multiply incorrect dimension
assert.PanicsWithValue(
t, matrix.ErrIncompatibleMatrixDimensions,
func() {
d := matrix.Create[int](3, 2)
a.HadamardMultiply(b, a, a, d)
},
)
assert.Equal(t, 2, a.Get(0, 0))
assert.Equal(t, 32, a.Get(0, 1))
assert.Equal(t, 162, a.Get(0, 2))
assert.Equal(t, 512, a.Get(1, 0))
assert.Equal(t, 1250, a.Get(1, 1))
assert.Equal(t, 2592, a.Get(1, 2))
}
func TestHadamardProduct(t *testing.T) {
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
// Multiply only one matrix
c := matrix.HadamardProduct(a)
assert.NotNil(t, c)
assert.NotSame(t, a, c)
assert.Equal(t, 2, c.Rows())
assert.Equal(t, 3, c.Cols())
assert.Equal(t, 1, c.Get(0, 0))
assert.Equal(t, 2, c.Get(0, 1))
assert.Equal(t, 3, c.Get(0, 2))
assert.Equal(t, 4, c.Get(1, 0))
assert.Equal(t, 5, c.Get(1, 1))
assert.Equal(t, 6, c.Get(1, 2))
// Multilply one matrix multiple times
c = matrix.HadamardProduct(a, a, a)
assert.NotNil(t, c)
assert.NotSame(t, a, c)
assert.Equal(t, 2, c.Rows())
assert.Equal(t, 3, c.Cols())
assert.Equal(t, 1, c.Get(0, 0))
assert.Equal(t, 8, c.Get(0, 1))
assert.Equal(t, 27, c.Get(0, 2))
assert.Equal(t, 64, c.Get(1, 0))
assert.Equal(t, 125, c.Get(1, 1))
assert.Equal(t, 216, c.Get(1, 2))
// Multiply different matrices
c = matrix.HadamardProduct(a, b)
assert.NotNil(t, c)
assert.NotEqual(t, a, c)
assert.Equal(t, 2, c.Rows())
assert.Equal(t, 3, c.Cols())
assert.Equal(t, 2, c.Get(0, 0))
assert.Equal(t, 4, c.Get(0, 1))
assert.Equal(t, 6, c.Get(0, 2))
assert.Equal(t, 8, c.Get(1, 0))
assert.Equal(t, 10, c.Get(1, 1))
assert.Equal(t, 12, c.Get(1, 2))
// Multiply incorrect dimensions
d := matrix.Create[int](3, 2)
assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() {
matrix.HadamardProduct(a, d)
})
}