Skip to content

Commit a8658da

Browse files
Add iris classification using xgboost
1 parent c335a76 commit a8658da

File tree

2 files changed

+235
-22
lines changed

2 files changed

+235
-22
lines changed

examples/dataframes.jl

Lines changed: 230 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ using DataFrames
1010
# ╔═╡ 525b4126-e6fa-4616-9676-b8b6f31e403f
1111
using CSV
1212

13-
# ╔═╡ 9e807a3f-d861-4b98-a119-54b8b69bae27
13+
# ╔═╡ 93b98458-2288-48e5-acf2-0b849bc4aaba
1414
using Statistics
1515

16+
# ╔═╡ 7d05f65d-ff26-41a6-889e-b03df04dd48b
17+
using XGBoost
18+
1619
# ╔═╡ 884905d7-ef97-4b62-acc9-843ce2e79f79
1720
md"""
1821
# Julia DataFrames
@@ -77,36 +80,75 @@ iris[1:10, 1:4]
7780
# ╔═╡ e8cb2729-37ec-4f60-a529-d7b139ac5133
7881
iris[!, r"sepal_*"]
7982

80-
# ╔═╡ 41446fa3-fd89-4ee8-a3bc-045d9003e92c
81-
md"""
82-
## Mutating DataFrames
83-
"""
83+
# ╔═╡ f865cb0d-c4a5-4dcc-89ed-a44d0acf88d2
84+
iris[!, Not(:class)]
8485

85-
# ╔═╡ d8f04fc9-8ba6-48d5-8c64-2ef3103e5077
86+
# ╔═╡ 9db64168-338c-4bf7-b430-98c2efd39db3
8687
md"""
87-
## Missing values
88+
## Classifying with XGBoost
8889
"""
8990

90-
# ╔═╡ d5350cc4-f91a-4208-9a97-8440b7e637d9
91-
md"""
92-
## Broadcasting
93-
"""
91+
# ╔═╡ cbef1767-4c50-4f40-aa29-e9bddbaeb2a0
92+
import Random: shuffle
9493

95-
# ╔═╡ f7567427-9c14-462c-b507-368e561ed540
96-
md"""
97-
## Selectors and Transformation functions
98-
"""
94+
# ╔═╡ 226ecb13-12d4-407c-b58a-924c421e2423
95+
function split(df, at=0.8)
96+
df = shuffle(df)
97+
endpoint = round(Int, nrow(df) * at)
98+
df[1:endpoint, :], df[endpoint + 1:end, :]
99+
end
100+
101+
# ╔═╡ 1815eda4-0d09-4b47-ba41-5c568f6804bc
102+
train, test = split(iris)
103+
104+
# ╔═╡ d684258f-a215-4b1a-a6ab-8d4d114e7b16
105+
classes = unique(iris.class)
106+
107+
# ╔═╡ 42dc519a-506b-4013-bace-5c7f1f1c18f7
108+
encode(x) :: Float32 = findfirst(==(x), classes)
109+
110+
# ╔═╡ b9ee45ad-37a7-4f6f-a044-28dd60e04de6
111+
decode(x) = classes[round(Int, x)]
112+
113+
# ╔═╡ d96c37be-da2f-4c0a-93c4-9ebf7f060603
114+
train_labels = map(encode, train.class)
115+
116+
# ╔═╡ 1a112052-f7e3-4bc8-bec6-70fcaa28eb1f
117+
model = xgboost((train[:, Not(:class)], train_labels))
118+
119+
# ╔═╡ e36eecca-feb6-4c76-9a14-0ffb675e904d
120+
predictions = predict(model, test[:, Not(:class)])
121+
122+
# ╔═╡ d0000e98-e712-4769-94b5-dc918637d7d5
123+
rmse(xs, ys) = sqrt(mean((xs - ys) .^ 2))
124+
125+
# ╔═╡ 7557839c-a580-44c3-a135-b9384cff89c3
126+
rmse(encode.(test.class), predictions)
127+
128+
# ╔═╡ 73413e5f-5c7f-4955-9acb-f5c5c996a903
129+
count(encode.(test.class) .!= round.(predictions))
130+
131+
# ╔═╡ 697135f2-cf29-4769-bb75-0c615d30a2ab
132+
let
133+
df = copy(iris)
134+
predictions = decode.(predict(model, df[:, Not(:class)]))
135+
df.prediction = predictions
136+
df[df.class .!= df.prediction, :]
137+
end
99138

100139
# ╔═╡ 00000000-0000-0000-0000-000000000001
101140
PLUTO_PROJECT_TOML_CONTENTS = """
102141
[deps]
103142
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
104143
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
144+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
105145
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
146+
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
106147
107148
[compat]
108149
CSV = "~0.10.12"
109150
DataFrames = "~1.6.1"
151+
XGBoost = "~2.5.1"
110152
"""
111153

112154
# ╔═╡ 00000000-0000-0000-0000-000000000002
@@ -115,20 +157,46 @@ PLUTO_MANIFEST_TOML_CONTENTS = """
115157
116158
julia_version = "1.10.0"
117159
manifest_format = "2.0"
118-
project_hash = "851423e89ab26b4ffbe11b2caf1f9c12c4416e1f"
160+
project_hash = "f397abcce5970fb0bcf05ebb743a7ab5e4c03c76"
161+
162+
[[deps.AbstractTrees]]
163+
git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c"
164+
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
165+
version = "0.4.4"
166+
167+
[[deps.ArgTools]]
168+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
169+
version = "1.1.1"
119170
120171
[[deps.Artifacts]]
121172
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
122173
123174
[[deps.Base64]]
124175
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
125176
177+
[[deps.CEnum]]
178+
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
179+
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
180+
version = "0.5.0"
181+
126182
[[deps.CSV]]
127183
deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
128184
git-tree-sha1 = "679e69c611fff422038e9e21e270c4197d49d918"
129185
uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
130186
version = "0.10.12"
131187
188+
[[deps.CUDA_Driver_jll]]
189+
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
190+
git-tree-sha1 = "d01bfc999768f0a31ed36f5d22a76161fc63079c"
191+
uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
192+
version = "0.7.0+1"
193+
194+
[[deps.CUDA_Runtime_jll]]
195+
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
196+
git-tree-sha1 = "8e25c009d2bf16c2c31a70a6e9e8939f7325cc84"
197+
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
198+
version = "0.11.1+0"
199+
132200
[[deps.CodecZlib]]
133201
deps = ["TranscodingStreams", "Zlib_jll"]
134202
git-tree-sha1 = "cd67fc487743b2f0fd4380d4cbd3a24660d0eec8"
@@ -181,12 +249,20 @@ version = "1.0.0"
181249
deps = ["Printf"]
182250
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
183251
252+
[[deps.Downloads]]
253+
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
254+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
255+
version = "1.6.0"
256+
184257
[[deps.FilePathsBase]]
185258
deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"]
186259
git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa"
187260
uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
188261
version = "0.9.21"
189262
263+
[[deps.FileWatching]]
264+
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
265+
190266
[[deps.Future]]
191267
deps = ["Random"]
192268
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
@@ -211,11 +287,63 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
211287
uuid = "82899510-4779-5014-852e-03e436cf321d"
212288
version = "1.0.0"
213289
290+
[[deps.JLLWrappers]]
291+
deps = ["Artifacts", "Preferences"]
292+
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
293+
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
294+
version = "1.5.0"
295+
296+
[[deps.JSON3]]
297+
deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"]
298+
git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b"
299+
uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
300+
version = "1.14.0"
301+
302+
[deps.JSON3.extensions]
303+
JSON3ArrowExt = ["ArrowTypes"]
304+
305+
[deps.JSON3.weakdeps]
306+
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
307+
308+
[[deps.LLVMOpenMP_jll]]
309+
deps = ["Artifacts", "JLLWrappers", "Libdl"]
310+
git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713"
311+
uuid = "1d63c593-3942-5779-bab2-d838dc0a180e"
312+
version = "15.0.7+0"
313+
214314
[[deps.LaTeXStrings]]
215315
git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
216316
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
217317
version = "1.3.1"
218318
319+
[[deps.LazyArtifacts]]
320+
deps = ["Artifacts", "Pkg"]
321+
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
322+
323+
[[deps.LibCURL]]
324+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
325+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
326+
version = "0.6.4"
327+
328+
[[deps.LibCURL_jll]]
329+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
330+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
331+
version = "8.4.0+0"
332+
333+
[[deps.LibGit2]]
334+
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
335+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
336+
337+
[[deps.LibGit2_jll]]
338+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
339+
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
340+
version = "1.6.4+0"
341+
342+
[[deps.LibSSH2_jll]]
343+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
344+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
345+
version = "1.11.0+1"
346+
219347
[[deps.Libdl]]
220348
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
221349
@@ -230,6 +358,11 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
230358
deps = ["Base64"]
231359
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
232360
361+
[[deps.MbedTLS_jll]]
362+
deps = ["Artifacts", "Libdl"]
363+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
364+
version = "2.28.2+1"
365+
233366
[[deps.Missings]]
234367
deps = ["DataAPI"]
235368
git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272"
@@ -239,6 +372,14 @@ version = "1.1.0"
239372
[[deps.Mmap]]
240373
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
241374
375+
[[deps.MozillaCACerts_jll]]
376+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
377+
version = "2023.1.10"
378+
379+
[[deps.NetworkOptions]]
380+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
381+
version = "1.2.0"
382+
242383
[[deps.OpenBLAS_jll]]
243384
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
244385
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
@@ -255,6 +396,11 @@ git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
255396
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
256397
version = "2.8.1"
257398
399+
[[deps.Pkg]]
400+
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
401+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
402+
version = "1.10.0"
403+
258404
[[deps.PooledArrays]]
259405
deps = ["DataAPI", "Future"]
260406
git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3"
@@ -323,6 +469,12 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
323469
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
324470
version = "1.10.0"
325471
472+
[[deps.SparseMatricesCSR]]
473+
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
474+
git-tree-sha1 = "38677ca58e80b5cad2382e5a1848f93b054ad28d"
475+
uuid = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
476+
version = "0.6.7"
477+
326478
[[deps.Statistics]]
327479
deps = ["LinearAlgebra", "SparseArrays"]
328480
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -334,6 +486,16 @@ git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
334486
uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
335487
version = "0.3.4"
336488
489+
[[deps.StructTypes]]
490+
deps = ["Dates", "UUIDs"]
491+
git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70"
492+
uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
493+
version = "1.10.0"
494+
495+
[[deps.SuiteSparse]]
496+
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
497+
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
498+
337499
[[deps.SuiteSparse_jll]]
338500
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
339501
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
@@ -356,6 +518,11 @@ git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d"
356518
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
357519
version = "1.11.1"
358520
521+
[[deps.Tar]]
522+
deps = ["ArgTools", "SHA"]
523+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
524+
version = "1.10.0"
525+
359526
[[deps.Test]]
360527
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
361528
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -387,6 +554,26 @@ git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7"
387554
uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60"
388555
version = "1.6.1"
389556
557+
[[deps.XGBoost]]
558+
deps = ["AbstractTrees", "CEnum", "JSON3", "LinearAlgebra", "OrderedCollections", "SparseArrays", "SparseMatricesCSR", "Statistics", "Tables", "XGBoost_jll"]
559+
git-tree-sha1 = "bacb62e07d104630094c8dac2fd070f5d4b9b305"
560+
uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
561+
version = "2.5.1"
562+
563+
[deps.XGBoost.extensions]
564+
XGBoostCUDAExt = "CUDA"
565+
XGBoostTermExt = "Term"
566+
567+
[deps.XGBoost.weakdeps]
568+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
569+
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
570+
571+
[[deps.XGBoost_jll]]
572+
deps = ["Artifacts", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "Libdl", "TOML"]
573+
git-tree-sha1 = "1c0aa2390a7ebb28a3d6c214f64e57a24091fbd7"
574+
uuid = "a5c6f535-4255-5ca2-a466-0e519f119c46"
575+
version = "2.0.1+0"
576+
390577
[[deps.Zlib_jll]]
391578
deps = ["Libdl"]
392579
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
@@ -396,6 +583,16 @@ version = "1.2.13+1"
396583
deps = ["Artifacts", "Libdl"]
397584
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
398585
version = "5.8.0+1"
586+
587+
[[deps.nghttp2_jll]]
588+
deps = ["Artifacts", "Libdl"]
589+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
590+
version = "1.52.0+1"
591+
592+
[[deps.p7zip_jll]]
593+
deps = ["Artifacts", "Libdl"]
594+
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
595+
version = "17.4.0+2"
399596
"""
400597

401598
# ╔═╡ Cell order:
@@ -409,7 +606,7 @@ version = "5.8.0+1"
409606
# ╠═32944ad3-641d-4da4-80aa-e3fa29ecbbe9
410607
# ╟─77c66f83-b02c-4ebd-987b-45c3e8928a3b
411608
# ╠═d06db870-6995-4b32-a838-d0ce284bc6d6
412-
# ╠═9e807a3f-d861-4b98-a119-54b8b69bae27
609+
# ╠═93b98458-2288-48e5-acf2-0b849bc4aaba
413610
# ╠═bdba45b6-6e2f-449b-a1a6-c44023b78b81
414611
# ╠═3ed5a310-7853-437a-b01f-6ab6fa744472
415612
# ╟─8df9f8aa-3253-4f7e-a801-4cc64a38ed0c
@@ -418,9 +615,21 @@ version = "5.8.0+1"
418615
# ╠═94075008-cc6f-4bbe-a9bb-46d587349e39
419616
# ╠═58d46e1d-3b30-46f1-b142-a068c001a765
420617
# ╠═e8cb2729-37ec-4f60-a529-d7b139ac5133
421-
# ╟─41446fa3-fd89-4ee8-a3bc-045d9003e92c
422-
# ╟─d8f04fc9-8ba6-48d5-8c64-2ef3103e5077
423-
# ╟─d5350cc4-f91a-4208-9a97-8440b7e637d9
424-
# ╟─f7567427-9c14-462c-b507-368e561ed540
618+
# ╠═f865cb0d-c4a5-4dcc-89ed-a44d0acf88d2
619+
# ╟─9db64168-338c-4bf7-b430-98c2efd39db3
620+
# ╠═7d05f65d-ff26-41a6-889e-b03df04dd48b
621+
# ╠═cbef1767-4c50-4f40-aa29-e9bddbaeb2a0
622+
# ╠═226ecb13-12d4-407c-b58a-924c421e2423
623+
# ╠═1815eda4-0d09-4b47-ba41-5c568f6804bc
624+
# ╠═d684258f-a215-4b1a-a6ab-8d4d114e7b16
625+
# ╠═42dc519a-506b-4013-bace-5c7f1f1c18f7
626+
# ╠═b9ee45ad-37a7-4f6f-a044-28dd60e04de6
627+
# ╠═d96c37be-da2f-4c0a-93c4-9ebf7f060603
628+
# ╠═1a112052-f7e3-4bc8-bec6-70fcaa28eb1f
629+
# ╠═e36eecca-feb6-4c76-9a14-0ffb675e904d
630+
# ╠═d0000e98-e712-4769-94b5-dc918637d7d5
631+
# ╠═7557839c-a580-44c3-a135-b9384cff89c3
632+
# ╠═73413e5f-5c7f-4955-9acb-f5c5c996a903
633+
# ╠═697135f2-cf29-4769-bb75-0c615d30a2ab
425634
# ╟─00000000-0000-0000-0000-000000000001
426635
# ╟─00000000-0000-0000-0000-000000000002

0 commit comments

Comments
 (0)