-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path05_split_dataset.go
68 lines (53 loc) · 1.27 KB
/
05_split_dataset.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package main
import (
"bufio"
"log"
"os"
"path/filepath"
"github.com/go-gota/gota/dataframe"
)
var (
fileName = "clean_loan_data.csv"
trainingName = "clean_loan_training.csv"
testName = "clean_loan_test.csv"
filePath = filepath.Join(os.Getenv("MLGO"), "storage", "data", fileName)
trainingPath = filepath.Join(os.Getenv("MLGO"), "storage", "data", trainingName)
testPath = filepath.Join(os.Getenv("MLGO"), "storage", "data", testName)
)
func main() {
f, err := os.Open(filePath)
if err != nil {
log.Fatal(err)
}
defer f.Close()
df := dataframe.ReadCSV(f)
trainingNum := (4 * df.Nrow()) / 5
testNum := df.Nrow() / 5
if trainingNum+testNum < df.Nrow() {
trainingNum++
}
trainingIdx := make([]int, trainingNum)
testIdx := make([]int, testNum)
for i := 0; i < trainingNum; i++ {
trainingIdx[i] = i
}
for i := 0; i < testNum; i++ {
testIdx[i] = trainingNum + i
}
trainingDF := df.Subset(trainingIdx)
testDF := df.Subset(testIdx)
setMap := map[int]dataframe.DataFrame{
0: trainingDF,
1: testDF,
}
for idx, setName := range []string{trainingPath, testPath} {
f, err := os.Create(setName)
if err != nil {
log.Fatal(err)
}
w := bufio.NewWriter(f)
if err := setMap[idx].WriteCSV(w); err != nil {
log.Fatal(err)
}
}
}