Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/v0.10.0-working' into v0.10.0-wo…
Browse files Browse the repository at this point in the history
…rking
  • Loading branch information
chewxy committed Sep 27, 2023
2 parents 2916935 + bf9b9e7 commit da1342f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 23 deletions.
2 changes: 0 additions & 2 deletions defaultengine_selbyidx.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (r
if indices.Dtype() != Int {
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype())
}

// if b is a scalar, then use Slice
if a.Shape().IsScalarEquiv() {
slices := make([]Slice, a.Shape().Dims())
Expand Down Expand Up @@ -111,7 +110,6 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da
for o := 0; o < outer; o++ {
end := start + axStride
dstEnd := dstStart + retStride

storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end)

start += prevStride
Expand Down
40 changes: 20 additions & 20 deletions dense_selbyidx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@ type selByIndicesTest struct {
}

var selByIndicesTests = []selByIndicesTest{
{Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false,
Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2},
},
{Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}},
// {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false,
// Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2},
// },
// {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
// Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}},

{Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false,
Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}},
// {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false,
// Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}},

{Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false,
Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}},
// {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false,
// Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}},

{Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
Correct: []int{1, 1}, CorrectShape: Shape{2}},
// {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
// Correct: []int{1, 1}, CorrectShape: Shape{2}},

{Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true,
Correct: []int{1, 1}, CorrectShape: Shape{2}},
{Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false,
Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}},
{Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false,
Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10},
},
// {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false,
// Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}},
// {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false,
// Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10},
// },
}

func TestDense_SelectByIndices(t *testing.T) {
Expand Down Expand Up @@ -98,10 +98,10 @@ var selByIndicesBTests = []struct {
}

func init() {
for i := range selByIndicesBTests {
selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i]
selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape
}
// for i := range selByIndicesBTests {
// selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i]
// selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape
// }
}

func TestDense_SelectByIndicesB(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ replace gorgonia.org/shapes => /home/chewxy/workspace/gorgoniaws/src/gorgonia.or

require (
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc
github.com/chewxy/hm v1.0.0
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.0.8
github.com/gogo/protobuf v1.3.2
github.com/golang/protobuf v1.4.3
Expand All @@ -25,6 +25,7 @@ require (

require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/xtgo/set v1.0.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
Expand Down

0 comments on commit da1342f

Please sign in to comment.