From 9be4286665ab419e54707411cf8bd8cd870923f8 Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Mon, 6 Jan 2025 15:46:33 +0100 Subject: [PATCH 1/9] Fix `whisper_buffer` and `resampled_buffer` data race `media_unpause` was causing `wisper_buffer` to be freed while `vad_based_segmentation`/`hybrid_vad_segmentation` need that buffer to not be modified for the duration of those calls --- src/transcription-filter-callbacks.cpp | 4 +--- src/transcription-filter-data.h | 1 + src/whisper-utils/whisper-processing.cpp | 7 +++++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 044a5bd..37fec00 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -462,10 +462,8 @@ void reset_caption_state(transcription_filter_data *gf_) if (gf_->info_buffer.data != nullptr) { circlebuf_free(&gf_->info_buffer); } - if (gf_->whisper_buffer.data != nullptr) { - circlebuf_free(&gf_->whisper_buffer); - } } + gf_->clear_buffers = true; } void media_play_callback(void *data_, calldata_t *cd) diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 8201c50..4ca5d91 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -47,6 +47,7 @@ struct transcription_filter_data { float *copy_buffers[MAX_PREPROC_CHANNELS]; struct circlebuf info_buffer; struct circlebuf input_buffers[MAX_PREPROC_CHANNELS]; + std::atomic clear_buffers; struct circlebuf whisper_buffer; /* Resampler */ diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index b53c5d4..55f57a3 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -386,6 +386,13 @@ void whisper_loop(void *data) } } + if (gf->clear_buffers) { + circlebuf_pop_front(&gf->resampled_buffer, nullptr, 0); + circlebuf_pop_front(&gf->whisper_buffer, nullptr, 0); + current_vad_state = {false, now_ms(), 0, 0}; + gf->clear_buffers = false; + } + if (gf->vad_mode == VAD_MODE_HYBRID) { current_vad_state = hybrid_vad_segmentation(gf, current_vad_state); } else if (gf->vad_mode == VAD_MODE_ACTIVE) { From 211e27722d97d4077b64b7237de40e757e96d703 Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Thu, 9 Jan 2025 19:05:36 +0100 Subject: [PATCH 2/9] Slightly improve handling for weird subtitle output filenames --- src/transcription-filter-callbacks.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 37fec00..7825126 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -200,8 +200,10 @@ void send_translated_sentence_to_file(struct transcription_filter_data *gf, // add a postfix to the file name (without extension) with the translation target language std::string translated_file_path = ""; std::string output_file_path = gf->output_file_path; - std::string file_extension = - output_file_path.substr(output_file_path.find_last_of(".") + 1); + auto point_pos = output_file_path.find_last_of("."); + std::string file_extension = point_pos != output_file_path.npos + ? output_file_path.substr(point_pos + 1) + : ""; std::string file_name = output_file_path.substr(0, output_file_path.find_last_of(".")); translated_file_path = file_name + "_" + target_lang + "." + file_extension; From bda52c1bdeb89501a95935e4a6a72a5a09e4b3cb Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Tue, 14 Jan 2025 15:48:01 +0100 Subject: [PATCH 3/9] Squashed 'deps/c-webvtt-in-video-stream/' content from commit 5579ca6 git-subtree-dir: deps/c-webvtt-in-video-stream git-subtree-split: 5579ca6dc9dcf94e3c14631c6c01b2ee4dfcf005 --- .gitignore | 1 + .vscode/settings.json | 7 + Cargo.lock | 581 +++++++++++++++++++++ Cargo.toml | 27 + build.rs | 8 + cbindgen.toml | 9 + src/lib.rs | 220 ++++++++ video-bytestream-tools/Cargo.toml | 10 + video-bytestream-tools/src/h264.rs | 361 +++++++++++++ video-bytestream-tools/src/h264/annex_b.rs | 100 ++++ video-bytestream-tools/src/h264/avcc.rs | 153 ++++++ video-bytestream-tools/src/lib.rs | 2 + video-bytestream-tools/src/webvtt.rs | 155 ++++++ webvtt-in-video-stream/Cargo.toml | 8 + webvtt-in-video-stream/src/lib.rs | 277 ++++++++++ 15 files changed, 1919 insertions(+) create mode 100644 .gitignore create mode 100644 .vscode/settings.json create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 build.rs create mode 100644 cbindgen.toml create mode 100644 src/lib.rs create mode 100644 video-bytestream-tools/Cargo.toml create mode 100644 video-bytestream-tools/src/h264.rs create mode 100644 video-bytestream-tools/src/h264/annex_b.rs create mode 100644 video-bytestream-tools/src/h264/avcc.rs create mode 100644 video-bytestream-tools/src/lib.rs create mode 100644 video-bytestream-tools/src/webvtt.rs create mode 100644 webvtt-in-video-stream/Cargo.toml create mode 100644 webvtt-in-video-stream/src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9338a73 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "editor.formatOnSave": true, + "evenBetterToml.formatter.reorderKeys": true, + "evenBetterToml.formatter.reorderArrays": true, + "evenBetterToml.formatter.trailingNewline": true, + "rust-analyzer.check.command": "clippy" +} diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..b16183c --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,581 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys", +] + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "bitstream-io" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e445576659fd04a57b44cbd00aa37aaa815ebefa0aa3cb677a6b5e63d883074f" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "c-webvtt-in-video-stream" +version = "0.1.0" +dependencies = [ + "cbindgen", + "h264-reader", + "strum_macros", + "video-bytestream-tools", + "webvtt-in-video-stream", +] + +[[package]] +name = "cbindgen" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb" +dependencies = [ + "clap", + "heck 0.4.1", + "indexmap", + "log", + "proc-macro2", + "quote", + "serde", + "serde_json", + "syn", + "tempfile", + "toml", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "clap" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "four-cc" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3958af68a31b1d1384d3f39b6aa33eb14b6009065b5ca305ddd9712a4237124f" + +[[package]] +name = "h264-reader" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd118dcc322cc71cfc33254a19ebece92cfaaf6d4b4793fec3f7f44fbc4150df" +dependencies = [ + "bitstream-io", + "hex-slice", + "log", + "memchr", + "rfc6381-codec", +] + +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hex-slice" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5491a308e0214554f07a81d8944abe45f552871c12e3c3c6e7e5d354039a6c4c" + +[[package]] +name = "indexmap" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "itoa" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" + +[[package]] +name = "libc" +version = "0.2.169" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "mp4ra-rust" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be9daf03b43bf3842962947c62ba40f411e46a58774c60838038f04a67d17626" +dependencies = [ + "four-cc", +] + +[[package]] +name = "mpeg4-audio-const" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96a1fe2275b68991faded2c80aa4a33dba398b77d276038b8f50701a22e55918" + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "proc-macro2" +version = "1.0.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rfc6381-codec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4395f46a67f0d57c57f6a5361f3a9a0c0183a19cab3998892ecdc003de6d8037" +dependencies = [ + "four-cc", + "mp4ra-rust", + "mpeg4-audio-const", +] + +[[package]] +name = "rustix" +version = "0.38.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.216" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.216" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.133" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "serde_spanned" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +dependencies = [ + "serde", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + +[[package]] +name = "syn" +version = "2.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys", +] + +[[package]] +name = "thiserror" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f49a1853cf82743e3b7950f77e0f4d622ca36cf4317cba00c767838bac8d490" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8381894bb3efe0c4acac3ded651301ceee58a15d47c2e34885ed1908ad667061" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "toml" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "unicode-ident" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "uuid" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" + +[[package]] +name = "video-bytestream-tools" +version = "0.1.0" +dependencies = [ + "byteorder", + "h264-reader", + "thiserror", + "uuid", +] + +[[package]] +name = "webvtt-in-video-stream" +version = "0.1.0" +dependencies = [ + "thiserror", + "video-bytestream-tools", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c2e37c1 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[package] +edition = "2021" +name = "c-webvtt-in-video-stream" +version = "0.1.0" + +[lib] +crate-type = ["staticlib"] + +[profile.release] +debug = 2 +panic = "abort" + +[profile.dev] +debug = 2 +panic = "abort" + +[workspace] +members = ["webvtt-in-video-stream", "video-bytestream-tools"] + +[dependencies] +h264-reader = "0.7.0" +strum_macros = "0.26.3" +video-bytestream-tools = {path = "./video-bytestream-tools"} +webvtt-in-video-stream = {path = "./webvtt-in-video-stream"} + +[build-dependencies] +cbindgen = "0.27.0" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..930ba57 --- /dev/null +++ b/build.rs @@ -0,0 +1,8 @@ +fn main() { + let crate_dir = std::env::var_os("CARGO_MANIFEST_DIR").unwrap(); + match cbindgen::generate(crate_dir) { + Ok(bindings) => bindings.write_to_file("target/webvtt-in-sei.h"), + Err(cbindgen::Error::ParseSyntaxError { .. }) => return, // ignore in favor of cargo's syntax check + Err(err) => panic!("{:?}", err), + }; +} diff --git a/cbindgen.toml b/cbindgen.toml new file mode 100644 index 0000000..379c21b --- /dev/null +++ b/cbindgen.toml @@ -0,0 +1,9 @@ +cpp_compat = true +language = "c" + +[parse] +include = ["webvtt-in-video-stream"] +parse_deps = true + +[export] +include = ["CodecFlavor"] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..07ad373 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,220 @@ +use std::{ + error::Error, + ffi::{c_char, CStr}, + time::Duration, +}; +use strum_macros::FromRepr; +use video_bytestream_tools::{ + h264::{self, H264ByteStreamWrite, NalHeader, NalUnitWrite, RbspWrite}, + webvtt::WebvttWrite, +}; +use webvtt_in_video_stream::{WebvttMuxer, WebvttMuxerBuilder, WebvttString}; + +#[no_mangle] +pub extern "C" fn webvtt_create_muxer_builder( + latency_to_video_in_msecs: u16, + send_frequency_hz: u8, + video_frame_time_in_nsecs: u64, +) -> Box { + Box::new(WebvttMuxerBuilder::new( + Duration::from_millis(latency_to_video_in_msecs.into()), + send_frequency_hz, + Duration::from_nanos(video_frame_time_in_nsecs), + )) +} + +fn turn_into_webvtt_string(ptr: *const c_char) -> Option { + if ptr.is_null() { + return None; + } + let c_str = unsafe { CStr::from_ptr(ptr) }; + WebvttString::from_string(c_str.to_string_lossy().into_owned()).ok() +} + +#[no_mangle] +pub extern "C" fn webvtt_muxer_builder_add_track( + builder: Option<&mut WebvttMuxerBuilder>, + default: bool, + autoselect: bool, + forced: bool, + name_ptr: *const c_char, + language_ptr: *const c_char, + assoc_language_ptr: *const c_char, + characteristics_ptr: *const c_char, +) -> bool { + let Some(builder) = builder else { return false }; + let Some(name) = turn_into_webvtt_string(name_ptr) else { + return false; + }; + let Some(language) = turn_into_webvtt_string(language_ptr) else { + return false; + }; + let assoc_language = turn_into_webvtt_string(assoc_language_ptr); + let characteristics = turn_into_webvtt_string(characteristics_ptr); + builder + .add_track( + default, + autoselect, + forced, + name, + language, + assoc_language, + characteristics, + ) + .is_ok() +} + +#[no_mangle] +pub extern "C" fn webvtt_muxer_builder_create_muxer( + muxer_builder: Option>, +) -> Option> { + muxer_builder.map(|builder| Box::new(builder.create_muxer())) +} + +#[no_mangle] +pub extern "C" fn webvtt_muxer_free(_: Option>) {} + +#[no_mangle] +pub extern "C" fn webvtt_muxer_add_cue( + muxer: Option<&WebvttMuxer>, + track: u8, + start_time_in_msecs: u64, + duration_in_msecs: u64, + text_ptr: *const c_char, +) -> bool { + let Some(muxer) = muxer else { return false }; + let Some(text) = turn_into_webvtt_string(text_ptr) else { + return false; + }; + muxer + .add_cue( + track, + Duration::from_millis(start_time_in_msecs), + Duration::from_millis(duration_in_msecs), + text, + ) + .is_ok() +} + +#[derive(FromRepr, Copy, Clone)] +#[repr(u8)] +enum CodecFlavor { + H264Avcc1, + H264Avcc2, + H264Avcc4, + H264AnnexB, +} + +impl CodecFlavor { + fn into_internal(self) -> CodecFlavorInternal { + match self { + CodecFlavor::H264Avcc1 => CodecFlavorInternal::H264(CodecFlavorH264::Avcc(1)), + CodecFlavor::H264Avcc2 => CodecFlavorInternal::H264(CodecFlavorH264::Avcc(2)), + CodecFlavor::H264Avcc4 => CodecFlavorInternal::H264(CodecFlavorH264::Avcc(4)), + CodecFlavor::H264AnnexB => CodecFlavorInternal::H264(CodecFlavorH264::AnnexB), + } + } +} + +enum CodecFlavorH264 { + Avcc(usize), + AnnexB, +} + +enum CodecFlavorInternal { + H264(CodecFlavorH264), +} + +pub struct WebvttBuffer(Vec); + +#[no_mangle] +pub extern "C" fn webvtt_muxer_try_mux_into_bytestream( + muxer: Option<&WebvttMuxer>, + video_timestamp_in_nsecs: u64, + add_header: bool, + codec_flavor: u8, +) -> Option> { + fn mux_into_bytestream<'a, W: WebvttWrite + 'a>( + muxer: &WebvttMuxer, + video_timestamp: Duration, + add_header: bool, + buffer: &'a mut Vec, + init: impl Fn(&'a mut Vec) -> Result>, + finish: impl Fn(W) -> Result<(), Box>, + ) -> Result> { + let mut writer = init(buffer)?; + if !muxer.try_mux_into_bytestream(video_timestamp, add_header, &mut writer)? { + return Ok(false); + } + finish(writer)?; + Ok(true) + } + + fn create_nal_header() -> NalHeader { + NalHeader::from_nal_unit_type_and_nal_ref_idc(h264_reader::nal::UnitType::SEI, 0).unwrap() + } + + fn inner( + muxer: Option<&WebvttMuxer>, + video_timestamp_in_nsecs: u64, + add_header: bool, + codec_flavor: u8, + ) -> Option> { + let muxer = muxer?; + let video_timestamp = Duration::from_nanos(video_timestamp_in_nsecs); + let codec_flavor = CodecFlavor::from_repr(codec_flavor)?; + let mut buffer = vec![]; + let data_written = match codec_flavor.into_internal() { + CodecFlavorInternal::H264(CodecFlavorH264::AnnexB) => mux_into_bytestream( + muxer, + video_timestamp, + add_header, + &mut buffer, + |buffer| { + Ok(h264::annex_b::AnnexBWriter::new(buffer) + .start_write_nal_unit()? + .write_nal_header(create_nal_header())?) + }, + |write| { + write.finish_rbsp()?; + Ok(()) + }, + ) + .ok()?, + CodecFlavorInternal::H264(CodecFlavorH264::Avcc(length_size)) => mux_into_bytestream( + muxer, + video_timestamp, + add_header, + &mut buffer, + |buffer| { + Ok(h264::avcc::AVCCWriter::new(length_size, buffer)? + .start_write_nal_unit()? + .write_nal_header(create_nal_header())?) + }, + |write| { + write.finish_rbsp()?; + Ok(()) + }, + ) + .ok()?, + }; + if !data_written { + return None; + } + Some(Box::new(WebvttBuffer(buffer))) + } + inner(muxer, video_timestamp_in_nsecs, add_header, codec_flavor) +} + +#[no_mangle] +pub extern "C" fn webvtt_buffer_data(buffer: Option<&WebvttBuffer>) -> *const u8 { + buffer.map(|b| b.0.as_ptr()).unwrap_or(std::ptr::null()) +} + +#[no_mangle] +pub extern "C" fn webvtt_buffer_length(buffer: Option<&WebvttBuffer>) -> usize { + buffer.map(|b| b.0.len()).unwrap_or(0) +} + +#[no_mangle] +pub extern "C" fn webvtt_buffer_free(_: Option>) {} diff --git a/video-bytestream-tools/Cargo.toml b/video-bytestream-tools/Cargo.toml new file mode 100644 index 0000000..1ebc56e --- /dev/null +++ b/video-bytestream-tools/Cargo.toml @@ -0,0 +1,10 @@ +[package] +edition = "2021" +name = "video-bytestream-tools" +version = "0.1.0" + +[dependencies] +byteorder = "1.5.0" +h264-reader = "0.7.0" +thiserror = "2.0.4" +uuid = "1.11.0" diff --git a/video-bytestream-tools/src/h264.rs b/video-bytestream-tools/src/h264.rs new file mode 100644 index 0000000..91f33c0 --- /dev/null +++ b/video-bytestream-tools/src/h264.rs @@ -0,0 +1,361 @@ +use crate::webvtt::{write_webvtt_header, write_webvtt_payload, WebvttTrack, WebvttWrite}; +use byteorder::WriteBytesExt; +use h264_reader::nal::UnitType; +use std::{collections::VecDeque, io::Write, time::Duration}; + +type Result = std::result::Result; + +pub mod annex_b; +pub mod avcc; + +pub trait H264ByteStreamWrite { + type Writer: NalUnitWrite; + fn start_write_nal_unit(self) -> Result; +} + +impl H264ByteStreamWrite for W { + type Writer = NalUnitWriter; + + fn start_write_nal_unit(self) -> Result { + Ok(NalUnitWriter::new(self)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NalHeader { + nal_unit_type: UnitType, + nal_ref_idc: u8, +} + +#[derive(Debug, Clone, Copy)] +pub enum NalHeaderError { + NalRefIdcOutOfRange(u8), + InvalidNalRefIdcForNalUnitType { + nal_unit_type: UnitType, + nal_ref_idc: u8, + }, + NalUnitTypeOutOfRange(UnitType), +} + +impl NalHeader { + pub fn from_nal_unit_type_and_nal_ref_idc( + nal_unit_type: UnitType, + nal_ref_idc: u8, + ) -> Result { + if nal_ref_idc >= 4 { + return Err(NalHeaderError::NalRefIdcOutOfRange(nal_ref_idc)); + } + match nal_unit_type.id() { + 0 => Err(NalHeaderError::NalUnitTypeOutOfRange(nal_unit_type)), + 6 | 9 | 10 | 11 | 12 => { + if nal_ref_idc == 0 { + Ok(NalHeader { + nal_unit_type, + nal_ref_idc, + }) + } else { + Err(NalHeaderError::InvalidNalRefIdcForNalUnitType { + nal_unit_type, + nal_ref_idc, + }) + } + } + 5 => { + if nal_ref_idc != 0 { + Ok(NalHeader { + nal_unit_type, + nal_ref_idc, + }) + } else { + Err(NalHeaderError::InvalidNalRefIdcForNalUnitType { + nal_unit_type, + nal_ref_idc, + }) + } + } + 32.. => Err(NalHeaderError::NalUnitTypeOutOfRange(nal_unit_type)), + _ => Ok(NalHeader { + nal_unit_type, + nal_ref_idc, + }), + } + } + + fn as_header_byte(&self) -> u8 { + self.nal_ref_idc << 5 | self.nal_unit_type.id() + } +} + +pub struct NalUnitWriter { + inner: W, +} + +pub trait NalUnitWrite { + type Writer: RbspWrite; + fn write_nal_header(self, nal_header: NalHeader) -> Result; +} + +impl NalUnitWriter { + fn new(inner: W) -> Self { + Self { inner } + } +} + +impl NalUnitWrite for NalUnitWriter { + type Writer = RbspWriter; + + fn write_nal_header(mut self, nal_header: NalHeader) -> Result> { + self.inner.write_u8(nal_header.as_header_byte())?; + Ok(RbspWriter::new(self.inner)) + } +} + +pub struct RbspWriter { + last_written: VecDeque, + inner: W, +} + +pub trait RbspWrite { + type Writer: H264ByteStreamWrite; + fn finish_rbsp(self) -> Result; +} + +impl RbspWriter { + pub fn new(inner: W) -> Self { + Self { + last_written: VecDeque::with_capacity(3), + inner, + } + } +} + +impl RbspWrite for RbspWriter { + type Writer = W; + fn finish_rbsp(mut self) -> Result { + self.write_u8(0x80)?; + Ok(self.inner) + } +} + +impl Write for RbspWriter { + fn write(&mut self, buf: &[u8]) -> Result { + let mut written = 0; + for &byte in buf { + let mut last_written_iter = self.last_written.iter(); + if last_written_iter.next() == Some(&0) + && last_written_iter.next() == Some(&0) + && (byte == 0 || byte == 1 || byte == 2 || byte == 3) + { + self.inner.write_u8(3)?; + self.last_written.clear(); + } + self.inner.write_u8(byte)?; + written += 1; + self.last_written.push_back(byte); + if self.last_written.len() > 2 { + self.last_written.pop_front(); + } + } + Ok(written) + } + + fn flush(&mut self) -> Result<()> { + self.inner.flush() + } +} + +pub(crate) struct CountingSink { + count: usize, +} + +impl CountingSink { + pub fn new() -> Self { + Self { count: 0 } + } + + pub fn count(&self) -> usize { + self.count + } +} + +impl Write for CountingSink { + fn write(&mut self, buf: &[u8]) -> Result { + self.count += buf.len(); + Ok(buf.len()) + } + + fn flush(&mut self) -> Result<()> { + Ok(()) + } +} + +pub(crate) fn write_sei_header( + writer: &mut W, + mut payload_type: usize, + mut payload_size: usize, +) -> std::io::Result<()> { + while payload_type >= 255 { + writer.write_u8(255)?; + payload_type -= 255; + } + writer.write_u8(payload_type.try_into().unwrap())?; + while payload_size >= 255 { + writer.write_u8(255)?; + payload_size -= 255; + } + writer.write_u8(payload_size.try_into().unwrap())?; + Ok(()) +} + +impl WebvttWrite for RbspWriter { + fn write_webvtt_header( + &mut self, + max_latency_to_video: Duration, + send_frequency_hz: u8, + subtitle_tracks: &[WebvttTrack], + ) -> std::io::Result<()> { + write_webvtt_header( + self, + max_latency_to_video, + send_frequency_hz, + subtitle_tracks, + ) + } + + fn write_webvtt_payload( + &mut self, + track_index: u8, + chunk_number: u64, + chunk_version: u8, + video_offset: Duration, + webvtt_payload: &str, // TODO: replace with string type that checks for interior NULs + ) -> std::io::Result<()> { + write_webvtt_payload( + self, + track_index, + chunk_number, + chunk_version, + video_offset, + webvtt_payload, + ) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + h264::{NalHeader, NalUnitWrite, NalUnitWriter, RbspWrite}, + webvtt::{WebvttWrite, PAYLOAD_GUID, USER_DATA_UNREGISTERED}, + }; + use byteorder::{BigEndian, ReadBytesExt}; + use h264_reader::nal::{Nal, RefNal, UnitType}; + use std::{io::Read, time::Duration}; + + #[test] + fn check_webvtt_sei() { + let mut writer = vec![]; + + let nalu_writer = NalUnitWriter::new(&mut writer); + let nal_unit_type = h264_reader::nal::UnitType::SEI; + let nal_ref_idc = 0; + let nal_header = + NalHeader::from_nal_unit_type_and_nal_ref_idc(nal_unit_type, nal_ref_idc).unwrap(); + let mut payload_writer = nalu_writer.write_nal_header(nal_header).unwrap(); + let track_index = 0; + let chunk_number = 1; + let chunk_version = 0; + let video_offset = Duration::from_millis(200); + let webvtt_payload = "Some unverified data"; + payload_writer + .write_webvtt_payload( + track_index, + chunk_number, + chunk_version, + video_offset, + webvtt_payload, + ) + .unwrap(); + payload_writer.finish_rbsp().unwrap(); + assert!(&writer[3..19] == PAYLOAD_GUID.as_bytes()); + + let nal = RefNal::new(&writer, &[], true); + assert!(nal.is_complete()); + assert!(nal.header().unwrap().nal_unit_type() == UnitType::SEI); + let mut byte_reader = nal.rbsp_bytes(); + + assert!(usize::from(byte_reader.read_u8().unwrap()) == USER_DATA_UNREGISTERED); + let mut length = 0; + loop { + let byte = byte_reader.read_u8().unwrap(); + length += usize::from(byte); + if byte != 255 { + break; + } + } + assert!(length + 1 == byte_reader.clone().bytes().count()); + byte_reader.read_u128::().unwrap(); + assert!(track_index == byte_reader.read_u8().unwrap()); + assert!(chunk_number == byte_reader.read_u64::().unwrap()); + assert!(chunk_version == byte_reader.read_u8().unwrap()); + assert!( + u16::try_from(video_offset.as_millis()).unwrap() + == byte_reader.read_u16::().unwrap() + ); + println!("{writer:02x?}"); + } + + #[test] + fn check_webvtt_multi_sei() { + let mut writer = vec![]; + + let nalu_writer = NalUnitWriter::new(&mut writer); + let nal_unit_type = h264_reader::nal::UnitType::SEI; + let nal_ref_idc = 0; + let nal_header = + NalHeader::from_nal_unit_type_and_nal_ref_idc(nal_unit_type, nal_ref_idc).unwrap(); + let mut payload_writer = nalu_writer.write_nal_header(nal_header).unwrap(); + let track_index = 0; + let chunk_number = 1; + let chunk_version = 0; + let video_offset = Duration::from_millis(200); + let webvtt_payload = "Some unverified data"; + payload_writer + .write_webvtt_payload( + track_index, + chunk_number, + chunk_version, + video_offset, + webvtt_payload, + ) + .unwrap(); + payload_writer + .write_webvtt_payload(1, 1, 0, video_offset, "Something else") + .unwrap(); + payload_writer.finish_rbsp().unwrap(); + assert!(&writer[3..19] == PAYLOAD_GUID.as_bytes()); + + let nal = RefNal::new(&writer, &[], true); + assert!(nal.is_complete()); + assert!(nal.header().unwrap().nal_unit_type() == UnitType::SEI); + let mut byte_reader = nal.rbsp_bytes(); + + assert!(usize::from(byte_reader.read_u8().unwrap()) == USER_DATA_UNREGISTERED); + let mut _length = 0; + loop { + let byte = byte_reader.read_u8().unwrap(); + _length += usize::from(byte); + if byte != 255 { + break; + } + } + byte_reader.read_u128::().unwrap(); + assert!(track_index == byte_reader.read_u8().unwrap()); + assert!(chunk_number == byte_reader.read_u64::().unwrap()); + assert!(chunk_version == byte_reader.read_u8().unwrap()); + assert!( + u16::try_from(video_offset.as_millis()).unwrap() + == byte_reader.read_u16::().unwrap() + ); + println!("{writer:02x?}"); + } +} diff --git a/video-bytestream-tools/src/h264/annex_b.rs b/video-bytestream-tools/src/h264/annex_b.rs new file mode 100644 index 0000000..09b0e66 --- /dev/null +++ b/video-bytestream-tools/src/h264/annex_b.rs @@ -0,0 +1,100 @@ +use super::{ + H264ByteStreamWrite, NalHeader, NalUnitWrite, NalUnitWriter, RbspWrite, RbspWriter, Result, +}; +use crate::webvtt::{WebvttTrack, WebvttWrite}; +use byteorder::WriteBytesExt; +use std::{io::Write, time::Duration}; + +pub struct AnnexBWriter { + leading_zero_8bits_written: bool, + inner: W, +} + +impl AnnexBWriter { + pub fn new(inner: W) -> Self { + Self { + leading_zero_8bits_written: false, + inner, + } + } +} + +impl H264ByteStreamWrite for AnnexBWriter { + type Writer = AnnexBNalUnitWriter; + + fn start_write_nal_unit(mut self) -> Result> { + if !self.leading_zero_8bits_written { + self.inner.write_u8(0)?; + self.leading_zero_8bits_written = true; + } + self.inner.write_all(&[0, 0, 1])?; + Ok(AnnexBNalUnitWriter { + inner: NalUnitWriter::new(self.inner), + }) + } +} + +pub struct AnnexBNalUnitWriter { + inner: NalUnitWriter, +} + +impl AnnexBNalUnitWriter { + fn _nal_unit_writer(&mut self) -> &mut NalUnitWriter { + &mut self.inner + } +} + +impl NalUnitWrite for AnnexBNalUnitWriter { + type Writer = AnnexBRbspWriter; + + fn write_nal_header(self, nal_header: NalHeader) -> Result> { + self.inner + .write_nal_header(nal_header) + .map(|inner| AnnexBRbspWriter { inner }) + } +} + +pub struct AnnexBRbspWriter { + inner: RbspWriter, +} + +impl AnnexBRbspWriter {} + +impl RbspWrite for AnnexBRbspWriter { + type Writer = AnnexBWriter; + + fn finish_rbsp(self) -> Result { + self.inner + .finish_rbsp() + .map(|writer| AnnexBWriter::new(writer)) + } +} + +impl WebvttWrite for AnnexBRbspWriter { + fn write_webvtt_header( + &mut self, + max_latency_to_video: Duration, + send_frequency_hz: u8, + subtitle_tracks: &[WebvttTrack], + ) -> std::io::Result<()> { + self.inner + .write_webvtt_header(max_latency_to_video, send_frequency_hz, subtitle_tracks) + } + + fn write_webvtt_payload( + &mut self, + track_index: u8, + chunk_number: u64, + chunk_version: u8, + video_offset: Duration, + webvtt_payload: &str, // TODO: replace with string type that checks for interior NULs + ) -> std::io::Result<()> { + self.inner.write_webvtt_payload( + track_index, + chunk_number, + chunk_version, + video_offset, + webvtt_payload, + ) + } +} diff --git a/video-bytestream-tools/src/h264/avcc.rs b/video-bytestream-tools/src/h264/avcc.rs new file mode 100644 index 0000000..672b2e4 --- /dev/null +++ b/video-bytestream-tools/src/h264/avcc.rs @@ -0,0 +1,153 @@ +use super::{ + H264ByteStreamWrite, NalHeader, NalUnitWrite, NalUnitWriter, RbspWrite, RbspWriter, Result, +}; +use crate::webvtt::{WebvttTrack, WebvttWrite}; +use byteorder::{BigEndian, WriteBytesExt}; +use std::{io::Write, time::Duration}; +use thiserror::Error; + +const AVCC_MAX_LENGTH: [usize; 4] = [0xff, 0xff_ff, 0, 0xff_ff_ff_ff]; + +pub struct AVCCWriter { + length_size: usize, + inner: W, +} + +#[derive(Error, Debug)] +#[error("AVCC length of {0} is unsupported")] +pub struct InvalidLengthError(pub usize); + +#[derive(Error, Debug)] +#[error("Tried to write {required} bytes which exceeds the max size of {max}")] +pub struct MaxNalUnitSizeExceededError { + max: usize, + required: usize, +} + +impl AVCCWriter { + pub fn new(length_size: usize, inner: W) -> Result { + match length_size { + 1 | 2 | 4 => Ok(Self { length_size, inner }), + _ => Err(InvalidLengthError(length_size)), + } + } +} + +impl H264ByteStreamWrite for AVCCWriter { + type Writer = AVCCNalUnitWriter>; + + fn start_write_nal_unit(self) -> Result>> { + Ok(AVCCNalUnitWriter { + inner: NalUnitWriter::new(AVCCWriterBuffer::new(self)), + }) + } +} + +pub struct AVCCWriterBuffer { + avcc_buffer: Vec, + avcc_writer: AVCCWriter, +} + +impl AVCCWriterBuffer { + fn new(avcc_writer: AVCCWriter) -> Self { + Self { + avcc_buffer: vec![], + avcc_writer, + } + } + + fn finish(mut self) -> Result> { + match self.avcc_writer.length_size { + 1 => self.write_u8(self.avcc_buffer.len().try_into().unwrap())?, + 2 => self.write_u16::(self.avcc_buffer.len().try_into().unwrap())?, + 4 => self.write_u32::(self.avcc_buffer.len().try_into().unwrap())?, + _ => unreachable!(), + } + self.avcc_writer.inner.write_all(&self.avcc_buffer)?; + Ok(self.avcc_writer) + } +} + +impl Write for AVCCWriterBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let length = self.avcc_buffer.len(); + let additional_length = buf.len(); + if length + additional_length > AVCC_MAX_LENGTH[self.avcc_writer.length_size] { + Err(std::io::Error::other(MaxNalUnitSizeExceededError { + max: AVCC_MAX_LENGTH[self.avcc_writer.length_size], + required: length + additional_length, + })) + } else { + self.avcc_buffer.write(buf) + } + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +pub struct AVCCNalUnitWriter { + inner: NalUnitWriter, +} + +impl AVCCNalUnitWriter { + fn _nal_unit_writer(&mut self) -> &mut NalUnitWriter { + &mut self.inner + } +} + +impl NalUnitWrite for AVCCNalUnitWriter> { + type Writer = AVCCRbspWriter>; + + fn write_nal_header( + self, + nal_header: NalHeader, + ) -> Result>> { + self.inner + .write_nal_header(nal_header) + .map(|inner| AVCCRbspWriter { inner }) + } +} + +pub struct AVCCRbspWriter { + inner: RbspWriter, +} + +impl RbspWrite for AVCCRbspWriter> { + type Writer = AVCCWriter; + + fn finish_rbsp(self) -> Result { + let buffer = self.inner.finish_rbsp()?; + buffer.finish() + } +} + +impl WebvttWrite for AVCCRbspWriter { + fn write_webvtt_header( + &mut self, + max_latency_to_video: Duration, + send_frequency_hz: u8, + subtitle_tracks: &[WebvttTrack], + ) -> std::io::Result<()> { + self.inner + .write_webvtt_header(max_latency_to_video, send_frequency_hz, subtitle_tracks) + } + + fn write_webvtt_payload( + &mut self, + track_index: u8, + chunk_number: u64, + chunk_version: u8, + video_offset: Duration, + webvtt_payload: &str, // TODO: replace with string type that checks for interior NULs + ) -> std::io::Result<()> { + self.inner.write_webvtt_payload( + track_index, + chunk_number, + chunk_version, + video_offset, + webvtt_payload, + ) + } +} diff --git a/video-bytestream-tools/src/lib.rs b/video-bytestream-tools/src/lib.rs new file mode 100644 index 0000000..874d072 --- /dev/null +++ b/video-bytestream-tools/src/lib.rs @@ -0,0 +1,2 @@ +pub mod h264; +pub mod webvtt; diff --git a/video-bytestream-tools/src/webvtt.rs b/video-bytestream-tools/src/webvtt.rs new file mode 100644 index 0000000..ac692a1 --- /dev/null +++ b/video-bytestream-tools/src/webvtt.rs @@ -0,0 +1,155 @@ +use crate::h264::{write_sei_header, CountingSink}; +use byteorder::{BigEndian, WriteBytesExt}; +use std::{io::Write, time::Duration}; +use uuid::{uuid, Uuid}; + +pub const USER_DATA_UNREGISTERED: usize = 5; +pub const HEADER_GUID: Uuid = uuid!("cc7124bd-5f1c-4592-b27a-e2d9d218ef9e"); +pub const PAYLOAD_GUID: Uuid = uuid!("a0cb4dd1-9db2-4635-a76b-1c9fefd6c37b"); + +trait WriteCStrExt: Write { + fn write_c_str(&mut self, string: &str) -> std::io::Result<()> { + self.write_all(string.as_bytes())?; + self.write_u8(0)?; + Ok(()) + } +} + +impl WriteCStrExt for W {} + +pub struct WebvttTrack<'a> { + pub default: bool, + pub autoselect: bool, + pub forced: bool, + pub name: &'a str, + pub language: &'a str, + pub assoc_language: Option<&'a str>, + pub characteristics: Option<&'a str>, +} + +pub(crate) fn write_webvtt_header( + writer: &mut W, + max_latency_to_video: Duration, + send_frequency_hz: u8, + subtitle_tracks: &[WebvttTrack], +) -> std::io::Result<()> { + fn inner( + writer: &mut W, + max_latency_to_video: Duration, + send_frequency_hz: u8, + subtitle_tracks: &[WebvttTrack], + ) -> std::io::Result<()> { + writer.write_all(HEADER_GUID.as_bytes())?; + writer.write_u16::(max_latency_to_video.as_millis().try_into().unwrap())?; + writer.write_u8(send_frequency_hz)?; + writer.write_u8(subtitle_tracks.len().try_into().unwrap())?; + for track in subtitle_tracks { + let flags = { + let mut flags: u8 = 0; + if track.default { + flags |= 0b1000_0000; + } + if track.autoselect { + flags |= 0b0100_0000; + } + if track.forced { + flags |= 0b0010_0000; + } + if track.assoc_language.is_some() { + flags |= 0b0001_0000; + } + if track.characteristics.is_some() { + flags |= 0b0000_1000; + } + flags + }; + writer.write_u8(flags)?; + writer.write_c_str(track.name)?; + writer.write_c_str(track.language)?; + if let Some(assoc_language) = track.assoc_language { + writer.write_c_str(assoc_language)?; + } + if let Some(characteristics) = track.characteristics { + writer.write_c_str(characteristics)?; + } + } + Ok(()) + } + let mut count = CountingSink::new(); + inner( + &mut count, + max_latency_to_video, + send_frequency_hz, + subtitle_tracks, + )?; + write_sei_header(writer, USER_DATA_UNREGISTERED, count.count())?; + inner( + writer, + max_latency_to_video, + send_frequency_hz, + subtitle_tracks, + ) +} + +pub(crate) fn write_webvtt_payload( + writer: &mut W, + track_index: u8, + chunk_number: u64, + chunk_version: u8, + video_offset: Duration, + webvtt_payload: &str, // TODO: replace with string type that checks for interior NULs +) -> std::io::Result<()> { + fn inner( + writer: &mut W, + track_index: u8, + chunk_number: u64, + chunk_version: u8, + video_offset: Duration, + webvtt_payload: &str, + ) -> std::io::Result<()> { + writer.write_all(PAYLOAD_GUID.as_bytes())?; + writer.write_u8(track_index)?; + writer.write_u64::(chunk_number)?; + writer.write_u8(chunk_version)?; + writer.write_u16::(video_offset.as_millis().try_into().unwrap())?; + writer.write_c_str(webvtt_payload)?; + Ok(()) + } + + let mut count = CountingSink::new(); + inner( + &mut count, + track_index, + chunk_number, + chunk_version, + video_offset, + webvtt_payload, + )?; + write_sei_header(writer, USER_DATA_UNREGISTERED, count.count())?; + inner( + writer, + track_index, + chunk_number, + chunk_version, + video_offset, + webvtt_payload, + ) +} + +pub trait WebvttWrite { + fn write_webvtt_header( + &mut self, + max_latency_to_video: Duration, + send_frequency_hz: u8, + subtitle_tracks: &[WebvttTrack], + ) -> std::io::Result<()>; + + fn write_webvtt_payload( + &mut self, + track_index: u8, + chunk_number: u64, + chunk_version: u8, + video_offset: Duration, + webvtt_payload: &str, // TODO: replace with string type that checks for interior NULs + ) -> std::io::Result<()>; +} diff --git a/webvtt-in-video-stream/Cargo.toml b/webvtt-in-video-stream/Cargo.toml new file mode 100644 index 0000000..4c42a49 --- /dev/null +++ b/webvtt-in-video-stream/Cargo.toml @@ -0,0 +1,8 @@ +[package] +edition = "2021" +name = "webvtt-in-video-stream" +version = "0.1.0" + +[dependencies] +thiserror = "2.0.4" +video-bytestream-tools = {path = "../video-bytestream-tools"} diff --git a/webvtt-in-video-stream/src/lib.rs b/webvtt-in-video-stream/src/lib.rs new file mode 100644 index 0000000..d21bf18 --- /dev/null +++ b/webvtt-in-video-stream/src/lib.rs @@ -0,0 +1,277 @@ +use std::{collections::VecDeque, sync::Mutex, time::Duration}; +use video_bytestream_tools::webvtt::WebvttWrite; + +pub struct WebvttMuxerBuilder { + latency_to_video: Duration, + send_frequency_hz: u8, + video_frame_time: Duration, + tracks: Vec, +} + +struct WebvttMuxerTrack { + cues: VecDeque, + default: bool, + autoselect: bool, + forced: bool, + name: String, + language: String, + assoc_language: Option, + characteristics: Option, +} + +pub struct WebvttMuxer { + latency_to_video: Duration, + send_frequency_hz: u8, + video_frame_time: Duration, + inner: Mutex, +} + +struct WebvttMuxerInner { + tracks: Vec, + webvtt_buffer: String, + next_chunk_number: u64, + first_video_timestamp: Option, +} + +// TODO: this should probably be moved into video-bytestream-tools instead +pub struct WebvttString(String); + +struct WebvttCue { + start_time: Duration, + duration: Duration, + text: WebvttString, +} + +pub struct NulError { + pub string: String, + pub nul_position: usize, +} + +impl WebvttString { + /// Create a `WebvttString`. + /// This verifies that there are no interior NUL bytes, since + /// the WebVTT-in-SEI wire format uses NUL terminated strings. + /// + /// # Errors + /// + /// This function will return an error if there are any NUL bytes in the string. + pub fn from_string(string: String) -> Result { + if let Some(nul_position) = string.find('\0') { + Err(NulError { + string, + nul_position, + }) + } else { + Ok(WebvttString(string)) + } + } +} + +pub struct TooManySubtitleTracksError { + pub name: WebvttString, + pub language: WebvttString, + pub assoc_language: Option, + pub characteristics: Option, +} + +impl WebvttMuxerBuilder { + pub fn new( + latency_to_video: Duration, + send_frequency_hz: u8, + video_frame_time: Duration, + ) -> Self { + Self { + latency_to_video, + send_frequency_hz, + video_frame_time, + tracks: vec![], + } + } + + // FIXME: split these arguments somehow? + #[allow(clippy::too_many_arguments)] + pub fn add_track( + &mut self, + default: bool, + autoselect: bool, + forced: bool, + name: WebvttString, + language: WebvttString, + assoc_language: Option, + characteristics: Option, + ) -> Result<&mut Self, TooManySubtitleTracksError> { + if self.tracks.len() == 0xff { + return Err(TooManySubtitleTracksError { + name, + language, + assoc_language, + characteristics, + }); + } + self.tracks.push(WebvttMuxerTrack { + cues: VecDeque::new(), + default, + autoselect, + forced, + name: name.0, + language: language.0, + assoc_language: assoc_language.map(|a| a.0), + characteristics: characteristics.map(|c| c.0), + }); + Ok(self) + } + + pub fn create_muxer(self) -> WebvttMuxer { + WebvttMuxer { + latency_to_video: self.latency_to_video, + send_frequency_hz: self.send_frequency_hz, + video_frame_time: self.video_frame_time, + inner: Mutex::new(WebvttMuxerInner { + tracks: self.tracks, + webvtt_buffer: String::new(), + next_chunk_number: 0, + first_video_timestamp: None, + }), + } + } +} + +pub struct InvalidWebvttTrack(pub u8); + +impl WebvttMuxer { + pub fn add_cue( + &self, + track: u8, + start_time: Duration, + duration: Duration, + text: WebvttString, + ) -> Result<(), InvalidWebvttTrack> { + let mut inner = self.inner.lock().unwrap(); + let tracks = &mut inner.tracks; + let track = tracks + .get_mut(usize::from(track)) + .ok_or(InvalidWebvttTrack(track))?; + let cues = &mut track.cues; + let index = cues + .iter() + .position(|c| c.start_time > start_time) + .unwrap_or(cues.len()); + cues.insert( + index, + WebvttCue { + start_time, + duration, + text, + }, + ); + Ok(()) + } + + fn consume_cues_into_chunk<'a>( + cues: &mut VecDeque, + timestamp: Duration, + duration: Duration, + buffer: &'a mut String, + ) -> &'a str { + while cues + .front() + .map(|cue| (cue.start_time + cue.duration) < timestamp) + .unwrap_or(false) + { + cues.pop_front(); + } + + buffer.clear(); + + for cue in &*cues { + if cue.start_time > (timestamp + duration) { + break; + } + let cue_start = if cue.start_time > timestamp { + cue.start_time + } else { + timestamp + }; + let cue_end = (cue.start_time + cue.duration).min(timestamp + duration); + buffer.push_str(&format!( + "{:0>2}:{:0>2}:{:0>2}.{:0>3} --> {:0>2}:{:0>2}:{:0>2}.{:0>3}\n{}\n\n", + cue_start.as_secs() / 3600, + cue_start.as_secs() % 3600 / 60, + cue_start.as_secs() % 60, + cue_start.as_millis() % 1000, + cue_end.as_secs() / 3600, + cue_end.as_secs() % 3600 / 60, + cue_end.as_secs() % 60, + cue_end.as_millis() % 1000, + cue.text.0 + )) + } + buffer.as_str() + } + + pub fn try_mux_into_bytestream( + &self, + video_timestamp: Duration, + add_header: bool, + writer: &mut impl WebvttWrite, + ) -> std::io::Result { + let mut inner = self.inner.lock().unwrap(); + let WebvttMuxerInner { + tracks, + webvtt_buffer, + next_chunk_number, + first_video_timestamp, + } = &mut *inner; + + if add_header { + // TODO: cache this? forward iter instead? + let webvtt_tracks = tracks + .iter() + .map(|track| video_bytestream_tools::webvtt::WebvttTrack { + default: track.default, + autoselect: track.autoselect, + forced: track.forced, + language: &track.language, + name: &track.name, + assoc_language: track.assoc_language.as_deref(), + characteristics: track.characteristics.as_deref(), + }) + .collect::>(); + writer.write_webvtt_header( + self.latency_to_video, + self.send_frequency_hz, + &webvtt_tracks, + )?; + } + + let duration_between_sends = + Duration::from_secs_f64(1. / f64::from(self.send_frequency_hz)); + let first_video_timestamp = &*first_video_timestamp.get_or_insert(video_timestamp); + let next_chunk_webvtt_timestamp = + u32::try_from(*next_chunk_number).unwrap() * duration_between_sends; + let next_chunk_video_timestamp = + *first_video_timestamp + self.latency_to_video + next_chunk_webvtt_timestamp; + if next_chunk_video_timestamp > video_timestamp + self.video_frame_time * 2 { + return Ok(add_header); + } + let chunk_number = *next_chunk_number; + // TODO: return an error type that allows skipping chunks if the writer fails? + for (track_index, track) in tracks.iter_mut().enumerate() { + let webvtt_payload = Self::consume_cues_into_chunk( + &mut track.cues, + next_chunk_webvtt_timestamp, + duration_between_sends, + webvtt_buffer, + ); + writer.write_webvtt_payload( + u8::try_from(track_index).unwrap(), + chunk_number, + 0, + video_timestamp - (*first_video_timestamp + next_chunk_webvtt_timestamp), + webvtt_payload, + )?; + } + *next_chunk_number += 1; + Ok(true) + } +} From 3814549b203d21f58cb79d2082a81ac95554bd72 Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Fri, 20 Dec 2024 21:24:24 +0100 Subject: [PATCH 4/9] Add WIP webvtt sei functionality --- CMakeLists.txt | 8 + cmake/BuildWebVTT.cmake | 15 ++ src/plugin-main.c | 2 + src/transcription-filter-callbacks.cpp | 178 ++++++++++++++++++++++- src/transcription-filter-data.h | 84 ++++++++++- src/transcription-filter.cpp | 37 +++++ src/whisper-utils/whisper-processing.cpp | 4 +- 7 files changed, 325 insertions(+), 3 deletions(-) create mode 100644 cmake/BuildWebVTT.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ccb91d..ce3ef46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,14 @@ if(DEFINED ENV{LOCALVOCAL_EXTRA_VERBOSE}) target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE LOCALVOCAL_EXTRA_VERBOSE) endif() +option(ENABLE_WEBVTT "Enable WebVTT embedding" ON) + +if(ENABLE_WEBVTT) + include(cmake/BuildWebVTT.cmake) + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE c_webvtt_in_video_stream) + target_compile_definitions(c_webvtt_in_video_stream INTERFACE ENABLE_WEBVTT) +endif() + target_sources( ${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c diff --git a/cmake/BuildWebVTT.cmake b/cmake/BuildWebVTT.cmake new file mode 100644 index 0000000..fbae50d --- /dev/null +++ b/cmake/BuildWebVTT.cmake @@ -0,0 +1,15 @@ +include(FetchContent) + +FetchContent_Declare( + Corrosion + GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git + GIT_TAG v0.5 # Optionally specify a commit hash, version tag or branch here +) +FetchContent_MakeAvailable(Corrosion) + +# Import targets defined in a package or workspace manifest `Cargo.toml` file +corrosion_import_crate(MANIFEST_PATH "${CMAKE_SOURCE_DIR}/deps/c-webvtt-in-video-stream/Cargo.toml" CRATE_TYPES + "staticlib" PROFILE release) + +set_target_properties(c_webvtt_in_video_stream PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${CMAKE_SOURCE_DIR}/deps/c-webvtt-in-video-stream/target/") diff --git a/src/plugin-main.c b/src/plugin-main.c index 49cdca9..9ea8fbb 100644 --- a/src/plugin-main.c +++ b/src/plugin-main.c @@ -28,10 +28,12 @@ MODULE_EXPORT const char *obs_module_description(void) } extern struct obs_source_info transcription_filter_info; +extern void load_packet_callback_functions(); bool obs_module_load(void) { obs_register_source(&transcription_filter_info); + load_packet_callback_functions(); obs_log(LOG_INFO, "plugin loaded successfully (version %s)", PLUGIN_VERSION); return true; } diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 7825126..d81fef3 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -3,6 +3,7 @@ #endif #include +#include #include #include @@ -231,7 +232,32 @@ void send_caption_to_stream(DetectionResultWithText result, const std::string &s } } -void set_text_callback(struct transcription_filter_data *gf, +#ifdef ENABLE_WEBVTT +void send_caption_to_webvtt(uint64_t possible_end_ts_ms, DetectionResultWithText result, + const std::string &str_copy, transcription_filter_data &gf) +{ + auto lock = std::unique_lock(gf.active_outputs_mutex); + for (auto &output : gf.active_outputs) { + for (size_t i = 0; i < MAX_OUTPUT_VIDEO_ENCODERS; i++) { + auto &muxer = output.webvtt_muxer[i]; + if (!muxer) + continue; + + auto duration = result.end_timestamp_ms - result.start_timestamp_ms; + auto segment_start_ts = possible_end_ts_ms - duration; + if (segment_start_ts < output.start_timestamp_ms) { + duration -= output.start_timestamp_ms - segment_start_ts; + segment_start_ts = output.start_timestamp_ms; + } + webvtt_muxer_add_cue(muxer.get(), 0, + segment_start_ts - output.start_timestamp_ms, duration, + str_copy.c_str()); + } + } +} +#endif + +void set_text_callback(uint64_t possible_end_ts, struct transcription_filter_data *gf, const DetectionResultWithText &resultIn) { DetectionResultWithText result = resultIn; @@ -342,6 +368,11 @@ void set_text_callback(struct transcription_filter_data *gf, send_caption_to_stream(result, str_copy, gf); } +#ifdef ENABLE_WEBVTT + if (result.result == DETECTION_RESULT_SPEECH) + send_caption_to_webvtt(possible_end_ts, result, str_copy, *gf); +#endif + if (gf->save_to_file && gf->output_file_path != "" && result.result == DETECTION_RESULT_SPEECH) { send_sentence_to_file(gf, result, str_copy, gf->output_file_path, true); @@ -363,6 +394,134 @@ void set_text_callback(struct transcription_filter_data *gf, } }; +#ifdef ENABLE_WEBVTT +void output_packet_added_callback(obs_output_t *output, struct encoder_packet *pkt, + struct encoder_packet_time *pkt_time, void *param) +{ + if (!pkt || !pkt_time) + return; + if (pkt->type != OBS_ENCODER_VIDEO) + return; + if (pkt->track_idx >= MAX_OUTPUT_VIDEO_ENCODERS) + return; + + auto &gf = *static_cast(param); + auto lock = std::unique_lock(gf.active_outputs_mutex); + auto it = std::find_if(gf.active_outputs.begin(), gf.active_outputs.end(), [&](auto &val) { + return obs_weak_output_references_output(val.output, output); + }); + if (it == gf.active_outputs.end()) + return; + + if (!it->initialized) { + it->initialized = true; + for (size_t i = 0; i < MAX_OUTPUT_VIDEO_ENCODERS; i++) { + auto encoder = obs_output_get_video_encoder2(output, i); + if (!encoder) + continue; + + auto &codec_flavor = it->codec_flavor[i]; + if (strcmp(obs_encoder_get_codec(encoder), "h264") == 0) { + codec_flavor = H264AnnexB; + } else if (strcmp(obs_encoder_get_codec(encoder), "av1") == 0) { + continue; + } else if (strcmp(obs_encoder_get_codec(encoder), "hevc") == 0) { + continue; + } else { + continue; + } + + auto video = obs_encoder_video(encoder); + auto voi = video_output_get_info(video); + + auto muxer_builder = webvtt_create_muxer_builder( + 10'000, 2, + util_mul_div64(1000000000ULL, voi->fps_den, voi->fps_num)); + // TODO: change name/language? + webvtt_muxer_builder_add_track(muxer_builder, false, false, false, + "Subtitles", "en", nullptr, nullptr); + webvtt_muxer_builder_add_track(muxer_builder, false, false, false, "Empty", + "en", nullptr, nullptr); + it->webvtt_muxer[i].reset(webvtt_muxer_builder_create_muxer(muxer_builder)); + } + } + + auto &muxer = it->webvtt_muxer[pkt->track_idx]; + if (!muxer) + return; + + std::unique_ptr buffer{ + webvtt_muxer_try_mux_into_bytestream(muxer.get(), pkt_time->cts, pkt->keyframe, + it->codec_flavor[pkt->track_idx])}; + + if (!buffer) + return; + + long ref = 1; + + DARRAY(uint8_t) out_data; + da_init(out_data); + da_reserve(out_data, sizeof(ref) + pkt->size + webvtt_buffer_length(buffer.get())); + + // Copy the original packet + da_push_back_array(out_data, (uint8_t *)&ref, sizeof(ref)); + da_push_back_array(out_data, pkt->data, pkt->size); + da_push_back_array(out_data, webvtt_buffer_data(buffer.get()), + webvtt_buffer_length(buffer.get())); + + auto old_pkt = *pkt; + obs_encoder_packet_release(pkt); + *pkt = old_pkt; + + pkt->data = (uint8_t *)out_data.array + sizeof(ref); + pkt->size = out_data.num - sizeof(ref); +} + +void add_webvtt_output(transcription_filter_data &gf, obs_output_t *output) +{ + if (!obs_output_add_packet_callback_) + return; + + auto start_ms = now_ms(); + + auto lock = std::unique_lock(gf.active_outputs_mutex); + gf.active_outputs.push_back({}); + auto &entry = gf.active_outputs.back(); + entry.output = obs_output_get_weak_output(output); + entry.start_timestamp_ms = start_ms; + obs_output_add_packet_callback_(output, output_packet_added_callback, &gf); +} + +void remove_webvtt_output(transcription_filter_data &gf, obs_output_t *output) +{ + if (!obs_output_remove_packet_callback_) + return; + + auto lock = std::unique_lock(gf.active_outputs_mutex); + for (auto iter = gf.active_outputs.begin(); iter != gf.active_outputs.end(); iter++) { + auto &webvtt_output = *iter; + if (!obs_weak_output_references_output(webvtt_output.output, output)) + continue; + + obs_output_remove_packet_callback_(output, output_packet_added_callback, &gf); + gf.active_outputs.erase(iter); + return; + } +} + +void remove_all_webvtt_outputs(std::unique_lock & /*active_outputs_lock*/, + transcription_filter_data &gf) +{ + for (auto &output : gf.active_outputs) { + auto obs_output = OBSOutputAutoRelease{obs_weak_output_get_output(output.output)}; + if (!obs_output) + continue; + + obs_output_remove_packet_callback_(obs_output, output_packet_added_callback, &gf); + } +} +#endif + /** * @brief Callback function to handle recording state changes in OBS. * @@ -385,6 +544,9 @@ void recording_state_callback(enum obs_frontend_event event, void *data) struct transcription_filter_data *gf_ = static_cast(data); if (event == OBS_FRONTEND_EVENT_RECORDING_STARTING) { +#ifdef ENABLE_WEBVTT + add_webvtt_output(*gf_, OBSOutputAutoRelease{obs_frontend_get_recording_output()}); +#endif if (gf_->save_srt && gf_->save_only_while_recording && gf_->output_file_path != "") { obs_log(gf_->log_level, "Recording started. Resetting srt file."); @@ -397,6 +559,11 @@ void recording_state_callback(enum obs_frontend_event event, void *data) gf_->sentence_number = 1; gf_->start_timestamp_ms = now_ms(); } + } else if (event == OBS_FRONTEND_EVENT_RECORDING_STOPPING) { +#ifdef ENABLE_WEBVTT + remove_webvtt_output(*gf_, + OBSOutputAutoRelease{obs_frontend_get_recording_output()}); +#endif } else if (event == OBS_FRONTEND_EVENT_RECORDING_STOPPED) { if (!gf_->save_only_while_recording || !gf_->rename_file_to_match_recording) { return; @@ -430,6 +597,15 @@ void recording_state_callback(enum obs_frontend_event event, void *data) newPath = recordingPath.parent_path() / newPath.filename(); fs::rename(outputPath, newPath); + } else if (event == OBS_FRONTEND_EVENT_STREAMING_STARTING) { +#ifdef ENABLE_WEBVTT + add_webvtt_output(*gf_, OBSOutputAutoRelease{obs_frontend_get_streaming_output()}); +#endif + } else if (event == OBS_FRONTEND_EVENT_STREAMING_STOPPING) { +#ifdef ENABLE_WEBVTT + remove_webvtt_output(*gf_, + OBSOutputAutoRelease{obs_frontend_get_streaming_output()}); +#endif } } diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 4ca5d91..8518163 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -1,6 +1,11 @@ #ifndef TRANSCRIPTION_FILTER_DATA_H #define TRANSCRIPTION_FILTER_DATA_H +#ifdef ENABLE_WEBVTT +#include +#include +#endif + #include #include #include @@ -23,6 +28,67 @@ #define MAX_PREPROC_CHANNELS 10 +#if !defined(LIBOBS_MAJOR_VERSION) || LIBOBS_MAJOR_VERSION < 31 +struct encoder_packet_time { + /* PTS used to associate uncompressed frames with encoded packets. */ + int64_t pts; + + /* Composition timestamp is when the frame was rendered, + * captured via os_gettime_ns(). + */ + uint64_t cts; + + /* FERC (Frame Encode Request) is when the frame was + * submitted to the encoder for encoding via the encode + * callback (e.g. encode_texture2()), captured via os_gettime_ns(). + */ + uint64_t fer; + + /* FERC (Frame Encode Request Complete) is when + * the associated FER event completed. If the encode + * is synchronous with the call, this means FERC - FEC + * measures the actual encode time, otherwise if the + * encode is asynchronous, it measures the pipeline + * delay between encode request and encode complete. + * FERC is also captured via os_gettime_ns(). + */ + uint64_t ferc; + + /* PIR (Packet Interleave Request) is when the encoded packet + * is interleaved with the stream. PIR is captured via + * os_gettime_ns(). The difference between PIR and CTS gives + * the total latency between frame rendering + * and packet interleaving. + */ + uint64_t pir; +}; +#endif + +using obs_output_add_packet_callback_t = + void(obs_output_t *output, + void (*packet_cb)(obs_output_t *output, struct encoder_packet *pkt, + struct encoder_packet_time *pkt_time, void *param), + void *param); +using obs_output_remove_packet_callback_t = + void(obs_output_t *output, + void (*packet_cb)(obs_output_t *output, struct encoder_packet *pkt, + struct encoder_packet_time *pkt_time, void *param), + void *param); + +extern obs_output_add_packet_callback_t *obs_output_add_packet_callback_; +extern obs_output_remove_packet_callback_t *obs_output_remove_packet_callback_; +extern "C" void load_packet_callback_functions(); + +#ifdef ENABLE_WEBVTT +struct webvtt_muxer_deleter { + void operator()(WebvttMuxer *m) { webvtt_muxer_free(m); } +}; + +struct webvtt_buffer_deleter { + void operator()(WebvttBuffer *b) { webvtt_buffer_free(b); } +}; +#endif + struct transcription_filter_data { obs_source_t *context; // obs filter source (this filter) size_t channels; // number of channels @@ -139,6 +205,21 @@ struct transcription_filter_data { TokenBufferSegmentation buffered_output_output_type = TokenBufferSegmentation::SEGMENTATION_TOKEN; +#ifdef ENABLE_WEBVTT + struct webvtt_output { + OBSWeakOutputAutoRelease output; + uint64_t start_timestamp_ms; + + bool initialized = false; + std::unique_ptr + webvtt_muxer[MAX_OUTPUT_VIDEO_ENCODERS]; + CodecFlavor codec_flavor[MAX_OUTPUT_VIDEO_ENCODERS] = {}; + }; + + std::mutex active_outputs_mutex; + std::vector active_outputs; +#endif + // ctor transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() { @@ -162,7 +243,8 @@ struct transcription_filter_audio_info { }; // Callback sent when the transcription has a new result -void set_text_callback(struct transcription_filter_data *gf, const DetectionResultWithText &str); +void set_text_callback(uint64_t possible_end_ts, struct transcription_filter_data *gf, + const DetectionResultWithText &str); void clear_current_caption(transcription_filter_data *gf_); // Callback sent when the VAD finds an audio chunk. Sample rate = WHISPER_SAMPLE_RATE, channels = 1 diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 0f802d4..6d327e4 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -132,6 +133,12 @@ void transcription_filter_remove(void *data, obs_source_t *source) disconnect_source_signals(gf, source); } +#ifdef ENABLE_WEBVTT + +void remove_all_webvtt_outputs(std::unique_lock &active_outputs_lock, + transcription_filter_data &gf); +#endif + void transcription_filter_destroy(void *data) { struct transcription_filter_data *gf = @@ -159,6 +166,14 @@ void transcription_filter_destroy(void *data) circlebuf_free(&gf->resampled_buffer); +#ifdef ENABLE_WEBVTT + { + auto lock = std::unique_lock(gf->active_outputs_mutex); + remove_all_webvtt_outputs(lock, *gf); + gf->active_outputs.clear(); + } +#endif + if (gf->captions_monitor.isEnabled()) { gf->captions_monitor.stopThread(); } @@ -557,3 +572,25 @@ void transcription_filter_hide(void *data) static_cast(data); obs_log(gf->log_level, "filter hide"); } + +obs_output_add_packet_callback_t *obs_output_add_packet_callback_ = nullptr; +obs_output_remove_packet_callback_t *obs_output_remove_packet_callback_ = nullptr; + +void load_packet_callback_functions() +{ + auto libobs = os_dlopen("obs"); + if (!libobs) + return; + + auto add_callback = os_dlsym(libobs, "obs_output_add_packet_callback"); + auto remove_callback = os_dlsym(libobs, "obs_output_remove_packet_callback"); + if (!add_callback || !remove_callback) + return; + + obs_output_add_packet_callback_ = + reinterpret_cast(add_callback); + obs_output_remove_packet_callback_ = + reinterpret_cast(remove_callback); + + obs_log(LOG_INFO, "loaded callbacks"); +} diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 55f57a3..fb2fc94 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -341,11 +341,13 @@ void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_o pcm32f_size * sizeof(float)); } + auto inference_start_ts = now_ms(); + struct DetectionResultWithText inference_result = run_whisper_inference(gf, pcm32f_data, pcm32f_size_with_silence, start_offset_ms, end_offset_ms, vad_state); // output inference result to a text source - set_text_callback(gf, inference_result); + set_text_callback(inference_start_ts, gf, inference_result); if (gf->enable_audio_chunks_callback && vad_state != VAD_STATE_PARTIAL) { audio_chunk_callback(gf, pcm32f_data, pcm32f_size_with_silence, vad_state, From bc761de21cd147a758b4894eee161aaa885e8d74 Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Thu, 9 Jan 2025 16:25:51 +0100 Subject: [PATCH 5/9] Add webvtt recording/streaming settings --- data/locale/en-US.ini | 3 +++ src/transcription-filter-callbacks.cpp | 23 ++++++++++++++++++++--- src/transcription-filter-data.h | 9 +++++++++ src/transcription-filter-properties.cpp | 17 +++++++++++++++++ src/transcription-filter.cpp | 4 ++++ 5 files changed, 53 insertions(+), 3 deletions(-) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index a2827fd..7bc434f 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -4,6 +4,9 @@ vad_threshold="VAD Threshold" log_level="Internal Log Level" log_words="Log Output to Console" caption_to_stream="Stream Captions" +webvtt_group="WebVTT" +webvtt_caption_to_stream="Add WebVTT captions to stream" +webvtt_caption_to_recording="Add WebVTT captions to recording" subtitle_sources="Output Destination" none_no_output="None / No output" file_output_enable="Save to File" diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index d81fef3..ebba9e7 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -238,6 +238,12 @@ void send_caption_to_webvtt(uint64_t possible_end_ts_ms, DetectionResultWithText { auto lock = std::unique_lock(gf.active_outputs_mutex); for (auto &output : gf.active_outputs) { + if (!gf.webvtt_caption_to_recording && + output.output_type == transcription_filter_data::webvtt_output_type::Recording) + continue; + if (!gf.webvtt_caption_to_stream && + output.output_type == transcription_filter_data::webvtt_output_type::Streaming) + continue; for (size_t i = 0; i < MAX_OUTPUT_VIDEO_ENCODERS; i++) { auto &muxer = output.webvtt_muxer[i]; if (!muxer) @@ -477,17 +483,26 @@ void output_packet_added_callback(obs_output_t *output, struct encoder_packet *p pkt->size = out_data.num - sizeof(ref); } -void add_webvtt_output(transcription_filter_data &gf, obs_output_t *output) +void add_webvtt_output(transcription_filter_data &gf, obs_output_t *output, + transcription_filter_data::webvtt_output_type output_type) { if (!obs_output_add_packet_callback_) return; + if (!gf.webvtt_caption_to_recording && + output_type == transcription_filter_data::webvtt_output_type::Recording) + return; + if (!gf.webvtt_caption_to_stream && + output_type == transcription_filter_data::webvtt_output_type::Streaming) + return; + auto start_ms = now_ms(); auto lock = std::unique_lock(gf.active_outputs_mutex); gf.active_outputs.push_back({}); auto &entry = gf.active_outputs.back(); entry.output = obs_output_get_weak_output(output); + entry.output_type = output_type; entry.start_timestamp_ms = start_ms; obs_output_add_packet_callback_(output, output_packet_added_callback, &gf); } @@ -545,7 +560,8 @@ void recording_state_callback(enum obs_frontend_event event, void *data) static_cast(data); if (event == OBS_FRONTEND_EVENT_RECORDING_STARTING) { #ifdef ENABLE_WEBVTT - add_webvtt_output(*gf_, OBSOutputAutoRelease{obs_frontend_get_recording_output()}); + add_webvtt_output(*gf_, OBSOutputAutoRelease{obs_frontend_get_recording_output()}, + transcription_filter_data::webvtt_output_type::Recording); #endif if (gf_->save_srt && gf_->save_only_while_recording && gf_->output_file_path != "") { @@ -599,7 +615,8 @@ void recording_state_callback(enum obs_frontend_event event, void *data) fs::rename(outputPath, newPath); } else if (event == OBS_FRONTEND_EVENT_STREAMING_STARTING) { #ifdef ENABLE_WEBVTT - add_webvtt_output(*gf_, OBSOutputAutoRelease{obs_frontend_get_streaming_output()}); + add_webvtt_output(*gf_, OBSOutputAutoRelease{obs_frontend_get_streaming_output()}, + transcription_filter_data::webvtt_output_type::Streaming); #endif } else if (event == OBS_FRONTEND_EVENT_STREAMING_STOPPING) { #ifdef ENABLE_WEBVTT diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 8518163..20ceec6 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -206,8 +206,14 @@ struct transcription_filter_data { TokenBufferSegmentation::SEGMENTATION_TOKEN; #ifdef ENABLE_WEBVTT + enum struct webvtt_output_type { + Streaming, + Recording, + }; + struct webvtt_output { OBSWeakOutputAutoRelease output; + webvtt_output_type output_type; uint64_t start_timestamp_ms; bool initialized = false; @@ -218,6 +224,9 @@ struct transcription_filter_data { std::mutex active_outputs_mutex; std::vector active_outputs; + + std::atomic webvtt_caption_to_stream; + std::atomic webvtt_caption_to_recording; #endif // ctor diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 70c7f1e..2e9be99 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -411,6 +411,20 @@ void add_translation_group_properties(obs_properties_t *ppts) MT_("translation_no_repeat_ngram_size"), 1, 10, 1); } +#ifdef ENABLE_WEBVTT +void add_webvtt_group_properties(obs_properties_t *ppts) +{ + auto webvtt_group = obs_properties_create(); + obs_properties_add_group(ppts, "webvtt_enable", MT_("webvtt_group"), OBS_GROUP_CHECKABLE, + webvtt_group); + + obs_properties_add_bool(webvtt_group, "webvtt_caption_to_stream", + MT_("webvtt_caption_to_stream")); + obs_properties_add_bool(webvtt_group, "webvtt_caption_to_recording", + MT_("webvtt_caption_to_recording")); +} +#endif + void add_file_output_group_properties(obs_properties_t *ppts) { // create a file output group @@ -617,6 +631,9 @@ obs_properties_t *transcription_filter_properties(void *data) add_transcription_group_properties(ppts, gf); add_translation_group_properties(ppts); add_translation_cloud_group_properties(ppts); +#ifdef ENABLE_WEBVTT + add_webvtt_group_properties(ppts); +#endif add_file_output_group_properties(ppts); add_buffered_output_group_properties(ppts); add_advanced_group_properties(ppts, gf); diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 6d327e4..77f25d6 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -194,6 +194,10 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->vad_mode = (int)obs_data_get_int(s, "vad_mode"); gf->log_words = obs_data_get_bool(s, "log_words"); gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream"); +#ifdef ENABLE_WEBVTT + gf->webvtt_caption_to_stream = obs_data_get_bool(s, "webvtt_caption_to_stream"); + gf->webvtt_caption_to_recording = obs_data_get_bool(s, "webvtt_caption_to_recording"); +#endif gf->save_to_file = obs_data_get_bool(s, "file_output_enable"); gf->save_srt = obs_data_get_bool(s, "subtitle_save_srt"); gf->truncate_output_file = obs_data_get_bool(s, "truncate_output_file"); From 5d29d02213d5f8760deae771e5e99232d6fe659e Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Thu, 9 Jan 2025 18:22:32 +0100 Subject: [PATCH 6/9] Make latency_to_video_in_msecs and send_frequency_hz configurable --- data/locale/en-US.ini | 2 ++ src/transcription-filter-callbacks.cpp | 3 ++- src/transcription-filter-data.h | 3 +++ src/transcription-filter-properties.cpp | 11 +++++++++++ src/transcription-filter.cpp | 10 ++++++++++ 5 files changed, 28 insertions(+), 1 deletion(-) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index 7bc434f..fcb4a7e 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -7,6 +7,8 @@ caption_to_stream="Stream Captions" webvtt_group="WebVTT" webvtt_caption_to_stream="Add WebVTT captions to stream" webvtt_caption_to_recording="Add WebVTT captions to recording" +webvtt_latency_to_video_in_msecs="Latency to video (milliseconds)" +webvtt_send_frequency_hz="Send frequency (Hz)" subtitle_sources="Output Destination" none_no_output="None / No output" file_output_enable="Save to File" diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index ebba9e7..a6f0705 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -421,6 +421,7 @@ void output_packet_added_callback(obs_output_t *output, struct encoder_packet *p if (!it->initialized) { it->initialized = true; + auto settings_lock = std::unique_lock(gf.webvtt_settings_mutex); for (size_t i = 0; i < MAX_OUTPUT_VIDEO_ENCODERS; i++) { auto encoder = obs_output_get_video_encoder2(output, i); if (!encoder) @@ -441,7 +442,7 @@ void output_packet_added_callback(obs_output_t *output, struct encoder_packet *p auto voi = video_output_get_info(video); auto muxer_builder = webvtt_create_muxer_builder( - 10'000, 2, + gf.latency_to_video_in_msecs, gf.send_frequency_hz, util_mul_div64(1000000000ULL, voi->fps_den, voi->fps_num)); // TODO: change name/language? webvtt_muxer_builder_add_track(muxer_builder, false, false, false, diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 20ceec6..0647ee2 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -225,6 +225,9 @@ struct transcription_filter_data { std::mutex active_outputs_mutex; std::vector active_outputs; + std::mutex webvtt_settings_mutex; + uint16_t latency_to_video_in_msecs; + uint8_t send_frequency_hz; std::atomic webvtt_caption_to_stream; std::atomic webvtt_caption_to_recording; #endif diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 2e9be99..6e6e751 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -422,6 +422,13 @@ void add_webvtt_group_properties(obs_properties_t *ppts) MT_("webvtt_caption_to_stream")); obs_properties_add_bool(webvtt_group, "webvtt_caption_to_recording", MT_("webvtt_caption_to_recording")); + + obs_properties_add_int_slider(webvtt_group, "webvtt_latency_to_video_in_msecs", + MT_("webvtt_latency_to_video_in_msecs"), 0, + std::numeric_limits::max(), 1); + obs_properties_add_int_slider(webvtt_group, "webvtt_send_frequency_hz", + MT_("webvtt_send_frequency_hz"), 1, + std::numeric_limits::max(), 1); } #endif @@ -715,6 +722,10 @@ void transcription_filter_defaults(obs_data_t *s) "{\n\t\"text\":\"{{sentence}}\",\n\t\"target\":\"{{target_language}}\"\n}"); obs_data_set_default_string(s, "translate_cloud_response_json_path", "translations.0.text"); + // webvtt options + obs_data_set_default_int(s, "webvtt_latency_to_video_in_msecs", 10'000); + obs_data_set_default_int(s, "webvtt_send_frequency_hz", 2); + // Whisper parameters apply_whisper_params_defaults_on_settings(s); } diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 77f25d6..93775a1 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -197,6 +197,16 @@ void transcription_filter_update(void *data, obs_data_t *s) #ifdef ENABLE_WEBVTT gf->webvtt_caption_to_stream = obs_data_get_bool(s, "webvtt_caption_to_stream"); gf->webvtt_caption_to_recording = obs_data_get_bool(s, "webvtt_caption_to_recording"); + + { + auto lock = std::unique_lock(gf->webvtt_settings_mutex); + gf->latency_to_video_in_msecs = static_cast(std::max( + 0ll, std::min(static_cast(std::numeric_limits::max()), + obs_data_get_int(s, "webvtt_latency_to_video_in_msecs")))); + gf->send_frequency_hz = static_cast(std::max( + 1ll, std::min(static_cast(std::numeric_limits::max()), + obs_data_get_int(s, "webvtt_send_frequency_hz")))); + } #endif gf->save_to_file = obs_data_get_bool(s, "file_output_enable"); gf->save_srt = obs_data_get_bool(s, "subtitle_save_srt"); From aa81b6602963c1c6e111cb27252893297e4d8d77 Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Thu, 9 Jan 2025 18:28:11 +0100 Subject: [PATCH 7/9] Make webvtt languages configurable --- data/locale/en-US.ini | 1 + src/transcription-filter-callbacks.cpp | 29 ++++++++++++++++++++----- src/transcription-filter-data.h | 4 ++++ src/transcription-filter-properties.cpp | 19 ++++++++++++++++ src/transcription-filter.cpp | 22 +++++++++++++++++++ 5 files changed, 69 insertions(+), 6 deletions(-) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index fcb4a7e..e5a75dc 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -9,6 +9,7 @@ webvtt_caption_to_stream="Add WebVTT captions to stream" webvtt_caption_to_recording="Add WebVTT captions to recording" webvtt_latency_to_video_in_msecs="Latency to video (milliseconds)" webvtt_send_frequency_hz="Send frequency (Hz)" +webvtt_language_description="Language $1" subtitle_sources="Output Destination" none_no_output="None / No output" file_output_enable="Save to File" diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index a6f0705..405a866 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -19,6 +19,7 @@ #include "transcription-utils.h" #include "translation/translation.h" #include "translation/translation-includes.h" +#include "whisper-utils/whisper-language.h" #include "whisper-utils/whisper-utils.h" #include "whisper-utils/whisper-model-utils.h" #include "translation/language_codes.h" @@ -244,6 +245,11 @@ void send_caption_to_webvtt(uint64_t possible_end_ts_ms, DetectionResultWithText if (!gf.webvtt_caption_to_stream && output.output_type == transcription_filter_data::webvtt_output_type::Streaming) continue; + + auto lang_to_track = output.language_to_track.find(result.language); + if (lang_to_track == output.language_to_track.end()) + continue; + for (size_t i = 0; i < MAX_OUTPUT_VIDEO_ENCODERS; i++) { auto &muxer = output.webvtt_muxer[i]; if (!muxer) @@ -255,7 +261,7 @@ void send_caption_to_webvtt(uint64_t possible_end_ts_ms, DetectionResultWithText duration -= output.start_timestamp_ms - segment_start_ts; segment_start_ts = output.start_timestamp_ms; } - webvtt_muxer_add_cue(muxer.get(), 0, + webvtt_muxer_add_cue(muxer.get(), lang_to_track->second, segment_start_ts - output.start_timestamp_ms, duration, str_copy.c_str()); } @@ -444,11 +450,22 @@ void output_packet_added_callback(obs_output_t *output, struct encoder_packet *p auto muxer_builder = webvtt_create_muxer_builder( gf.latency_to_video_in_msecs, gf.send_frequency_hz, util_mul_div64(1000000000ULL, voi->fps_den, voi->fps_num)); - // TODO: change name/language? - webvtt_muxer_builder_add_track(muxer_builder, false, false, false, - "Subtitles", "en", nullptr, nullptr); - webvtt_muxer_builder_add_track(muxer_builder, false, false, false, "Empty", - "en", nullptr, nullptr); + uint8_t track_index = 0; + // FIXME: this may be too lazy, i.e. languages should probably be locked in the signal handler instead + for (auto &lang : gf.active_languages) { + auto lang_it = whisper_available_lang_reverse.find(lang); + if (lang_it == whisper_available_lang.end()) { + obs_log(LOG_WARNING, + "requested language '%s' unknown, track not added", + lang.c_str()); + continue; + } + + webvtt_muxer_builder_add_track(muxer_builder, false, false, false, + lang_it->second.c_str(), + lang.c_str(), nullptr, nullptr); + it->language_to_track[lang] = track_index++; + } it->webvtt_muxer[i].reset(webvtt_muxer_builder_create_muxer(muxer_builder)); } } diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 0647ee2..0fd1b76 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -27,6 +27,7 @@ #include "translation/cloud-translation/translation-cloud.h" #define MAX_PREPROC_CHANNELS 10 +#define MAX_WEBVTT_TRACKS 5 #if !defined(LIBOBS_MAJOR_VERSION) || LIBOBS_MAJOR_VERSION < 31 struct encoder_packet_time { @@ -217,6 +218,7 @@ struct transcription_filter_data { uint64_t start_timestamp_ms; bool initialized = false; + std::map language_to_track; std::unique_ptr webvtt_muxer[MAX_OUTPUT_VIDEO_ENCODERS]; CodecFlavor codec_flavor[MAX_OUTPUT_VIDEO_ENCODERS] = {}; @@ -228,6 +230,8 @@ struct transcription_filter_data { std::mutex webvtt_settings_mutex; uint16_t latency_to_video_in_msecs; uint8_t send_frequency_hz; + std::vector active_languages; + std::atomic webvtt_caption_to_stream; std::atomic webvtt_caption_to_recording; #endif diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 6e6e751..e26c50d 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include "transcription-filter-data.h" #include "transcription-filter.h" @@ -429,6 +430,24 @@ void add_webvtt_group_properties(obs_properties_t *ppts) obs_properties_add_int_slider(webvtt_group, "webvtt_send_frequency_hz", MT_("webvtt_send_frequency_hz"), 1, std::numeric_limits::max(), 1); + + DStr num_buffer, name_buffer, description_buffer; + for (size_t i = 0; i < MAX_WEBVTT_TRACKS; i++) { + dstr_printf(num_buffer, "%zu", i + 1); + dstr_printf(name_buffer, "webvtt_language_%zu", i); + dstr_copy(description_buffer, MT_("webvtt_language_description")); + dstr_replace(description_buffer, "$1", num_buffer->array); + obs_property_t *language_select = obs_properties_add_list( + webvtt_group, name_buffer->array, description_buffer->array, + OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING); + obs_property_list_add_string(language_select, "None", ""); + for (auto const &pair : whisper_available_lang_reverse) { + if (pair.second == "auto") + continue; + obs_property_list_add_string(language_select, pair.first.c_str(), + pair.second.c_str()); + } + } } #endif diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 93775a1..e1d4d64 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -206,6 +207,27 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->send_frequency_hz = static_cast(std::max( 1ll, std::min(static_cast(std::numeric_limits::max()), obs_data_get_int(s, "webvtt_send_frequency_hz")))); + + gf->active_languages.clear(); + DStr name_buffer; + for (size_t i = 0; i < MAX_WEBVTT_TRACKS; i++) { + dstr_printf(name_buffer, "webvtt_language_%zu", i); + if (!obs_data_has_user_value(s, name_buffer->array)) + continue; + + std::string lang = obs_data_get_string(s, name_buffer->array); + if (lang.empty()) + continue; + + if (std::find(gf->active_languages.begin(), gf->active_languages.end(), + lang) != gf->active_languages.end()) { + obs_log(LOG_WARNING, "Not adding duplicate language '%s'", + lang.c_str()); + continue; + } + + gf->active_languages.push_back(lang); + } } #endif gf->save_to_file = obs_data_get_bool(s, "file_output_enable"); From 96dc52644fd66131e24094b8ad7d6a3cec6b72bf Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Thu, 9 Jan 2025 19:05:04 +0100 Subject: [PATCH 8/9] Add translation and main language separately --- src/transcription-filter-callbacks.cpp | 38 ++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 405a866..9c569f6 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -303,6 +303,11 @@ void set_text_callback(uint64_t possible_end_ts, struct transcription_filter_dat } } +#ifdef ENABLE_WEBVTT + if (result.result == DETECTION_RESULT_SPEECH) + send_caption_to_webvtt(possible_end_ts, result, str_copy, *gf); +#endif + bool should_translate_local = gf->translate_only_full_sentences ? result.result == DETECTION_RESULT_SPEECH : true; @@ -343,7 +348,21 @@ void set_text_callback(uint64_t possible_end_ts, struct transcription_filter_dat if (should_translate_cloud) { send_sentence_to_cloud_translation_async( str_copy, gf, result.language, - [gf, result](const std::string &translated_sentence_cloud) { + [gf, result, + possible_end_ts](const std::string &translated_sentence_cloud) { +#ifdef ENABLE_WEBVTT + if (result.result == DETECTION_RESULT_SPEECH) { + auto target_lang = language_codes_to_whisper.find( + gf->translate_cloud_target_language); + if (target_lang != language_codes_to_whisper.end()) { + auto res_copy = result; + res_copy.language = target_lang->second; + send_caption_to_webvtt(possible_end_ts, res_copy, + translated_sentence_cloud, + *gf); + } + } +#endif if (gf->translate_cloud_output != "none") { send_caption_to_source(gf->translate_cloud_output, translated_sentence_cloud, gf); @@ -375,16 +394,23 @@ void set_text_callback(uint64_t possible_end_ts, struct transcription_filter_dat } } +#ifdef ENABLE_WEBVTT + if (should_translate_local && result.result == DETECTION_RESULT_SPEECH) { + auto target_lang = language_codes_to_whisper.find(gf->target_lang); + if (target_lang != language_codes_to_whisper.end()) { + auto res_copy = result; + res_copy.language = target_lang->second; + send_caption_to_webvtt(possible_end_ts, res_copy, translated_sentence_local, + *gf); + } + } +#endif + if (gf->caption_to_stream && result.result == DETECTION_RESULT_SPEECH) { // TODO: add support for partial transcriptions send_caption_to_stream(result, str_copy, gf); } -#ifdef ENABLE_WEBVTT - if (result.result == DETECTION_RESULT_SPEECH) - send_caption_to_webvtt(possible_end_ts, result, str_copy, *gf); -#endif - if (gf->save_to_file && gf->output_file_path != "" && result.result == DETECTION_RESULT_SPEECH) { send_sentence_to_file(gf, result, str_copy, gf->output_file_path, true); From 66ded7a53239eafef81d2fccf9c15d12f778432d Mon Sep 17 00:00:00 2001 From: Ruwen Hahn Date: Mon, 13 Jan 2025 17:08:37 +0100 Subject: [PATCH 9/9] Add rust CI integration --- .github/workflows/build-project.yaml | 14 ++++++++++++++ .github/workflows/check-format.yaml | 14 ++++++++++++++ cmake/BuildWebVTT.cmake | 8 ++++++++ 3 files changed, 36 insertions(+) diff --git a/.github/workflows/build-project.yaml b/.github/workflows/build-project.yaml index 91cabe1..d9398ff 100644 --- a/.github/workflows/build-project.yaml +++ b/.github/workflows/build-project.yaml @@ -119,6 +119,16 @@ jobs: restore-keys: | ${{ runner.os }}-ccache-${{ matrix.architecture }}- + - uses: actions-rust-lang/setup-rust-toolchain@v1 + if: matrix.architecture == 'arm64' + with: + target: aarch64-apple-darwin + + - uses: actions-rust-lang/setup-rust-toolchain@v1 + if: matrix.architecture == 'x86_64' + with: + target: x86_64-apple-darwin + - name: Set Up Codesigning 🔑 uses: ./.github/actions/setup-macos-codesigning if: fromJSON(needs.check-event.outputs.codesign) @@ -197,6 +207,8 @@ jobs: echo "pluginName=${product_name}" >> $GITHUB_OUTPUT echo "pluginVersion=${product_version}" >> $GITHUB_OUTPUT + - uses: actions-rust-lang/setup-rust-toolchain@v1 + - uses: actions/cache@v4 id: ccache-cache with: @@ -271,6 +283,8 @@ jobs: "pluginName=${ProductName}" >> $env:GITHUB_OUTPUT "pluginVersion=${ProductVersion}" >> $env:GITHUB_OUTPUT + - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: Build Plugin 🧱 uses: ./.github/actions/build-plugin with: diff --git a/.github/workflows/check-format.yaml b/.github/workflows/check-format.yaml index e30b916..f2f51e4 100644 --- a/.github/workflows/check-format.yaml +++ b/.github/workflows/check-format.yaml @@ -25,3 +25,17 @@ jobs: uses: ./.github/actions/run-cmake-format with: failCondition: error + + cargo-fmt: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rustfmt + - name: rustfmt + uses: actions-rust-lang/rustfmt@v1 + with: + manifest-path: deps/c-webvtt-in-video-stream/Cargo.toml diff --git a/cmake/BuildWebVTT.cmake b/cmake/BuildWebVTT.cmake index fbae50d..5b2f4ae 100644 --- a/cmake/BuildWebVTT.cmake +++ b/cmake/BuildWebVTT.cmake @@ -1,5 +1,13 @@ include(FetchContent) +set(Rust_RUSTUP_INSTALL_MISSING_TARGET true) + +if(OS_MACOS) + if("$ENV{MACOS_ARCH}" STREQUAL "x86_64") + set(Rust_CARGO_TARGET "x86_64-apple-darwin") + endif() +endif() + FetchContent_Declare( Corrosion GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git