-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_GAN.d2
120 lines (118 loc) · 2.46 KB
/
train_GAN.d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
train_GAN(): {
grid-columns: 3
segmented_training_epoch(): {
# training the seg_model
seg_train_dl: {
raw image: {
text: |md
n input channels
|
}
mask: {
mask: |md
m output classes
|
}
}
seg_model: {
text: |md n -> m |
}
disc_model: {
text: |md m-> 1 |
}
disc_loss: {
text: |md 1-> 1 |
}
seg_gen_loss: {
BCELoss
adv_loss
}
seg_train_dl.raw image -> seg_model
seg_model -> seg_gen_loss.BCELoss
seg_train_dl.mask -> seg_gen_loss.BCELoss
disc_model -> disc_loss
disc_model -> seg_gen_loss.adv_loss: {
style: {
stroke: purple
}
}
seg_model -> disc_model: {
style: {
stroke: purple
}
}
(seg_gen_loss -> seg_model).style.stroke: red
# training the disc_model
seg_train_dl.mask -> disc_model -> disc_loss
(disc_loss -> disc_model).style.stroke: red
(disc_loss -> disc_model).style.stroke: red
(disc_loss -> raw_disc_loss).style.stroke: green
(disc_loss -> seg_disc_loss).style.stroke: green
(seg_gen_loss -> sup_seg_loss).style.stroke: green
}
unsegmented_training_epoch(): {
# training the seg_model
unseg_train_dl: {
raw image: {
text: |md n input channels |
}
}
seg_model: {
text: |md n -> m |
}
disc_model: {
test: |md m->1 |
}
unseg_train_dl.raw image -> seg_model
seg_model -> disc_model
disc_model -> unseg_gen_loss: {
style: {
stroke: purple
}
}
seg_model -> unseg_gen_loss
(unseg_gen_loss -> seg_model).style.stroke: red
(unseg_gen_loss -> unsup_seg_loss).style.stroke: green
}
validation_epoch(): {
seg_val_dl: {
raw image: {
text: |md
n input channels
|
}
mask: {
mask: |md
m output classes
|
}
}
seg_model: {
text: |md n -> m |
}
seg_gen_loss: {
BCELoss
adv_loss
}
seg_model -> disc_model -> seg_gen_loss.adv_loss
seg_val_dl.raw image -> seg_model
seg_model -> seg_gen_loss.BCELoss
seg_val_dl.mask -> seg_gen_loss.BCELoss
(seg_gen_loss -> val_seg_loss).style.stroke: green
}
}
Legend: {
grid-rows: 2
text1: |md
blue: forward pass
|
text2: |md
red: backpropogation
|
text3: |md
purple: detatch to avoid backpropogation
|
text4: |md
green: save for ploting
|
}