Add HadamardProduct function and HadamardMultiply method
This commit is contained in:
@ -112,6 +112,13 @@ func Sum[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
|||||||
return first.Copy().Add(additional...)
|
return first.Copy().Add(additional...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HadamardProduct takes at least one matrix and performs component-wise
|
||||||
|
// multiplication of all given matrices. Returns a new matrix with the computed
|
||||||
|
// values. Panics if the arguments don't have matching dimensions.
|
||||||
|
func HadamardProduct[T Number](first *Matrix[T], additional ...*Matrix[T]) *Matrix[T] {
|
||||||
|
return first.Copy().HadamardMultiply(additional...)
|
||||||
|
}
|
||||||
|
|
||||||
// Copy creates a deep copy of the matrix and returns the new instance
|
// Copy creates a deep copy of the matrix and returns the new instance
|
||||||
func (m *Matrix[T]) Copy() *Matrix[T] {
|
func (m *Matrix[T]) Copy() *Matrix[T] {
|
||||||
mCopy := Create[T](m.rows, m.cols)
|
mCopy := Create[T](m.rows, m.cols)
|
||||||
@ -240,3 +247,39 @@ func (m *Matrix[T]) Scale(scalar T) *Matrix[T] {
|
|||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HadamardMultiply performs an in-place component-wise multiplication of this
|
||||||
|
// matrix with zero or more matrices.
|
||||||
|
// Ensures correct behavior even if the matrix itself is passed as one or more
|
||||||
|
// arguments.
|
||||||
|
// Panics with ErrIncompatibleMatrixDimensions if any of the matrices don't
|
||||||
|
// have matching dimensions.
|
||||||
|
func (m *Matrix[T]) HadamardMultiply(matrices ...*Matrix[T]) *Matrix[T] {
|
||||||
|
numSelf := 0
|
||||||
|
for _, other := range matrices {
|
||||||
|
if m.rows != other.rows || m.cols != other.cols {
|
||||||
|
panic(ErrIncompatibleMatrixDimensions)
|
||||||
|
}
|
||||||
|
if other == m {
|
||||||
|
numSelf += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have multiple self references to multiply, work on duplicate data to
|
||||||
|
// make multiplication behave as expected
|
||||||
|
values := m.values
|
||||||
|
if numSelf > 1 {
|
||||||
|
values = m.Copy().values
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, other := range matrices {
|
||||||
|
for i := range m.rows {
|
||||||
|
for j := range m.cols {
|
||||||
|
values[i][j] *= other.values[i][j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.values = values
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
@ -366,3 +366,112 @@ func TestMatrix_Scale(t *testing.T) {
|
|||||||
assert.Equal(t, -20, a.Get(1, 1))
|
assert.Equal(t, -20, a.Get(1, 1))
|
||||||
assert.Equal(t, -24, a.Get(1, 2))
|
assert.Equal(t, -24, a.Get(1, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMatrix_HadamardMultiply(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
|
||||||
|
|
||||||
|
// Multiply nothing
|
||||||
|
c := a.HadamardMultiply()
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply itself multiple times
|
||||||
|
c = a.HadamardMultiply(a, a, a)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 16, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 81, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 256, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 625, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 1296, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply other matrix
|
||||||
|
c = a.HadamardMultiply(b)
|
||||||
|
assert.Same(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 32, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 162, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 512, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 1250, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 2592, a.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply incorrect dimension
|
||||||
|
assert.PanicsWithValue(
|
||||||
|
t, matrix.ErrIncompatibleMatrixDimensions,
|
||||||
|
func() {
|
||||||
|
d := matrix.Create[int](3, 2)
|
||||||
|
a.HadamardMultiply(b, a, a, d)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, a.Get(0, 0))
|
||||||
|
assert.Equal(t, 32, a.Get(0, 1))
|
||||||
|
assert.Equal(t, 162, a.Get(0, 2))
|
||||||
|
assert.Equal(t, 512, a.Get(1, 0))
|
||||||
|
assert.Equal(t, 1250, a.Get(1, 1))
|
||||||
|
assert.Equal(t, 2592, a.Get(1, 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHadamardProduct(t *testing.T) {
|
||||||
|
a := matrix.CreateFromFlatSlice(2, 3, []int{1, 2, 3, 4, 5, 6})
|
||||||
|
b := matrix.CreateFromFlatSlice(2, 3, []int{2, 2, 2, 2, 2, 2})
|
||||||
|
|
||||||
|
// Multiply only one matrix
|
||||||
|
c := matrix.HadamardProduct(a)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotSame(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 1, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 2, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 3, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 4, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 5, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 6, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Multilply one matrix multiple times
|
||||||
|
c = matrix.HadamardProduct(a, a, a)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotSame(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 1, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 8, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 27, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 64, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 125, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 216, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply different matrices
|
||||||
|
c = matrix.HadamardProduct(a, b)
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
assert.NotEqual(t, a, c)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Rows())
|
||||||
|
assert.Equal(t, 3, c.Cols())
|
||||||
|
|
||||||
|
assert.Equal(t, 2, c.Get(0, 0))
|
||||||
|
assert.Equal(t, 4, c.Get(0, 1))
|
||||||
|
assert.Equal(t, 6, c.Get(0, 2))
|
||||||
|
assert.Equal(t, 8, c.Get(1, 0))
|
||||||
|
assert.Equal(t, 10, c.Get(1, 1))
|
||||||
|
assert.Equal(t, 12, c.Get(1, 2))
|
||||||
|
|
||||||
|
// Multiply incorrect dimensions
|
||||||
|
d := matrix.Create[int](3, 2)
|
||||||
|
assert.PanicsWithValue(t, matrix.ErrIncompatibleMatrixDimensions, func() {
|
||||||
|
matrix.HadamardProduct(a, d)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user