diff --git a/matrix/matrix.go b/matrix/matrix.go index 174b9d7..90d02e0 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -283,3 +283,13 @@ func (m *Matrix[T]) HadamardMultiply(matrices ...*Matrix[T]) *Matrix[T] { m.values = values return m } + +// Fill sets all components of this matrix to the given value. Returns the receiver. +func (m *Matrix[T]) Fill(value T) *Matrix[T] { + for i := range m.rows { + for j := range m.cols { + m.values[i][j] = value + } + } + return m +} diff --git a/matrix/matrix_test.go b/matrix/matrix_test.go index fabc971..4b06d6c 100644 --- a/matrix/matrix_test.go +++ b/matrix/matrix_test.go @@ -475,3 +475,21 @@ func TestHadamardProduct(t *testing.T) { matrix.HadamardProduct(a, d) }) } + +func TestMatrix_Fill(t *testing.T) { + a := matrix.Create[int](2, 3) + assert.Equal(t, 0, a.Get(0, 0)) + assert.Equal(t, 0, a.Get(0, 1)) + assert.Equal(t, 0, a.Get(0, 2)) + assert.Equal(t, 0, a.Get(1, 0)) + assert.Equal(t, 0, a.Get(1, 1)) + assert.Equal(t, 0, a.Get(1, 2)) + + a.Fill(3) + assert.Equal(t, 3, a.Get(0, 0)) + assert.Equal(t, 3, a.Get(0, 1)) + assert.Equal(t, 3, a.Get(0, 2)) + assert.Equal(t, 3, a.Get(1, 0)) + assert.Equal(t, 3, a.Get(1, 1)) + assert.Equal(t, 3, a.Get(1, 2)) +}