diff --git a/Cargo.lock b/Cargo.lock index 48e037a..f595bfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,27 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ahash" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217" + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "autocfg" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" + [[package]] name = "bitflags" version = "1.2.1" @@ -14,6 +35,18 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +[[package]] +name = "bytemuck" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" + +[[package]] +name = "cc" +version = "1.0.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" + [[package]] name = "cfg-if" version = "1.0.0" @@ -31,6 +64,74 @@ dependencies = [ "winapi", ] +[[package]] +name = "colored" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +dependencies = [ + "lazy_static", + "windows-sys", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + [[package]] name = "endian-type" version = "0.1.2" @@ -70,15 +171,54 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.1.16" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", "wasi", ] +[[package]] +name = "ggml" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c7ab58e14ee56b10c892506eaa1e7d6524b95924b85c53d805830207dde1d0" +dependencies = [ + "ggml-sys", + "thiserror", +] + +[[package]] +name = "ggml-sys" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78f1b6a53a659486cf4938f31993295be0189da3acd5b35ea2f41d91957e5d83" +dependencies = [ + "cc", +] + +[[package]] +name = "half" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5eceaaeec696539ddaf7b333340f1af35a5aa87ae3e4f3ead0532f72affab2e" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "hashbrown" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96282e96bfcd3da0d3aa9938bedf1e50df3269b6db08b4876d2da0bb1a0841cf" +dependencies = [ + "ahash", + "autocfg", +] + [[package]] name = "home" version = "0.5.5" @@ -88,6 +228,27 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "itertools" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.149" @@ -100,6 +261,93 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" +[[package]] +name = "llm" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ef1f3d703cd164bc7efd0b9e9b2268fb6e8a7c080cac719dd953e5e6e0ba42f" +dependencies = [ + "llm-base", + "llm-bloom", + "llm-gpt2", + "llm-gptj", + "llm-llama", + "llm-neox", +] + +[[package]] +name = "llm-base" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79756dcf95aec564245d422a4e7e6552fe53432bf3e42a377a4e26450b6e8da9" +dependencies = [ + "bytemuck", + "ggml", + "half", + "memmap2", + "partial_sort", + "rand", + "serde", + "serde_bytes", + "thiserror", +] + +[[package]] +name = "llm-bloom" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd53ec4a68b487f07ab79f2ab965ce914b6acb4b2f2eebc80e6c9c66d4265ab" +dependencies = [ + "bytemuck", + "llm-base", +] + +[[package]] +name = "llm-gpt2" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a5cb7a87fd0305c0f4e9aaf286b35a49254434755c3a13e44f95880c0522af" +dependencies = [ + "bytemuck", + "llm-base", +] + +[[package]] +name = "llm-gptj" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6eb2d600244e1d7280bd4e623530ff6a4096985689f1ce1c43dde9aa6ed1152" +dependencies = [ + "bytemuck", + "llm-base", +] + +[[package]] +name = "llm-llama" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0f0fc6123ad40a602584f9989608f47fe74fa7e546712997b30e0b149c3c34d" +dependencies = [ + "bytemuck", + "llm-base", + "protobuf", + "rand", + "rust_tokenizers", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "llm-neox" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff102bdd82b67e8682026cd1d921d087659f2c13d00084322b387e9e5f6b3bab" +dependencies = [ + "bytemuck", + "llm-base", +] + [[package]] name = "log" version = "0.4.20" @@ -112,6 +360,15 @@ version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -132,6 +389,24 @@ dependencies = [ "libc", ] +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "partial_sort" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7924d1d0ad836f665c9065e26d016c673ece3993f30d340068b16f282afc1156" + +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -140,18 +415,24 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] +[[package]] +name = "protobuf" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e86d370532557ae7573551a1ec8235a0f8d6cb276c7c9e6aa490b511c447485" + [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -168,22 +449,20 @@ dependencies = [ [[package]] name = "rand" -version = "0.7.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "getrandom", "libc", "rand_chacha", "rand_core", - "rand_hc", ] [[package]] name = "rand_chacha" -version = "0.2.2" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", "rand_core", @@ -191,20 +470,79 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.5.1" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ "getrandom", ] [[package]] -name = "rand_hc" -version = "0.2.0" +name = "rayon" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ - "rand_core", + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + +[[package]] +name = "rust_tokenizers" +version = "3.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c4313059ea8764ff2743ffaaa42fba0e4d5f8ff12febe4f3c74d598f629f62" +dependencies = [ + "csv", + "hashbrown", + "itertools", + "lazy_static", + "protobuf", + "rayon", + "regex", + "serde", + "serde_json", + "unicode-normalization", + "unicode-normalization-alignments", ] [[package]] @@ -250,21 +588,78 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db9dfbf470021de34cfaf6983067f460ea19164934a7c2d4b92eec0968eb95f1" dependencies = [ "quote", - "syn", + "syn 1.0.109", ] +[[package]] +name = "ryu" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "serde" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_bytes" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b8497c313fd43ab992087548117643f6fcd935cbf36f176ffda0aacf9591734" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.58", +] + +[[package]] +name = "serde_json" +version = "1.0.115" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "smallvec" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +[[package]] +name = "spinoff" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fee259f96b31e7a18657d11741fe30d63f98e07de70e7a19d2b705ab9b331cdc" +dependencies = [ + "colored", + "once_cell", + "paste", +] + [[package]] name = "str-buf" version = "1.0.6" @@ -282,11 +677,45 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.58", +] + [[package]] name = "three_body_interpreter" version = "0.4.5" dependencies = [ + "llm", + "llm-base", "rand", + "spinoff", "unicode-xid", ] @@ -299,12 +728,45 @@ dependencies = [ "three_body_interpreter", ] +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + [[package]] name = "unicode-segmentation" version = "1.10.1" @@ -331,9 +793,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" +version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "winapi" diff --git a/interpreter/Cargo.toml b/interpreter/Cargo.toml index c3b3630..4cde771 100644 --- a/interpreter/Cargo.toml +++ b/interpreter/Cargo.toml @@ -12,4 +12,8 @@ license = "MIT" unicode-xid = { version = "0.2.1" } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -rand = { version = "0.7.3" } +rand = { version = "0.8.5" } + +llm = { version = "0.1.1" } +llm-base = { version = "0.1.1" } +spinoff = { version = "0.7.0", default-features = false, features = ["dots", "arc", "line"] } diff --git a/interpreter/src/evaluator/builtins.rs b/interpreter/src/evaluator/builtins.rs index 10c3287..c0f2889 100644 --- a/interpreter/src/evaluator/builtins.rs +++ b/interpreter/src/evaluator/builtins.rs @@ -7,6 +7,11 @@ use crate::evaluator::object::Object; use rand::distributions::Uniform; use rand::{thread_rng, Rng}; +use llm::{load_progress_callback_stdout as load_callback, InferenceParameters, Model}; +use llm_base::InferenceRequest; +use std::{convert::Infallible, io::Write, path::Path}; +use spinoff; + pub fn new_builtins() -> HashMap { let mut builtins = HashMap::new(); builtins.insert(String::from("len"), Object::Builtin(1, monkey_len)); @@ -29,6 +34,10 @@ pub fn new_builtins() -> HashMap { String::from("没关系的都一样"), Object::Builtin(2, three_body_deep_equal), ); + builtins.insert( + String::from("智子工程"), + Object::Builtin(1, three_body_sophon_engineering), + ); builtins } @@ -137,6 +146,176 @@ fn three_body_deep_equal(args: Vec) -> Object { } } +fn three_body_sophon_infer(args: Vec) -> Object { + match &args[0] { + Object::Hash(hash) => { + let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() { + Object::NativeObject(model_ptr) => { + model_ptr.clone() + }, + _ => panic!() + }; + let character = hash.get(&Object::String("character".to_owned())).unwrap(); + let model = unsafe { & *model_ptr }; + + let mut session = model.start_session(Default::default()); + let meessage = format!("{}", &args[1]); + let prompt = &format!(" +下面是描述一项任务的说明。需要适当地完成请求的响应。 + +### 角色设定: + +{} + +### 提问: + +{} + +### 回答: + +", character, meessage); + + let sp = spinoff::Spinner::new(spinoff::spinners::Arc, "".to_string(), None); + + if let Err(llm::InferenceError::ContextFull) = session.feed_prompt::( + model, + &InferenceParameters { + ..Default::default() + }, + prompt, + &mut Default::default(), + |t| { + Ok(()) + }, + ) { + println!("Prompt exceeds context window length.") + }; + sp.clear(); + + let res = session.infer::( + model, + &mut thread_rng(), + &InferenceRequest { + prompt: "", + ..Default::default() + }, + // OutputRequest + &mut Default::default(), + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + match res { + Err(err) => println!("\n{err}"), + _ => () + } + Object::Null + }, + _ => panic!() + } +} + + + +fn three_body_sophon_close(args: Vec) -> Object { + match &args[0] { + Object::Hash(hash) => { + let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() { + Object::NativeObject(model_ptr) => { + model_ptr.clone() + }, + _ => panic!() + }; + // let model = unsafe { & *model_ptr }; + unsafe { Box::from_raw(model_ptr) }; + // std::mem::drop(model); + Object::Null + }, + _ => panic!() + } +} + + +fn three_body_sophon_engineering(args: Vec) -> Object { + match &args[0] { + Object::Hash(o) => { + let model_type = o[&Object::String("type".to_owned())].clone(); + let model_path = o[&Object::String("path".to_owned())].clone(); + let prompt = o[&Object::String("prompt".to_owned())].clone(); + + let now = std::time::Instant::now(); + + let model_type = { + match model_type { + Object::String(model_type) => { + model_type + }, + _ => { + panic!() + } + } + }; + + let model_type = model_type.as_str(); + + + let model_path = { + match model_path { + Object::String(path) => { + path + }, + _ => { + panic!() + } + } + }; + + let model_path = Path::new(model_path.as_str()); + + let prompt = { + match prompt { + Object::String(prompt) => { + prompt + }, + _ => { + panic!() + } + } + }; + + let character = prompt; + + let architecture = model_type.parse().unwrap_or_else(|e| panic!("{e}")); + + let model = llm::load_dynamic(architecture, model_path, Default::default(), load_callback) + .unwrap_or_else(|err| { + panic!("Failed to load {model_type} model from {model_path:?}: {err}") + }); + + let model = Box::leak(model); + + println!( + "智子工程初始化成功: 耗时 {} ms", + now.elapsed().as_millis() + ); + + let model_ptr = &mut *model as *mut dyn Model; + + let mut session_hash = HashMap::new(); + session_hash.insert(Object::String("model".to_owned()), Object::NativeObject(model_ptr)); + session_hash.insert(Object::String("character".to_owned()), Object::String(character.to_string())); + session_hash.insert(Object::String("infer".to_owned()), Object::Builtin(2, three_body_sophon_infer)); + session_hash.insert(Object::String("close".to_owned()), Object::Builtin(1, three_body_sophon_close)); + Object::Hash(session_hash) + } + _ => Object::Null, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/interpreter/src/evaluator/object.rs b/interpreter/src/evaluator/object.rs index 7dab4fa..87db104 100644 --- a/interpreter/src/evaluator/object.rs +++ b/interpreter/src/evaluator/object.rs @@ -3,6 +3,8 @@ use std::rc::Rc; use std::cell::RefCell; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use llm; + use crate::evaluator::env; use crate::ast; use crate::lexer::unescape::escape_str; @@ -23,6 +25,7 @@ pub enum Object { ContinueStatement, Error(String), Null, + NativeObject(*mut dyn llm::Model), } /// This is actually repr @@ -71,6 +74,7 @@ impl fmt::Display for Object { Object::ContinueStatement => write!(f, "ContinueStatement"), Object::ReturnValue(ref value) => write!(f, "ReturnValue({})", value), Object::Error(ref value) => write!(f, "Error({})", value), + Object::NativeObject(ref model) => write!(f, "NativeObject({:?})", (model)), } } }