Add HadamardProduct function and HadamardMultiply method
This commit is contained in:
@ -112,6 +112,13 @@ func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
||||
return first.Copy().Add(additional...)
|
||||
}
|
||||
|
||||
// HadamardProduct takes at least one matrix and performs component-wise
|
||||
// multiplication of all given matrices. Returns a new matrix with the computed
|
||||
// values. Panics if the arguments don't have matching dimensions.
|
||||
func HadamardProduct[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
||||
return first.Copy().HadamardMultiply(additional...)
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the matrix and returns the new instance
|
||||
func (m *Matrix[T]) Copy() *Matrix[T] {
|
||||
mCopy := Create[T](m.rows, m.cols)
|
||||
@ -240,3 +247,39 @@ func (m *Matrix[T]) Scale(scalar T) *Matrix[T] {
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// HadamardMultiply performs an in-place component-wise multiplication of this
|
||||
// matrix with zero or more matrices.
|
||||
// Ensures correct behavior even if the matrix itself is passed as one or more
|
||||
// arguments.
|
||||
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
|
||||
// have matching dimensions.
|
||||
func (m *Matrix[T]) HadamardMultiply(matrices ...*Matrix[T]) *Matrix[T] {
|
||||
numSelf := 0
|
||||
for _, other := range matrices {
|
||||
if m.rows != other.rows || m.cols != other.cols {
|
||||
panic(ErrIncompatibleMatrixDimensions)
|
||||
}
|
||||
if other == m {
|
||||
numSelf += 1
|
||||
}
|
||||
}
|
||||
|
||||
// If we have multiple self references to multiply, work on duplicate data to
|
||||
// make multiplication behave as expected
|
||||
values := m.values
|
||||
if numSelf > 1 {
|
||||
values = m.Copy().values
|
||||
}
|
||||
|
||||
for _, other := range matrices {
|
||||
for i := range m.rows {
|
||||
for j := range m.cols {
|
||||
values[i][j] *= other.values[i][j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.values = values
|
||||
return m
|
||||
}
|
||||
|
@ -366,3 +366,112 @@ func TestMatrix_Scale(t *testing.T) {
|
||||
assert.Equal(t, -20, a.Get(1, 1))
|
||||
assert.Equal(t, -24, a.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestMatrix_HadamardMultiply(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
|
||||
|
||||
// Multiply nothing
|
||||
c := a.HadamardMultiply()
|
||||
assert.Same(t, a, c)
|
||||
assert.Equal(t, 1, a.Get(0, 0))
|
||||
assert.Equal(t, 2, a.Get(0, 1))
|
||||
assert.Equal(t, 3, a.Get(0, 2))
|
||||
assert.Equal(t, 4, a.Get(1, 0))
|
||||
assert.Equal(t, 5, a.Get(1, 1))
|
||||
assert.Equal(t, 6, a.Get(1, 2))
|
||||
|
||||
// Multiply itself multiple times
|
||||
c = a.HadamardMultiply(a, a, a)
|
||||
assert.Same(t, a, c)
|
||||
|
||||
assert.Equal(t, 1, a.Get(0, 0))
|
||||
assert.Equal(t, 16, a.Get(0, 1))
|
||||
assert.Equal(t, 81, a.Get(0, 2))
|
||||
assert.Equal(t, 256, a.Get(1, 0))
|
||||
assert.Equal(t, 625, a.Get(1, 1))
|
||||
assert.Equal(t, 1296, a.Get(1, 2))
|
||||
|
||||
// Multiply other matrix
|
||||
c = a.HadamardMultiply(b)
|
||||
assert.Same(t, a, c)
|
||||
|
||||
assert.Equal(t, 2, a.Get(0, 0))
|
||||
assert.Equal(t, 32, a.Get(0, 1))
|
||||
assert.Equal(t, 162, a.Get(0, 2))
|
||||
assert.Equal(t, 512, a.Get(1, 0))
|
||||
assert.Equal(t, 1250, a.Get(1, 1))
|
||||
assert.Equal(t, 2592, a.Get(1, 2))
|
||||
|
||||
// Multiply incorrect dimension
|
||||
assert.PanicsWithValue(
|
||||
t, matrix.ErrIncompatibleMatrixDimensions,
|
||||
func() {
|
||||
d := matrix.Create[int](3, 2)
|
||||
a.HadamardMultiply(b, a, a, d)
|
||||
},
|
||||
)
|
||||
|
||||
assert.Equal(t, 2, a.Get(0, 0))
|
||||
assert.Equal(t, 32, a.Get(0, 1))
|
||||
assert.Equal(t, 162, a.Get(0, 2))
|
||||
assert.Equal(t, 512, a.Get(1, 0))
|
||||
assert.Equal(t, 1250, a.Get(1, 1))
|
||||
assert.Equal(t, 2592, a.Get(1, 2))
|
||||
}
|
||||
|
||||
func TestHadamardProduct(t *testing.T) {
|
||||
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
|
||||
|
||||
// Multiply only one matrix
|
||||
c := matrix.HadamardProduct(a)
|
||||
assert.NotNil(t, c)
|
||||
assert.NotSame(t, a, c)
|
||||
|
||||
assert.Equal(t, 2, c.Rows())
|
||||
assert.Equal(t, 3, c.Cols())
|
||||
|
||||
assert.Equal(t, 1, c.Get(0, 0))
|
||||
assert.Equal(t, 2, c.Get(0, 1))
|
||||
assert.Equal(t, 3, c.Get(0, 2))
|
||||
assert.Equal(t, 4, c.Get(1, 0))
|
||||
assert.Equal(t, 5, c.Get(1, 1))
|
||||
assert.Equal(t, 6, c.Get(1, 2))
|
||||
|
||||
// Multilply one matrix multiple times
|
||||
c = matrix.HadamardProduct(a, a, a)
|
||||
assert.NotNil(t, c)
|
||||
assert.NotSame(t, a, c)
|
||||
|
||||
assert.Equal(t, 2, c.Rows())
|
||||
assert.Equal(t, 3, c.Cols())
|
||||
|
||||
assert.Equal(t, 1, c.Get(0, 0))
|
||||
assert.Equal(t, 8, c.Get(0, 1))
|
||||
assert.Equal(t, 27, c.Get(0, 2))
|
||||
assert.Equal(t, 64, c.Get(1, 0))
|
||||
assert.Equal(t, 125, c.Get(1, 1))
|
||||
assert.Equal(t, 216, c.Get(1, 2))
|
||||
|
||||
// Multiply different matrices
|
||||
c = matrix.HadamardProduct(a, b)
|
||||
assert.NotNil(t, c)
|
||||
assert.NotEqual(t, a, c)
|
||||
|
||||
assert.Equal(t, 2, c.Rows())
|
||||
assert.Equal(t, 3, c.Cols())
|
||||
|
||||
assert.Equal(t, 2, c.Get(0, 0))
|
||||
assert.Equal(t, 4, c.Get(0, 1))
|
||||
assert.Equal(t, 6, c.Get(0, 2))
|
||||
assert.Equal(t, 8, c.Get(1, 0))
|
||||
assert.Equal(t, 10, c.Get(1, 1))
|
||||
assert.Equal(t, 12, c.Get(1, 2))
|
||||
|
||||
// Multiply incorrect dimensions
|
||||
d := matrix.Create[int](3, 2)
|
||||
assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() {
|
||||
matrix.HadamardProduct(a, d)
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user