Skip to content

Commit

Permalink
Merge pull request #68 from solaoi/main
Browse files Browse the repository at this point in the history
add gpt
  • Loading branch information
solaoi authored May 28, 2023
2 parents 198ab0e + e0a352b commit ba69c98
Show file tree
Hide file tree
Showing 16 changed files with 401 additions and 63 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
[![license](https://img.shields.io/github/license/solaoi/lycoris)](https://github.com/solaoi/lycoris/blob/main/LICENSE)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/solaoi/lycoris)](https://github.com/solaoi/lycoris/releases)
[![GitHub Sponsors](https://img.shields.io/github/sponsors/solaoi?color=db61a2)](https://github.com/sponsors/solaoi)
[![PIXIV FANBOX](https://img.shields.io/badge/%E5%AF%84%E4%BB%98-PIXIV%20FANBOX-ff69b4)](https://solaoi.fanbox.cc/)

外部データ通信無しで、リアルタイム音声認識で文字起こしを行う音声ノートアプリケーションです。

Expand Down
3 changes: 2 additions & 1 deletion src-tauri/migrations/001.sql
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ CREATE TABLE settings (
);
INSERT INTO settings(setting_name, setting_status) VALUES("speakerLanguage", NULL);
INSERT INTO settings(setting_name, setting_status) VALUES("transcriptionAccuracy", "off");
INSERT INTO settings(setting_name, setting_status) VALUES("settingKey", "");
INSERT INTO settings(setting_name, setting_status) VALUES("settingKeyOpenai", "");
INSERT INTO settings(setting_name, setting_status) VALUES("settingLanguage", "日本語");
INSERT INTO settings(setting_name, setting_status) VALUES("settingTemplate", "");
CREATE TABLE models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT,
Expand Down
9 changes: 8 additions & 1 deletion src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,21 @@ fn start_trace_command(
*lock = Some(stop_convert_tx);

std::thread::spawn(move || {
if transcription_accuracy.starts_with("online") {
if transcription_accuracy.starts_with("online-transcript") {
let mut transcription_online = module::transcription_online::TranscriptionOnline::new(
window.app_handle(),
transcription_accuracy,
speaker_language,
note_id,
);
transcription_online.start(stop_convert_rx, true);
} else if transcription_accuracy.starts_with("online-chat") {
let mut chat_online = module::chat_online::ChatOnline::new(
window.app_handle(),
speaker_language,
note_id,
);
chat_online.start(stop_convert_rx, true);
} else {
let mut transcription = module::transcription::Transcription::new(
window.app_handle(),
Expand Down
247 changes: 247 additions & 0 deletions src-tauri/src/module/chat_online.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
use tokio::{fs::File, io::AsyncReadExt};

use super::sqlite::Sqlite;

use crossbeam_channel::Receiver;

use reqwest::{
header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
multipart, Client,
};
use serde_json::{json, Value};
use tauri::{AppHandle, Manager};

#[derive(Debug, Clone, serde::Serialize)]
pub struct TraceCompletion {}

pub struct ChatOnline {
app_handle: AppHandle,
sqlite: Sqlite,
speaker_language: String,
note_id: u64,
token: String,
}

impl ChatOnline {
pub fn new(app_handle: AppHandle, speaker_language: String, note_id: u64) -> Self {
let sqlite = Sqlite::new();
let token = sqlite.select_whisper_token().unwrap();
Self {
app_handle,
sqlite,
speaker_language,
note_id,
token,
}
}

pub fn start(&mut self, stop_convert_rx: Receiver<()>, is_continuous: bool) {
while Self::convert(self).is_ok() {
if is_continuous {
let vosk_speech = self.sqlite.select_vosk(self.note_id);
if vosk_speech.is_err() {
self.app_handle
.clone()
.emit_all("traceCompletion", TraceCompletion {})
.unwrap();
break;
}
}
if stop_convert_rx.try_recv().is_ok() {
let vosk_speech = self.sqlite.select_vosk(self.note_id);
if vosk_speech.is_err() {
self.app_handle
.clone()
.emit_all("traceCompletion", TraceCompletion {})
.unwrap();
} else {
self.app_handle
.clone()
.emit_all("traceUnCompletion", TraceCompletion {})
.unwrap();
}
break;
}
}
}

#[tokio::main]
async fn request_whisper(
speaker_language: String,
file_path: String,
token: String,
) -> Result<String, Box<dyn std::error::Error>> {
let url = "https://api.openai.com/v1/audio/transcriptions";

let model = "whisper-1";

let client = Client::new();

let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", token))?,
);
let mut file = File::open(file_path).await?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer).await?;

let part_file = multipart::Part::bytes(buffer)
.file_name("test.wav")
.mime_str("audio/wav")?;

let part_model = multipart::Part::text(model);
let language = if speaker_language.starts_with("en-us")
|| speaker_language.starts_with("small-en-us")
{
"en"
} else if speaker_language.starts_with("cn") || speaker_language.starts_with("small-cn") {
"zh"
} else if speaker_language.starts_with("small-ko") {
"ko"
} else if speaker_language.starts_with("fr") || speaker_language.starts_with("small-fr") {
"fr"
} else if speaker_language.starts_with("de") || speaker_language.starts_with("small-de") {
"de"
} else if speaker_language.starts_with("ru") || speaker_language.starts_with("small-ru") {
"ru"
} else if speaker_language.starts_with("es") || speaker_language.starts_with("small-es") {
"es"
} else if speaker_language.starts_with("small-pt") {
"pt"
} else if speaker_language.starts_with("small-tr") {
"tr"
} else if speaker_language.starts_with("vn") || speaker_language.starts_with("small-vn") {
"vi"
} else if speaker_language.starts_with("it") || speaker_language.starts_with("small-it") {
"it"
} else if speaker_language.starts_with("small-nl") {
"nl"
} else if speaker_language.starts_with("small-ca") {
"ca"
} else if speaker_language.starts_with("uk") || speaker_language.starts_with("small-uk") {
"uk"
} else if speaker_language.starts_with("small-sv") {
"sv"
} else if speaker_language.starts_with("hi") || speaker_language.starts_with("small-hi") {
"hi"
} else if speaker_language.starts_with("small-cs") {
"cs"
} else if speaker_language.starts_with("small-pl") {
"pl"
} else {
"ja"
};
let part_language = multipart::Part::text(language);

let form = multipart::Form::new()
.part("file", part_file)
.part("model", part_model)
.part("language", part_language);

let response = client
.post(url)
.headers(headers)
.multipart(form)
.send()
.await?;

println!("Status: {}", response.status());
let json_response: Value = response.json().await?;
println!("Response: {:?}", json_response);
let response_text = json_response["text"]
.as_str()
.unwrap_or("text field not found");

Ok(response_text.to_string())
}

#[tokio::main]
async fn request_gpt(
question: &str,
token: String,
template: String,
) -> Result<String, Box<dyn std::error::Error>> {
let url = "https://api.openai.com/v1/chat/completions";

let model = "gpt-3.5-turbo";
let temperature = 0;

let client = Client::new();

let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", token))?,
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));

let post_body = if template != "" {
json!({
"model": model,
"temperature": temperature,
"messages": [{"role": "system", "content": template},{"role": "user", "content": question}]
})
} else {
json!({
"model": model,
"temperature": temperature,
"messages": [{"role": "user", "content": question}]
})
};

let response = client
.post(url)
.headers(headers)
.json(&post_body)
.send()
.await?;

println!("Status: {}", response.status());
let json_response: Value = response.json().await?;
println!("Response: {:?}", json_response);
let response_text = json_response["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("choices[0].message.content field not found");

Ok(response_text.to_string())
}

fn convert(&mut self) -> Result<(), rusqlite::Error> {
let vosk_speech = self.sqlite.select_vosk(self.note_id);
return vosk_speech.and_then(|speech| {
let result = Self::request_whisper(
self.speaker_language.clone(),
speech.wav,
self.token.clone(),
);
if result.is_ok() {
let question = result.unwrap();
let result = self.sqlite.select_ai_template();
let template = if result.is_ok() {
result.unwrap()
} else {
"".to_string()
};
let result = Self::request_gpt(&question, self.token.clone(), template);
if result.is_ok() {
let answer = result.unwrap();
let updated = self.sqlite.update_model_vosk_to_whisper(
speech.id,
format!("Q. {}\nA. {}", question, answer),
);

self.app_handle
.clone()
.emit_all("finalTextConverted", updated.unwrap())
.unwrap();
} else {
println!("gpt api is temporally failed, so skipping...")
}
} else {
println!("whisper api is temporally failed, so skipping...")
}
Ok(())
});
}
}
1 change: 1 addition & 0 deletions src-tauri/src/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ mod sqlite;
mod transcriber;
pub mod transcription;
pub mod transcription_online;
pub mod chat_online;
mod writer;
13 changes: 10 additions & 3 deletions src-tauri/src/module/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crossbeam_channel::{unbounded, Receiver};
use tauri::{api::path::data_dir, AppHandle, Manager};

use super::{
recognizer::MyRecognizer, sqlite::Sqlite, transcription::Transcription,
transcription_online::TranscriptionOnline, writer::Writer,
chat_online::ChatOnline, recognizer::MyRecognizer, sqlite::Sqlite,
transcription::Transcription, transcription_online::TranscriptionOnline, writer::Writer,
};

pub struct Record {
Expand Down Expand Up @@ -187,14 +187,21 @@ impl Record {
let mut lock = is_converting_clone.lock().unwrap();
*lock = true;
drop(lock);
if transcription_accuracy_clone.starts_with("online") {
if transcription_accuracy_clone.starts_with("online-transcript") {
let mut transcription_online = TranscriptionOnline::new(
app_handle_clone,
transcription_accuracy_clone,
speaker_language_clone,
note_id,
);
transcription_online.start(stop_convert_rx_clone, false);
} else if transcription_accuracy_clone.starts_with("online-chat") {
let mut chat_online = ChatOnline::new(
app_handle_clone,
speaker_language_clone,
note_id,
);
chat_online.start(stop_convert_rx_clone, false);
} else {
let mut transcription = Transcription::new(
app_handle_clone,
Expand Down
10 changes: 9 additions & 1 deletion src-tauri/src/module/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ impl Sqlite {

pub fn select_whisper_token(&self) -> Result<String, rusqlite::Error> {
return self.conn.query_row(
"SELECT setting_status FROM settings WHERE setting_name = \"settingKey\"",
"SELECT setting_status FROM settings WHERE setting_name = \"settingKeyOpenai\"",
params![],
|row| Ok(row.get_unwrap(0)),
);
}

pub fn select_ai_template(&self) -> Result<String, rusqlite::Error> {
return self.conn.query_row(
"SELECT setting_status FROM settings WHERE setting_name = \"settingTemplate\"",
params![],
|row| Ok(row.get_unwrap(0)),
);
Expand Down
11 changes: 4 additions & 7 deletions src-tauri/src/module/transcription_online.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,10 @@ impl TranscriptionOnline {
.sqlite
.update_model_vosk_to_whisper(speech.id, result.unwrap());

let updated = updated.unwrap();
if updated.content != "" {
self.app_handle
.clone()
.emit_all("finalTextConverted", updated)
.unwrap();
}
self.app_handle
.clone()
.emit_all("finalTextConverted", updated.unwrap())
.unwrap();
} else {
println!("whisper api is temporally failed, so skipping...")
}
Expand Down
2 changes: 1 addition & 1 deletion src-tauri/tauri.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
},
"package": {
"productName": "Lycoris",
"version": "0.8.0"
"version": "0.9.0"
},
"tauri": {
"allowlist": {
Expand Down
Loading

0 comments on commit ba69c98

Please sign in to comment.