-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtable_test.go
45 lines (35 loc) · 1.1 KB
/
table_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package q_test
import (
"testing"
. "github.com/aunum/gold/pkg/v1/agent/q"
"github.com/stretchr/testify/require"
"gorgonia.org/tensor"
)
func TestMemTable(t *testing.T) {
actionSpaceSize := 6
table := NewMemTable(actionSpaceSize)
observation1 := tensor.New(tensor.WithShape(2, 4), tensor.WithBacking(tensor.Range(tensor.Float32, 0, 8)))
qVal1 := float32(0.5)
action1 := 0
state1 := HashState(observation1)
err := table.Set(state1, action1, qVal1)
require.Nil(t, err)
qRes1, err := table.Get(state1, action1)
require.Equal(t, qVal1, qRes1)
require.NoError(t, err)
observation2 := tensor.New(tensor.WithShape(2, 4), tensor.WithBacking(tensor.Range(tensor.Float32, 8, 16)))
qVal2 := float32(0.2)
action2 := 1
state2 := HashState(observation2)
err = table.Set(state2, action2, qVal2)
require.Nil(t, err)
qRes2, err := table.Get(state2, action2)
require.Equal(t, qVal2, qRes2)
require.NoError(t, err)
err = table.Set(state1, action2, qVal2)
require.NoError(t, err)
action, qval, err := table.GetMax(state1)
require.Nil(t, err)
require.Equal(t, action, 0)
require.Equal(t, qval, float32(0.5))
}