-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial refactoring for stateless_llama.py
#213
Conversation
|
||
|
||
|
||
def run_vmfb_comparison(args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you just remove run_vmfb_comparison from stateless llama to a completely separate runner that just takes a vmfb and runs it? Also have a way to disable and enable the torch comparison
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, on it!
@@ -2,6 +2,8 @@ | |||
import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultimately I think we just want stateless_llama to only contain llama specific code like the StateUpdateModule class and for now json_schema, and have the rest as separate more generic llm exporter code using the model_builder.py to generate IR/vmfb, and llm vmfb runner. Also we need to clean out the examples dir that has old pathing referenced here #207
Also thinking about including an easy command to be able to chat with LLAMA. Since that's what people would try when they approach a new LLM anyway. @IanNod what do you think of this as a first task after i get the refactoring done? |
Reading exporter.py and model_builder.py, it sounds like I need to:
|
"--iree-input-type=torch", | ||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary", | ||
"--mlir-print-debuginfo", | ||
"--mlir-print-op-on-diagnostic=false", | ||
"--iree-llvmcpu-target-cpu-features=host", | ||
"--iree-llvmcpu-target-triple=x86_64-linux-gnu", | ||
"--iree-llvmcpu-enable-microkernels", | ||
"--iree-llvmcpu-stack-allocation-limit=256000", | ||
"--iree-stream-resource-index-bits=64", | ||
"--iree-vm-target-index-bits=64", | ||
"--iree-vm-bytecode-module-strip-source-map=true", | ||
"--iree-util-zero-fill-elided-attrs", | ||
"--iree-vm-target-truncate-unsupported-floats", | ||
"--iree-codegen-check-ir-before-llvm-conversion=false", | ||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary", | ||
"--iree-opt-const-expr-hoisting=False", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So these flags are something I copy pasted offhandedly from a vicuna.py custom compile command in a google doc somewhere, I don't think they're a canonical solid list to just make a default for all compiled models in turbine. Some of them are probably fine but probably worth looking through.
) | ||
with open(path, "wb+") as f: | ||
f.write(flatbuffer_blob) | ||
print("saved to ", path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a logger
|
||
flatbuffer_blob = ireec.compile_str( | ||
module_str, | ||
target_backends=["llvm-cpu"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the target backend should probably be a function argument
python/shark_turbine/aot/exporter.py
Outdated
|
||
return all_pkv_tensors | ||
|
||
def export_llama( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't put any llama specific code here. This should be for general exporting only. If we want to make abstractions for what we did in the stateless_llama model (i.e. export pkvs as a global) we can do that, but it shouldn't be llama specific and probably warrants its own py file somewhere.
|
||
|
||
|
||
# TODO (Dan): replace this with a file once I figure out paths on windows exe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to not exist more than once in the whole codebase, and could come from its own python file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(as opposed to json file, which is harder to bundle in an exe)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets make a python runner wrapper that can do both compile/inference from one cli invocation. Bash scripts aren't os agnostic
Can see some of my comments, but broadly nothing model specific should go into exporter.py Chat functionality already exists in SHARK, which imports this, you can add code to do it here too, but it may be unmercenary. |
Current plan:
|
50c8032
to
7999ddb
Compare
f478670
to
a827607
Compare
All of this would be easier to do once we have some end to end tests. I've taken notes on Dan & Ian's comments & will open another PR later. |
This partially addresses #184 #185