-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMainConv.java
250 lines (209 loc) · 7.53 KB
/
MainConv.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import java.awt.image.BufferedImage;
import java.text.DecimalFormat;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import jnn.Funcional;
import jnn.camadas.*;
import jnn.core.tensor.Tensor;
import jnn.modelos.Modelo;
import jnn.modelos.Sequencial;
import jnn.serializacao.Serializador;
import lib.ged.Dados;
import lib.ged.Ged;
import lib.geim.Geim;
public class MainConv {
/**
* Gerenciador de dados.
*/
static Ged ged = new Ged();
/**
* Gerenciador de imagens.
*/
static Geim geim = new Geim();
/**
* Interface da biblioteca.
*/
static Funcional jnn = new Funcional();
// dados de controle
// += 4min15s - 500 amostras - 8 epocas - 32 lote
static final int NUM_DIGITOS_TREINO = 10;
static final int NUM_DIGITOS_TESTE = NUM_DIGITOS_TREINO;
static final int NUM_AMOSTRAS_TREINO = 500;
static final int NUM_AMOSTRAS_TESTE = 100;
static final int TREINO_EPOCAS = 8;
static final int TREINO_LOTE = 32;
static final boolean TREINO_LOGS = true;
// caminhos de arquivos externos
static final String CAMINHO_TREINO = "./dados/mnist/treino/";
static final String CAMINHO_TESTE = "./dados/mnist/teste/";
static final String CAMINHO_SAIDA_MODELO = "./dados/modelos/modelo-treinado.nn";
static final String CAMINHO_HISTORICO = "historico-perda";
public static void main(String[] args) {
ged.limparConsole();
final Tensor[] treinoX = jnn.arrayParaTensores(carregarDadosMNIST(CAMINHO_TREINO, NUM_AMOSTRAS_TREINO, NUM_DIGITOS_TREINO));
final Tensor[] treinoY = jnn.arrayParaTensores(criarRotulosMNIST(NUM_AMOSTRAS_TREINO, NUM_DIGITOS_TREINO));
Sequencial modelo = criarModelo();
modelo.setHistorico(true);
modelo.print();
System.out.println("Treinando.");
long tempo = System.nanoTime();
modelo.treinar(treinoX, treinoY, TREINO_EPOCAS, TREINO_LOTE, TREINO_LOGS);
tempo = System.nanoTime() - tempo;
long segundosTotais = TimeUnit.NANOSECONDS.toSeconds(tempo);
long horas = segundosTotais / 3600;
long minutos = (segundosTotais % 3600) / 60;
long segundos = segundosTotais % 60;
System.out.println("\nTempo de treino: " + horas + "h " + minutos + "min " + segundos + "s");
System.out.print("Treino -> perda: " + modelo.avaliar(treinoX, treinoY).item() + " - ");
System.out.println("acurácia: " + formatarDecimal((modelo.avaliador().acuracia(treinoX, treinoY).item() * 100), 4) + "%");
System.out.println("\nCarregando dados de teste.");
final Tensor[] testeX = jnn.arrayParaTensores(carregarDadosMNIST(CAMINHO_TESTE, NUM_AMOSTRAS_TESTE, NUM_DIGITOS_TESTE));
final Tensor[] testeY = jnn.arrayParaTensores(criarRotulosMNIST(NUM_AMOSTRAS_TESTE, NUM_DIGITOS_TESTE));
System.out.print("Teste -> perda: " + modelo.avaliar(testeX, testeY).item() + " - ");
System.out.println("acurácia: " + formatarDecimal((modelo.avaliador().acuracia(testeX, testeY).item() * 100), 4) + "%");
exportarHistorico(modelo, CAMINHO_HISTORICO);
salvarModelo(modelo, CAMINHO_SAIDA_MODELO);
MainImg.executarComando("python grafico.py " + CAMINHO_HISTORICO);
}
/*
* Criação de modelos para testes.
*/
static Sequencial criarModelo() {
Sequencial modelo = new Sequencial(
new Entrada(1, 28, 28),
new Conv2D(20, new int[]{3, 3}, "relu"),
new MaxPool2D(new int[]{2, 2}),
new Conv2D(20, new int[]{3, 3}, "relu"),
new MaxPool2D(new int[]{2, 2}),
new Flatten(),
new Densa(100, "relu"),
new Dropout(0.5),
new Densa(NUM_DIGITOS_TREINO, "softmax")
);
modelo.compilar("adam", "entropia-cruzada");
return modelo;
}
/**
* Salva o modelo num arquivo externo.
* @param modelo instância de um modelo sequencial.
* @param caminho caminho de destino.
*/
static void salvarModelo(Sequencial modelo, String caminho) {
String tipo = "double";
System.out.println("Exportando modelo (" + tipo + ").");
new Serializador().salvar(modelo, caminho, tipo);
}
/**
* Converte uma imagem numa matriz contendo seus valores de brilho entre 0 e 1.
* @param caminho caminho da imagem.
* @return matriz contendo os valores de brilho da imagem.
*/
static double[][] imagemParaMatriz(String caminho) {
BufferedImage img = geim.lerImagem(caminho);
double[][] imagem = new double[img.getHeight()][img.getWidth()];
int[][] cinza = geim.obterCinza(img);
for (int y = 0; y < imagem.length; y++) {
for (int x = 0; x < imagem[y].length; x++) {
imagem[y][x] = (double)cinza[y][x] / 255.0;
}
}
return imagem;
}
/**
* Carrega as imagens do conjunto de dados {@code MNIST}.
* <p>
* Nota
* </p>
* O diretório deve conter subdiretórios, cada um contendo o conjunto de
* imagens de cada dígito, exemplo:
* <pre>
*"mnist/treino/0"
*"mnist/treino/1"
*"mnist/treino/2"
*"mnist/treino/3"
*"mnist/treino/4"
*"mnist/treino/5"
*"mnist/treino/6"
*"mnist/treino/7"
*"mnist/treino/8"
*"mnist/treino/9"
* </pre>
* @param caminho caminho do diretório das imagens.
* @param amostras quantidade de amostras por dígito
* @param digitos quantidade de dígitos, iniciando do dígito 0.
* @return dados carregados.
*/
static double[][][][] carregarDadosMNIST(String caminho, int amostras, int digitos) {
final double[][][][] imagens = new double[digitos * amostras][1][][];
final int numThreads = Runtime.getRuntime().availableProcessors() / 2;
try (ExecutorService exec = Executors.newFixedThreadPool(numThreads)) {
int id = 0;
for (int digito = 0; digito < digitos; digito++) {
for (int amostra = 0; amostra < amostras; amostra++) {
final String caminhoCompleto = caminho + digito + "/img_" + amostra + ".jpg";
final int indice = id;
exec.submit(() -> {
try {
double[][] imagem = imagemParaMatriz(caminhoCompleto);
imagens[indice][0] = imagem;
} catch (Exception e) {
System.out.println(e.getMessage());
System.exit(1);
}
});
id++;
}
}
} catch (Exception e) {
System.out.println(e.getMessage());
}
System.out.println("Imagens carregadas (" + imagens.length + ").");
return imagens;
}
/**
* Gera os rótulos do conjunto de dados {@code MNIST}.
* @param amostras quantidades de amostras por dítigo.
* @param digitos quantidade de dítigos, começando do 0.
* @return dados carregados.
*/
static double[][] criarRotulosMNIST(int amostras, int digitos) {
double[][] rotulos = new double[digitos * amostras][digitos];
for (int numero = 0; numero < digitos; numero++) {
for (int i = 0; i < amostras; i++) {
int indice = numero * amostras + i;
rotulos[indice][numero] = 1;
}
}
System.out.println("Rótulos gerados de 0 a " + (digitos-1) + ".");
return rotulos;
}
/**
* Formata o valor recebido para a quantidade de casas após o ponto
* flutuante.
* @param valor valor alvo.
* @param casas quantidade de casas após o ponto flutuante.
* @return
*/
static String formatarDecimal(double valor, int casas) {
String formato = "#.";
for (int i = 0; i < casas; i++) formato += "#";
DecimalFormat df = new DecimalFormat(formato);
return df.format(valor);
}
/**
* Salva um arquivo csv com o historico de desempenho do modelo.
* @param modelo modelo.
* @param caminho caminho onde será salvo o arquivo.
*/
static void exportarHistorico(Modelo modelo, String caminho) {
System.out.println("Exportando histórico de perda");
double[] perdas = modelo.hist();
double[][] dadosPerdas = new double[perdas.length][1];
for (int i = 0; i < dadosPerdas.length; i++) {
dadosPerdas[i][0] = perdas[i];
}
Dados dados = new Dados(dadosPerdas);
ged.exportarCsv(dados, caminho);
}
}