-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.jl
265 lines (243 loc) · 6.95 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
all kinds of utility functions
"""
using LaTeXStrings, Printf, DelimitedFiles, JLD
using DataStructures
import Printf: Format, format # for string formatting
# get variable name
macro Name(arg)
string(arg)
end
unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a))) # ~~cursed technique~~ reversal of zip()
"""
transforms Integers to Binaries, mainly for faster slicing, given a vector with known length
"""
function int_to_bin(xint, len)
xbin = zeros(Bool, len)
xbin[xint] .= 1
return xbin
end
"""
inverse of above
"""
function bin_to_int(xbin)
return findall(xbin .== 1)
end
# script to write string given a vector{string}
function writestringline(strinput, filename; mode="w")
open(filename, mode) do io
str = ""
for s ∈ strinput
str*=s*"\t"
end
print(io, str*"\n")
end
end
# script to write to latex table, given a Matrix{Any}
function writelatextable(table, filename; hline = true)
open(filename, "w") do io
for i ∈ axes(table, 1)
str = ""
for j ∈ axes(table, 2)
str *= string(table[i, j])*"\t"*"& "
end
str = str[1:end-2]
str *= raw" \\ "
if hline
str *= raw" \hline"
end
str *= "\n"
print(io, str)
end
end
end
"""
reference function to query any type of data, the queries are similar to this
"""
function dataqueryref()
# query example:
for i ∈ axes(best10, 1)
for j ∈ axes(centers, 1)
if best10[i, 1] == centers[j,1] && best10[i, 2] == centers[j, 2]
push!(kidx, j)
break
end
end
end
# string formatting example:
table_null[2:end, end-1:end] = map(el -> @sprintf("%.3f", el), table_null[2:end, end-1:end])
# table header example:
table_exp[1,:] = [raw"\textbf{MAE}", raw"\textbf{\begin{tabular}[c]{@{}c@{}}Null \\ train \\MAE\end{tabular}}", raw"\textbf{model}", raw"$k$", raw"$f$", raw"$n_{af}$", raw"$n_{mf}$", raw"$t_s$", raw"$t_p$"]
# query from setup_info.txt example:
for i ∈ axes(table_k, 1)
table_k[i, 2] = datainfo[didx[(i-1)*5 + 1], 4]
end
# taking mean example:
for i ∈ axes(table_k, 1)
table_k[i, 5] = mean(atominfo[(i-1)*5+1:(i-1)*5+5, 4])
end
# filling str with latex interp example:
for i ∈ eachindex(cidx)
table_centers[1, i] = L"$k=$"*string(cidx[i])
end
end
# cleans the structure data of floats into 3 digits-behind-comma format
function clean_float(data)
return map(data) do el
s = ""
if occursin("e", string(el)) # retain scinetfigc notaioton
s = @sprintf "%.3e" el
else
s = @sprintf "%.3f" el
end
s
end
end
# more advanced string of float formatting, nwo can accept arbitrary decimal size
function format_string_float(n,x; scientific=false)
if scientific
FF = Format("%.$(n)e")
else
FF = Format("%.$(n)f")
end
format(FF, x)
end
"""
converts floats of x integers and y decimals to scientific notation "x.ye-n"
"""
function convert_to_scientific(data; n_decimals = 3)
return map(data) do el
s = @sprintf "%.3e" el
s
end
end
# add \ for latex underscore escape, only handles findfirst for now
function latex_(data)
return map(data) do el
id = findfirst("_", el)
if id !== nothing
s = el[1 : id[1]-1]*raw"\_"*el[id[1]+1 : end]
else
s = el
end
s
end
end
"""
wrap numbers in latex in chemical format, e.g., H2 -> H₂
ONLY supports single digit for now.........
"""
function latex_chemformat(s)
str = ""
for c ∈ s
if isdigit(c)
str*=raw"$_"*c*raw"$"
else
str*=c
end
end
return str
end
"""
wrap string in \textbf{}
"""
function latex_bold(s)
return raw"\textbf{"*s*raw"}"
end
"""
=== query functions ===
"""
"""
very specific function, may be changed at will
query the information from a table of the row with the minimum MAE
"""
function query_min_f(table; feature_type = "")
# get the min MAE:
indices = []
if !isempty(feature_type)
for i ∈ axes(table, 1)
if (table[i, 2] == feature_type)
push!(indices, i)
end
end
sliced = table[indices,:]
minid = argmin(sliced[:, 7])
selid = indices[minid] # assume 100 Ntrain is always the lowest MAE
else
selid = argmin(table[:, 7])
end
return selid
end
"""
query the row index of data by column info
params:
- colids = list of column ids
- coldatas = list of data entry corresponding to the colids
"""
function query_indices(tb, colids, coldatas)
ids = []
for i ∈ axes(tb, 1)
c = 0;
# loop all column ids:
for (j,colid) ∈ enumerate(colids)
if tb[i, colid] == coldatas[j]
c += 1
end
end
if c == length(colids)
push!(ids, i)
end
end
return ids
end
"""
more generic query min function, returns the minimum fobj given the selected columns
params:
- table: matrix containing info
- colids: list of column ids selected for query
- colnames: list of name of the columns in which we want to look the minimum from
"""
function query_min(table, colids, colnames, idselcol)
selids = query_indices(table, colids, colnames)
minid = argmin(table[selids, idselcol])
return selids[minid] # the nth index of the selected indices
end
function main_convert_datatype()
fpaths = ["ACSF_51", "SOAP", "FCHL19"]
for fp ∈ fpaths
f = load("data/"*fp*".jld", "data")
f = Matrix{Float64}.(f)
save("data/"*fp*".jld", "data", f)
end
end
"""
function to cache/track iterates, matches x then return new_fobj for cache hit
"""
function track_cache(path_tracker, fun_obj, x, f_id, x_ids::Vector{Int};
fun_params = (), fun_arg_params = Dict(),
fun_x_transform = nothing, fun_x_transform_params = [])
idx = nothing
if filesize(path_tracker) > 1
tracker = readdlm(path_tracker)
for i ∈ axes(tracker, 1)
if x == tracker[i, x_ids[1]:x_ids[2]]
idx = i
break
end
end
end
if idx !== nothing # if x is found in the repo, then just return the f given by the index
println("x found in tracker!")
new_fobj = tracker[idx, f_id]
else
println("x not found, computing f(x)")
if fun_x_transform !== nothing # tranform x
xtr = fun_x_transform(x, fun_x_transform_params...)
new_fobj = fun_obj(xtr, fun_params...; fun_arg_params...) # evaluate objective value
else
new_fobj = fun_obj(x, fun_params...; fun_arg_params...) # evaluate objective value
end
writestringline(string.(vcat(new_fobj, x)'), path_tracker; mode="a")
end
return new_fobj
end