Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vraspar/phi 3 ios update #467

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 115 additions & 13 deletions mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,27 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
vraspar marked this conversation as resolved.
Show resolved Hide resolved

import SwiftUI


struct Message: Identifiable {
let id = UUID()
var text: String
let isUser: Bool
}

struct ContentView: View {
@ObservedObject var tokenUpdater = SharedTokenUpdater.shared
@State private var userInput: String = ""
@State private var messages: [Message] = [] // Store chat messages locally
@State private var isGenerating: Bool = false // Track token generation state
@State private var stats: String = "" // token genetation stats

vraspar marked this conversation as resolved.
Show resolved Hide resolved
var body: some View {
VStack {
// ChatBubbles
ScrollView {
VStack(alignment: .leading) {
ForEach(tokenUpdater.decodedTokens, id: \.self) { token in
Text(token)
.padding(.horizontal, 5)
VStack(alignment: .leading, spacing: 20) {
ForEach(messages) { message in
ChatBubble(text: message.text, isUser: message.isUser)
.padding(.horizontal, 20)
}
if !stats.isEmpty {
Text(stats)
.font(.footnote)
.foregroundColor(.gray)
.padding(.horizontal, 20)
.padding(.top, 5)
.multilineTextAlignment(.center)
}
}
.padding()
.padding(.top, 20)
}
Button("Generate Tokens") {
DispatchQueue.global(qos: .background).async {
// TODO: add user prompt question UI
GenAIGenerator.generate("Who is the current US president?");


// User input
HStack {
TextField("Type your message...", text: $userInput)
.padding()
.background(Color(.systemGray6))
.cornerRadius(20)
.padding(.horizontal)

Button(action: {
// Check for non-empty input
guard !userInput.trimmingCharacters(in: .whitespaces).isEmpty else { return }

messages.append(Message(text: userInput, isUser: true))
messages.append(Message(text: "", isUser: false)) // Placeholder for AI response


// clear previously generated tokens
SharedTokenUpdater.shared.clearTokens()

let prompt = userInput
userInput = ""
isGenerating = true


DispatchQueue.global(qos: .background).async {
GenAIGenerator.generate(prompt)
}
}) {
Image(systemName: "paperplane.fill")
.foregroundColor(.white)
.padding()
.background(isGenerating ? Color.gray : Color.pastelGreen)
.clipShape(Circle())
.padding(.trailing, 10)
}
.disabled(isGenerating)
}
.padding(.bottom, 20)
}
.background(Color(.systemGroupedBackground))
.edgesIgnoringSafeArea(.bottom)
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationCompleted"))) { _ in
isGenerating = false // Re-enable the button when token generation is complete
}
.onReceive(SharedTokenUpdater.shared.$decodedTokens) { tokens in
// update model response
if let lastIndex = messages.lastIndex(where: { !$0.isUser }) {
let combinedText = tokens.joined(separator: "")
messages[lastIndex].text = combinedText
}
}
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationStats"))) { notification in
if let userInfo = notification.userInfo,
let totalTime = userInfo["totalTime"] as? Int,
let firstTokenTime = userInfo["firstTokenTime"] as? Int,
let tokenCount = userInfo["tokenCount"] as? Int {
stats = "Generated \(tokenCount) tokens in \(totalTime) ms. First token in \(firstTokenTime) ms."
}
vraspar marked this conversation as resolved.
Show resolved Hide resolved
}
vraspar marked this conversation as resolved.
Show resolved Hide resolved
}
}

struct ChatBubble: View {
var text: String
var isUser: Bool

var body: some View {
HStack {
if isUser {
Spacer()
Text(text)
.padding()
.background(Color.pastelGreen)
.foregroundColor(.white)
.cornerRadius(25)
.padding(.horizontal, 10)
} else {
Text(text)
.padding()
.background(Color(.systemGray5))
.foregroundColor(.black)
.cornerRadius(25)
.padding(.horizontal, 20)
Spacer()
}
}
}
Expand All @@ -32,3 +129,8 @@ struct ContentView_Previews: PreviewProvider {
ContentView()
}
}

// Extension for a pastel green color
extension Color {
static let pastelGreen = Color(red: 0.6, green: 0.9, blue: 0.6)
}
128 changes: 96 additions & 32 deletions mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.mm
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,109 @@
#include "LocalLLM-Swift.h"
#include "ort_genai.h"
#include "ort_genai_c.h"

#include <chrono>

@implementation GenAIGenerator

+ (void)generate:(nonnull NSString*)input_user_question {
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
const char* modelPath = llmPath.cString;

auto model = OgaModel::Create(modelPath);
auto tokenizer = OgaTokenizer::Create(*model);

NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
const char* prompt = [promptString UTF8String];

auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt, *sequences);
typedef std::chrono::high_resolution_clock Clock;
vraspar marked this conversation as resolved.
Show resolved Hide resolved
typedef std::chrono::time_point<Clock> TimePoint;

auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 200);
params->SetInputSequences(*sequences);

// Streaming Output to generate token by token
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);

auto generator = OgaGenerator::Create(*model, *params);
+ (void)generate:(nonnull NSString*)input_user_question {
NSLog(@"Starting token generation...");

NSString* llmPath = [[NSBundle mainBundle] resourcePath];
const char* modelPath = llmPath.cString;

// Log model creation
NSLog(@"Creating model ...");
auto model = OgaModel::Create(modelPath);
vraspar marked this conversation as resolved.
Show resolved Hide resolved
if (!model) {
vraspar marked this conversation as resolved.
Show resolved Hide resolved
NSLog(@"Failed to create model.");
return;
}

NSLog(@"Creating tokenizer...");
auto tokenizer = OgaTokenizer::Create(*model);
if (!tokenizer) {
NSLog(@"Failed to create tokenizer.");
return;
}
vraspar marked this conversation as resolved.
Show resolved Hide resolved

auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);

// Construct the prompt
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
const char* prompt = [promptString UTF8String];

NSLog(@"Encoding prompt...");
auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt, *sequences);

// Log parameters
NSLog(@"Setting generator parameters...");
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 200);
params->SetInputSequences(*sequences);

NSLog(@"Creating generator...");
auto generator = OgaGenerator::Create(*model, *params);

bool isFirstToken = true;
TimePoint startTime = Clock::now();
TimePoint firstTokenTime;
int tokenCount = 0;

NSLog(@"Starting token generation loop...");
while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

if (isFirstToken) {
NSLog(@"First token generated.");
firstTokenTime = Clock::now();
vraspar marked this conversation as resolved.
Show resolved Hide resolved
isFirstToken = false;
}

// Get the sequence data
const int32_t* seq = generator->GetSequenceData(0);
size_t seq_len = generator->GetSequenceCount(0);

// Decode the new token
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);

// Check for decoding failure
if (!decode_tokens) {
NSLog(@"Token decoding failed.");
break;
}

NSLog(@"Decoded token: %s", decode_tokens);
vraspar marked this conversation as resolved.
Show resolved Hide resolved
tokenCount++;

// Convert token to NSString and update UI on the main thread
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
}

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

const int32_t* seq = generator->GetSequenceData(0);
size_t seq_len = generator->GetSequenceCount(0);
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
TimePoint endTime = Clock::now();
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count();

NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %d", totalDuration, firstTokenDuration, tokenCount);

NSLog(@"Decoded tokens: %s", decode_tokens);
NSDictionary *stats = @{
@"totalTime": @(totalDuration),
@"firstTokenTime": @(firstTokenDuration),
@"tokenCount": @(tokenCount)
};

// Add decoded token to SharedTokenUpdater
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
}
// notify main thread that token generation is complete
dispatch_async(dispatch_get_main_queue(), ^{
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil];
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats];
});
NSLog(@"Token generation completed.");
}

@end
8 changes: 5 additions & 3 deletions mobile/examples/phi-3/ios/LocalLLM/LocalLLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ git clone https://github.com/microsoft/onnxruntime-genai

cd onnxruntime-genai

python3 build.py --parallel --build_dir ./build_iphoneos --ios --ios_sysroot iphoneos --ios_arch arm64 --ios_deployment_target 16.6 --cmake_generator Xcode
python3 build.py --parallel --build_dir ./build_iphoneos --ios --apple_sysroot iphoneos --osx_arch arm64 --apple_deploy_target 16.6 --cmake_generator Xcode

```

Expand Down Expand Up @@ -98,12 +98,14 @@ The app uses Objective-C/C++ since using Generative AI with ONNX Runtime C++ API

Download from hf repo: <https://huggingface.co/microsoft/Phi-3-mini-128k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4>

After downloading completes, you need to copy files over to the `Resources` directory in the `Destination` column of `Target-LocalLLM`->`Build Phases`-> `New Copy File Phases` -> `Copy Files`.
After downloading the files, Click on `LocalLLM` project from sidebar, go to `Targets > LocalLLM > Build Phases`. Find the Copy Files section, set the Destination to Resources, and add the downloaded files.

Upon app launching, Xcode will automatically copy and install the model files from Resources folder and directly download to the iOS device.

### 4. Run the app and checkout the streaming output token results

**Note**: The current app only sets up with a simple initial prompt question, you can adjust/try your own or refine the UI based on requirements.

***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for lates iOS versions accordingly.
***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for lates iOS versions accordingly.

![alt text](<Simulator Screenshot - iPhone 16.png>)
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@ import Foundation
self.decodedTokens.append(token)
}
}

@objc func clearTokens() {
DispatchQueue.main.async {
self.decodedTokens.removeAll()
}
}
}
vraspar marked this conversation as resolved.
Show resolved Hide resolved
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading