@@ -10,9 +10,12 @@ using DataFrames
10
10
# ╔═╡ 525b4126-e6fa-4616-9676-b8b6f31e403f
11
11
using CSV
12
12
13
- # ╔═╡ 9e807a3f-d861-4b98-a119-54b8b69bae27
13
+ # ╔═╡ 93b98458-2288-48e5-acf2-0b849bc4aaba
14
14
using Statistics
15
15
16
+ # ╔═╡ 7d05f65d-ff26-41a6-889e-b03df04dd48b
17
+ using XGBoost
18
+
16
19
# ╔═╡ 884905d7-ef97-4b62-acc9-843ce2e79f79
17
20
md """
18
21
# Julia DataFrames
@@ -77,36 +80,75 @@ iris[1:10, 1:4]
77
80
# ╔═╡ e8cb2729-37ec-4f60-a529-d7b139ac5133
78
81
iris[! , r" sepal_*" ]
79
82
80
- # ╔═╡ 41446fa3-fd89-4ee8-a3bc-045d9003e92c
81
- md """
82
- ## Mutating DataFrames
83
- """
83
+ # ╔═╡ f865cb0d-c4a5-4dcc-89ed-a44d0acf88d2
84
+ iris[! , Not (:class )]
84
85
85
- # ╔═╡ d8f04fc9-8ba6-48d5-8c64-2ef3103e5077
86
+ # ╔═╡ 9db64168-338c-4bf7-b430-98c2efd39db3
86
87
md """
87
- ## Missing values
88
+ ## Classifying with XGBoost
88
89
"""
89
90
90
- # ╔═╡ d5350cc4-f91a-4208-9a97-8440b7e637d9
91
- md """
92
- ## Broadcasting
93
- """
91
+ # ╔═╡ cbef1767-4c50-4f40-aa29-e9bddbaeb2a0
92
+ import Random: shuffle
94
93
95
- # ╔═╡ f7567427-9c14-462c-b507-368e561ed540
96
- md """
97
- ## Selectors and Transformation functions
98
- """
94
+ # ╔═╡ 226ecb13-12d4-407c-b58a-924c421e2423
95
+ function split (df, at= 0.8 )
96
+ df = shuffle (df)
97
+ endpoint = round (Int, nrow (df) * at)
98
+ df[1 : endpoint, :], df[endpoint + 1 : end , :]
99
+ end
100
+
101
+ # ╔═╡ 1815eda4-0d09-4b47-ba41-5c568f6804bc
102
+ train, test = split (iris)
103
+
104
+ # ╔═╡ d684258f-a215-4b1a-a6ab-8d4d114e7b16
105
+ classes = unique (iris. class)
106
+
107
+ # ╔═╡ 42dc519a-506b-4013-bace-5c7f1f1c18f7
108
+ encode (x) :: Float32 = findfirst (== (x), classes)
109
+
110
+ # ╔═╡ b9ee45ad-37a7-4f6f-a044-28dd60e04de6
111
+ decode (x) = classes[round (Int, x)]
112
+
113
+ # ╔═╡ d96c37be-da2f-4c0a-93c4-9ebf7f060603
114
+ train_labels = map (encode, train. class)
115
+
116
+ # ╔═╡ 1a112052-f7e3-4bc8-bec6-70fcaa28eb1f
117
+ model = xgboost ((train[:, Not (:class )], train_labels))
118
+
119
+ # ╔═╡ e36eecca-feb6-4c76-9a14-0ffb675e904d
120
+ predictions = predict (model, test[:, Not (:class )])
121
+
122
+ # ╔═╡ d0000e98-e712-4769-94b5-dc918637d7d5
123
+ rmse (xs, ys) = sqrt (mean ((xs - ys) .^ 2 ))
124
+
125
+ # ╔═╡ 7557839c-a580-44c3-a135-b9384cff89c3
126
+ rmse (encode .(test. class), predictions)
127
+
128
+ # ╔═╡ 73413e5f-5c7f-4955-9acb-f5c5c996a903
129
+ count (encode .(test. class) .!= round .(predictions))
130
+
131
+ # ╔═╡ 697135f2-cf29-4769-bb75-0c615d30a2ab
132
+ let
133
+ df = copy (iris)
134
+ predictions = decode .(predict (model, df[:, Not (:class )]))
135
+ df. prediction = predictions
136
+ df[df. class .!= df. prediction, :]
137
+ end
99
138
100
139
# ╔═╡ 00000000-0000-0000-0000-000000000001
101
140
PLUTO_PROJECT_TOML_CONTENTS = """
102
141
[deps]
103
142
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
104
143
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
144
+ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
105
145
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
146
+ XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
106
147
107
148
[compat]
108
149
CSV = "~0.10.12"
109
150
DataFrames = "~1.6.1"
151
+ XGBoost = "~2.5.1"
110
152
"""
111
153
112
154
# ╔═╡ 00000000-0000-0000-0000-000000000002
@@ -115,20 +157,46 @@ PLUTO_MANIFEST_TOML_CONTENTS = """
115
157
116
158
julia_version = "1.10.0"
117
159
manifest_format = "2.0"
118
- project_hash = "851423e89ab26b4ffbe11b2caf1f9c12c4416e1f"
160
+ project_hash = "f397abcce5970fb0bcf05ebb743a7ab5e4c03c76"
161
+
162
+ [[deps.AbstractTrees]]
163
+ git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c"
164
+ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
165
+ version = "0.4.4"
166
+
167
+ [[deps.ArgTools]]
168
+ uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
169
+ version = "1.1.1"
119
170
120
171
[[deps.Artifacts]]
121
172
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
122
173
123
174
[[deps.Base64]]
124
175
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
125
176
177
+ [[deps.CEnum]]
178
+ git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
179
+ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
180
+ version = "0.5.0"
181
+
126
182
[[deps.CSV]]
127
183
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
128
184
git-tree-sha1 = "679e69c611fff422038e9e21e270c4197d49d918"
129
185
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
130
186
version = "0.10.12"
131
187
188
+ [[deps.CUDA_Driver_jll]]
189
+ deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
190
+ git-tree-sha1 = "d01bfc999768f0a31ed36f5d22a76161fc63079c"
191
+ uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
192
+ version = "0.7.0+1"
193
+
194
+ [[deps.CUDA_Runtime_jll]]
195
+ deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
196
+ git-tree-sha1 = "8e25c009d2bf16c2c31a70a6e9e8939f7325cc84"
197
+ uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
198
+ version = "0.11.1+0"
199
+
132
200
[[deps.CodecZlib]]
133
201
deps = ["TranscodingStreams", "Zlib_jll"]
134
202
git-tree-sha1 = "cd67fc487743b2f0fd4380d4cbd3a24660d0eec8"
@@ -181,12 +249,20 @@ version = "1.0.0"
181
249
deps = ["Printf"]
182
250
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
183
251
252
+ [[deps.Downloads]]
253
+ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
254
+ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
255
+ version = "1.6.0"
256
+
184
257
[[deps.FilePathsBase]]
185
258
deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"]
186
259
git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa"
187
260
uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
188
261
version = "0.9.21"
189
262
263
+ [[deps.FileWatching]]
264
+ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
265
+
190
266
[[deps.Future]]
191
267
deps = ["Random"]
192
268
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
@@ -211,11 +287,63 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
211
287
uuid = "82899510-4779-5014-852e-03e436cf321d"
212
288
version = "1.0.0"
213
289
290
+ [[deps.JLLWrappers]]
291
+ deps = ["Artifacts", "Preferences"]
292
+ git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
293
+ uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
294
+ version = "1.5.0"
295
+
296
+ [[deps.JSON3]]
297
+ deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"]
298
+ git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b"
299
+ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
300
+ version = "1.14.0"
301
+
302
+ [deps.JSON3.extensions]
303
+ JSON3ArrowExt = ["ArrowTypes"]
304
+
305
+ [deps.JSON3.weakdeps]
306
+ ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
307
+
308
+ [[deps.LLVMOpenMP_jll]]
309
+ deps = ["Artifacts", "JLLWrappers", "Libdl"]
310
+ git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713"
311
+ uuid = "1d63c593-3942-5779-bab2-d838dc0a180e"
312
+ version = "15.0.7+0"
313
+
214
314
[[deps.LaTeXStrings]]
215
315
git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
216
316
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
217
317
version = "1.3.1"
218
318
319
+ [[deps.LazyArtifacts]]
320
+ deps = ["Artifacts", "Pkg"]
321
+ uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
322
+
323
+ [[deps.LibCURL]]
324
+ deps = ["LibCURL_jll", "MozillaCACerts_jll"]
325
+ uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
326
+ version = "0.6.4"
327
+
328
+ [[deps.LibCURL_jll]]
329
+ deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
330
+ uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
331
+ version = "8.4.0+0"
332
+
333
+ [[deps.LibGit2]]
334
+ deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
335
+ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
336
+
337
+ [[deps.LibGit2_jll]]
338
+ deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
339
+ uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
340
+ version = "1.6.4+0"
341
+
342
+ [[deps.LibSSH2_jll]]
343
+ deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
344
+ uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
345
+ version = "1.11.0+1"
346
+
219
347
[[deps.Libdl]]
220
348
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
221
349
@@ -230,6 +358,11 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
230
358
deps = ["Base64"]
231
359
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
232
360
361
+ [[deps.MbedTLS_jll]]
362
+ deps = ["Artifacts", "Libdl"]
363
+ uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
364
+ version = "2.28.2+1"
365
+
233
366
[[deps.Missings]]
234
367
deps = ["DataAPI"]
235
368
git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272"
@@ -239,6 +372,14 @@ version = "1.1.0"
239
372
[[deps.Mmap]]
240
373
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
241
374
375
+ [[deps.MozillaCACerts_jll]]
376
+ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
377
+ version = "2023.1.10"
378
+
379
+ [[deps.NetworkOptions]]
380
+ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
381
+ version = "1.2.0"
382
+
242
383
[[deps.OpenBLAS_jll]]
243
384
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
244
385
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
@@ -255,6 +396,11 @@ git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
255
396
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
256
397
version = "2.8.1"
257
398
399
+ [[deps.Pkg]]
400
+ deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
401
+ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
402
+ version = "1.10.0"
403
+
258
404
[[deps.PooledArrays]]
259
405
deps = ["DataAPI", "Future"]
260
406
git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3"
@@ -323,6 +469,12 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
323
469
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
324
470
version = "1.10.0"
325
471
472
+ [[deps.SparseMatricesCSR]]
473
+ deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
474
+ git-tree-sha1 = "38677ca58e80b5cad2382e5a1848f93b054ad28d"
475
+ uuid = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
476
+ version = "0.6.7"
477
+
326
478
[[deps.Statistics]]
327
479
deps = ["LinearAlgebra", "SparseArrays"]
328
480
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -334,6 +486,16 @@ git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
334
486
uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
335
487
version = "0.3.4"
336
488
489
+ [[deps.StructTypes]]
490
+ deps = ["Dates", "UUIDs"]
491
+ git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70"
492
+ uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
493
+ version = "1.10.0"
494
+
495
+ [[deps.SuiteSparse]]
496
+ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
497
+ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
498
+
337
499
[[deps.SuiteSparse_jll]]
338
500
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
339
501
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
@@ -356,6 +518,11 @@ git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d"
356
518
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
357
519
version = "1.11.1"
358
520
521
+ [[deps.Tar]]
522
+ deps = ["ArgTools", "SHA"]
523
+ uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
524
+ version = "1.10.0"
525
+
359
526
[[deps.Test]]
360
527
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
361
528
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -387,6 +554,26 @@ git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7"
387
554
uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60"
388
555
version = "1.6.1"
389
556
557
+ [[deps.XGBoost]]
558
+ deps = ["AbstractTrees", "CEnum", "JSON3", "LinearAlgebra", "OrderedCollections", "SparseArrays", "SparseMatricesCSR", "Statistics", "Tables", "XGBoost_jll"]
559
+ git-tree-sha1 = "bacb62e07d104630094c8dac2fd070f5d4b9b305"
560
+ uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
561
+ version = "2.5.1"
562
+
563
+ [deps.XGBoost.extensions]
564
+ XGBoostCUDAExt = "CUDA"
565
+ XGBoostTermExt = "Term"
566
+
567
+ [deps.XGBoost.weakdeps]
568
+ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
569
+ Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
570
+
571
+ [[deps.XGBoost_jll]]
572
+ deps = ["Artifacts", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "Libdl", "TOML"]
573
+ git-tree-sha1 = "1c0aa2390a7ebb28a3d6c214f64e57a24091fbd7"
574
+ uuid = "a5c6f535-4255-5ca2-a466-0e519f119c46"
575
+ version = "2.0.1+0"
576
+
390
577
[[deps.Zlib_jll]]
391
578
deps = ["Libdl"]
392
579
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
@@ -396,6 +583,16 @@ version = "1.2.13+1"
396
583
deps = ["Artifacts", "Libdl"]
397
584
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
398
585
version = "5.8.0+1"
586
+
587
+ [[deps.nghttp2_jll]]
588
+ deps = ["Artifacts", "Libdl"]
589
+ uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
590
+ version = "1.52.0+1"
591
+
592
+ [[deps.p7zip_jll]]
593
+ deps = ["Artifacts", "Libdl"]
594
+ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
595
+ version = "17.4.0+2"
399
596
"""
400
597
401
598
# ╔═╡ Cell order:
@@ -409,7 +606,7 @@ version = "5.8.0+1"
409
606
# ╠═32944ad3-641d-4da4-80aa-e3fa29ecbbe9
410
607
# ╟─77c66f83-b02c-4ebd-987b-45c3e8928a3b
411
608
# ╠═d06db870-6995-4b32-a838-d0ce284bc6d6
412
- # ╠═9e807a3f-d861-4b98-a119-54b8b69bae27
609
+ # ╠═93b98458-2288-48e5-acf2-0b849bc4aaba
413
610
# ╠═bdba45b6-6e2f-449b-a1a6-c44023b78b81
414
611
# ╠═3ed5a310-7853-437a-b01f-6ab6fa744472
415
612
# ╟─8df9f8aa-3253-4f7e-a801-4cc64a38ed0c
@@ -418,9 +615,21 @@ version = "5.8.0+1"
418
615
# ╠═94075008-cc6f-4bbe-a9bb-46d587349e39
419
616
# ╠═58d46e1d-3b30-46f1-b142-a068c001a765
420
617
# ╠═e8cb2729-37ec-4f60-a529-d7b139ac5133
421
- # ╟─41446fa3-fd89-4ee8-a3bc-045d9003e92c
422
- # ╟─d8f04fc9-8ba6-48d5-8c64-2ef3103e5077
423
- # ╟─d5350cc4-f91a-4208-9a97-8440b7e637d9
424
- # ╟─f7567427-9c14-462c-b507-368e561ed540
618
+ # ╠═f865cb0d-c4a5-4dcc-89ed-a44d0acf88d2
619
+ # ╟─9db64168-338c-4bf7-b430-98c2efd39db3
620
+ # ╠═7d05f65d-ff26-41a6-889e-b03df04dd48b
621
+ # ╠═cbef1767-4c50-4f40-aa29-e9bddbaeb2a0
622
+ # ╠═226ecb13-12d4-407c-b58a-924c421e2423
623
+ # ╠═1815eda4-0d09-4b47-ba41-5c568f6804bc
624
+ # ╠═d684258f-a215-4b1a-a6ab-8d4d114e7b16
625
+ # ╠═42dc519a-506b-4013-bace-5c7f1f1c18f7
626
+ # ╠═b9ee45ad-37a7-4f6f-a044-28dd60e04de6
627
+ # ╠═d96c37be-da2f-4c0a-93c4-9ebf7f060603
628
+ # ╠═1a112052-f7e3-4bc8-bec6-70fcaa28eb1f
629
+ # ╠═e36eecca-feb6-4c76-9a14-0ffb675e904d
630
+ # ╠═d0000e98-e712-4769-94b5-dc918637d7d5
631
+ # ╠═7557839c-a580-44c3-a135-b9384cff89c3
632
+ # ╠═73413e5f-5c7f-4955-9acb-f5c5c996a903
633
+ # ╠═697135f2-cf29-4769-bb75-0c615d30a2ab
425
634
# ╟─00000000-0000-0000-0000-000000000001
426
635
# ╟─00000000-0000-0000-0000-000000000002
0 commit comments