Skip to content

Commit fa62355

Browse files
committed
Fix #140
+ Fix SortIndex() + Add SortIndexStable()
1 parent 87b8a9a commit fa62355

File tree

2 files changed

+96
-27
lines changed

2 files changed

+96
-27
lines changed

api_utils.go

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package tensor
22

33
import (
4-
"log"
54
"math"
65
"math/rand"
76
"reflect"
@@ -10,41 +9,49 @@ import (
109
"github.com/chewxy/math32"
1110
)
1211

13-
// SortIndex is similar to numpy's argsort
14-
// TODO: tidy this up
12+
// SortIndex: Similar to numpy's argsort.
13+
// Returns indices for sorting a slice in increasing order.
14+
// Input slice remains unchanged.
15+
// SortIndex may not be stable; for stability, use SortIndexStable.
1516
func SortIndex(in interface{}) (out []int) {
17+
return sortIndex(in, sort.Slice)
18+
}
19+
20+
// SortIndexStable: Similar to SortIndex, but stable.
21+
// Returns indices for sorting a slice in increasing order.
22+
// Input slice remains unchanged.
23+
func SortIndexStable(in interface{}) (out []int) {
24+
return sortIndex(in, sort.SliceStable)
25+
}
26+
27+
func sortIndex(in interface{}, sortFunc func(x any, less func(i int, j int) bool)) (out []int) {
1628
switch list := in.(type) {
1729
case []int:
18-
orig := make([]int, len(list))
1930
out = make([]int, len(list))
20-
copy(orig, list)
21-
sort.Ints(list)
22-
for i, s := range list {
23-
for j, o := range orig {
24-
if o == s {
25-
out[i] = j
26-
break
27-
}
28-
}
31+
for i := 0; i < len(list); i++ {
32+
out[i] = i
2933
}
34+
sortFunc(out, func(i, j int) bool {
35+
return list[out[i]] < list[out[j]]
36+
})
3037
case []float64:
31-
orig := make([]float64, len(list))
3238
out = make([]int, len(list))
33-
copy(orig, list)
34-
sort.Float64s(list)
35-
36-
for i, s := range list {
37-
for j, o := range orig {
38-
if o == s {
39-
out[i] = j
40-
break
41-
}
42-
}
39+
for i := 0; i < len(list); i++ {
40+
out[i] = i
4341
}
42+
sortFunc(out, func(i, j int) bool {
43+
return list[out[i]] < list[out[j]]
44+
})
4445
case sort.Interface:
45-
sort.Sort(list)
46-
47-
log.Printf("TODO: SortIndex for sort.Interface not yet done.")
46+
out = make([]int, list.Len())
47+
for i := 0; i < list.Len(); i++ {
48+
out[i] = i
49+
}
50+
sortFunc(out, func(i, j int) bool {
51+
return list.Less(out[i], out[j])
52+
})
53+
default:
54+
panic("The slice type is not currently supported.")
4855
}
4956

5057
return

api_utils_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package tensor
2+
3+
import (
4+
"testing"
5+
)
6+
7+
type testInt []int
8+
9+
func (m testInt) Less(i, j int) bool { return m[i] < m[j] }
10+
func (m testInt) Len() int { return len(m) }
11+
func (m testInt) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
12+
13+
func TestSortIndexInts(t *testing.T) {
14+
in := []int{9, 8, 7, 6, 5, 4, 10, -1, -2, -4, 11, 13, 15, 100, 99}
15+
inCopy := make([]int, len(in))
16+
copy(inCopy, in)
17+
out := SortIndex(in)
18+
for i := 1; i < len(out); i++ {
19+
if inCopy[out[i]] < inCopy[out[i-1]] {
20+
t.Fatalf("Unexpected output")
21+
}
22+
}
23+
for i := range in {
24+
if in[i] != inCopy[i] {
25+
t.Fatalf("The input slice should not be changed")
26+
}
27+
}
28+
}
29+
30+
func TestSortIndexFloats(t *testing.T) {
31+
in := []float64{.9, .8, .7, .6, .5, .4, .10, -.1, -.2, -.4, .11, .13, .15, .100, .99}
32+
inCopy := make([]float64, len(in))
33+
copy(inCopy, in)
34+
out := SortIndex(in)
35+
for i := 1; i < len(out); i++ {
36+
if inCopy[out[i]] < inCopy[out[i-1]] {
37+
t.Fatalf("Unexpected output")
38+
}
39+
}
40+
for i := range in {
41+
if in[i] != inCopy[i] {
42+
t.Fatalf("The input slice should not be changed")
43+
}
44+
}
45+
}
46+
47+
func TestSortIndexSortInterface(t *testing.T) {
48+
in := testInt{9, 8, 7, 6, 5, 4, 10, -1, -2, -4, 11, 13, 15, 100, 99}
49+
inCopy := make(testInt, len(in))
50+
copy(inCopy, in)
51+
out := SortIndex(in)
52+
for i := 1; i < len(out); i++ {
53+
if inCopy[out[i]] < inCopy[out[i-1]] {
54+
t.Fatalf("Unexpected output")
55+
}
56+
}
57+
for i := range in {
58+
if in[i] != inCopy[i] {
59+
t.Fatalf("The input slice should not be changed")
60+
}
61+
}
62+
}

0 commit comments

Comments
 (0)