Skip to content

Commit

Permalink
metas出力時に話者情報をマージする (#728)
Browse files Browse the repository at this point in the history
* `metas`出力時に話者情報をマージする

* スタイルもソートする

* `SpeakerMeta::{speaker_order,style_order}`を導入

* テストを追加

* `merge`にdoc

* 話者情報が食い違うやつはwarning止まりにする

* `StyleMeta`自体に`order`を持たせる
  • Loading branch information
qryxip authored Feb 9, 2024
1 parent eafd765 commit 0f2c3b2
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 37 deletions.
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/__internal/interop.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub use crate::synthesizer::blocking::PerformInference;
pub use crate::{metas::merge as merge_metas, synthesizer::blocking::PerformInference};
43 changes: 25 additions & 18 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use std::{
use anyhow::bail;
use educe::Educe;
use enum_map::{Enum as _, EnumMap};
use indexmap::IndexMap;
use itertools::{iproduct, Itertools as _};

use crate::{
error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{InferenceOperation, ParamInfo},
manifest::ModelInnerId,
metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta},
metas::{self, SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta},
voice_model::{VoiceModelHeader, VoiceModelId},
Result,
};
Expand Down Expand Up @@ -119,7 +120,7 @@ impl<R: InferenceRuntime, D: InferenceDomain> Status<R, D> {
#[derive(Educe)]
#[educe(Default(bound = "R: InferenceRuntime, D: InferenceDomain"))]
struct LoadedModels<R: InferenceRuntime, D: InferenceDomain>(
BTreeMap<VoiceModelId, LoadedModel<R, D>>,
IndexMap<VoiceModelId, LoadedModel<R, D>>,
);

struct LoadedModel<R: InferenceRuntime, D: InferenceDomain> {
Expand All @@ -130,11 +131,7 @@ struct LoadedModel<R: InferenceRuntime, D: InferenceDomain> {

impl<R: InferenceRuntime, D: InferenceDomain> LoadedModels<R, D> {
fn metas(&self) -> VoiceModelMeta {
self.0
.values()
.flat_map(|LoadedModel { metas, .. }| metas)
.cloned()
.collect()
metas::merge(self.0.values().flat_map(|LoadedModel { metas, .. }| metas))
}

fn ids_for(&self, style_id: StyleId) -> Result<(VoiceModelId, ModelInnerId)> {
Expand Down Expand Up @@ -184,20 +181,29 @@ impl<R: InferenceRuntime, D: InferenceDomain> LoadedModels<R, D> {
///
/// # Errors
///
/// 音声モデルIDかスタイルIDが`model_header`と重複するとき、エラーを返す。
/// 次の場合にエラーを返す。
///
/// - 音声モデルIDかスタイルIDが`model_header`と重複するとき
fn ensure_acceptable(&self, model_header: &VoiceModelHeader) -> LoadModelResult<()> {
let loaded = self.styles();
let external = model_header
.metas
.iter()
.flat_map(|speaker| speaker.styles());

let error = |context| LoadModelError {
path: model_header.path.clone(),
context,
source: None,
};

let loaded = self.speakers();
let external = model_header.metas.iter();
for (loaded, external) in iproduct!(loaded, external) {
if loaded.speaker_uuid() == external.speaker_uuid() {
loaded.warn_diff_except_styles(external);
}
}

let loaded = self.styles();
let external = model_header
.metas
.iter()
.flat_map(|speaker| speaker.styles());
if self.0.contains_key(&model_header.id) {
return Err(error(LoadModelErrorKind::ModelAlreadyLoaded {
id: model_header.id.clone(),
Expand Down Expand Up @@ -242,11 +248,12 @@ impl<R: InferenceRuntime, D: InferenceDomain> LoadedModels<R, D> {
Ok(())
}

fn speakers(&self) -> impl Iterator<Item = &SpeakerMeta> + Clone {
self.0.values().flat_map(|LoadedModel { metas, .. }| metas)
}

fn styles(&self) -> impl Iterator<Item = &StyleMeta> {
self.0
.values()
.flat_map(|LoadedModel { metas, .. }| metas)
.flat_map(|speaker| speaker.styles())
self.speakers().flat_map(|speaker| speaker.styles())
}
}

Expand Down
193 changes: 191 additions & 2 deletions crates/voicevox_core/src/metas.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,42 @@
use std::fmt::Display;
use std::fmt::{Debug, Display};

use derive_getters::Getters;
use derive_new::new;
use indexmap::IndexMap;
use itertools::Itertools as _;
use serde::{Deserialize, Serialize};
use tracing::warn;

/// [`speaker_uuid`]をキーとして複数の[`SpeakerMeta`]をマージする。
///
/// マージする際話者は[`SpeakerMeta::order`]、スタイルは[`StyleMeta::order`]をもとに安定ソートされる。
/// `order`が無い話者とスタイルは、そうでないものよりも後ろに置かれる。
///
/// [`speaker_uuid`]: SpeakerMeta::speaker_uuid
pub fn merge<'a>(metas: impl IntoIterator<Item = &'a SpeakerMeta>) -> Vec<SpeakerMeta> {
return metas
.into_iter()
.fold(IndexMap::<_, SpeakerMeta>::new(), |mut acc, speaker| {
acc.entry(&speaker.speaker_uuid)
.and_modify(|acc| acc.styles.extend(speaker.styles.clone()))
.or_insert_with(|| speaker.clone());
acc
})
.into_values()
.update(|speaker| {
speaker
.styles
.sort_by_key(|&StyleMeta { order, .. }| key(order));
})
.sorted_by_key(|&SpeakerMeta { order, .. }| key(order))
.collect();

fn key(order: Option<u32>) -> impl Ord {
order
.map(Into::into)
.unwrap_or_else(|| u64::from(u32::MAX) + 1)
}
}

/// [`StyleId`]の実体。
///
Expand All @@ -15,7 +49,7 @@ pub type RawStyleId = u32;
///
/// [**話者**(_speaker_)]: SpeakerMeta
/// [**スタイル**(_style_)]: StyleMeta
#[derive(PartialEq, Eq, Clone, Copy, Ord, PartialOrd, Deserialize, Serialize, new, Debug)]
#[derive(PartialEq, Eq, Clone, Copy, Ord, Hash, PartialOrd, Deserialize, Serialize, new, Debug)]
pub struct StyleId(RawStyleId);

impl StyleId {
Expand Down Expand Up @@ -65,6 +99,52 @@ pub struct SpeakerMeta {
version: StyleVersion,
/// 話者のUUID。
speaker_uuid: String,
/// 話者の順番。
///
/// `SpeakerMeta`の列は、この値に対して昇順に並んでいるべきである。
order: Option<u32>,
}

impl SpeakerMeta {
/// # Panics
///
/// `speaker_uuid`が異なるときパニックする。
pub(crate) fn warn_diff_except_styles(&self, other: &Self) {
let Self {
name: name1,
styles: _,
version: version1,
speaker_uuid: speaker_uuid1,
order: order1,
} = self;

let Self {
name: name2,
styles: _,
version: version2,
speaker_uuid: speaker_uuid2,
order: order2,
} = other;

if speaker_uuid1 != speaker_uuid2 {
panic!("must be equal: {speaker_uuid1} != {speaker_uuid2:?}");
}

warn_diff(speaker_uuid1, "name", name1, name2);
warn_diff(speaker_uuid1, "version", version1, version2);
warn_diff(speaker_uuid1, "order", order1, order2);

fn warn_diff<T: PartialEq + Debug>(
speaker_uuid: &str,
field_name: &str,
left: &T,
right: &T,
) {
if left != right {
warn!("`{speaker_uuid}`: different `{field_name}` ({left:?} != {right:?})");
}
}
}
}

/// **スタイル**(_style_)のメタ情報。
Expand All @@ -74,4 +154,113 @@ pub struct StyleMeta {
id: StyleId,
/// スタイル名。
name: String,
/// スタイルの順番。
///
/// [`SpeakerMeta::styles`]は、この値に対して昇順に並んでいるべきである。
order: Option<u32>,
}

#[cfg(test)]
mod tests {
use once_cell::sync::Lazy;
use serde_json::json;

#[test]
fn merge_works() -> anyhow::Result<()> {
static INPUT: Lazy<serde_json::Value> = Lazy::new(|| {
json!([
{
"name": "B",
"styles": [
{
"id": 3,
"name": "B_1",
"order": 0
}
],
"version": "0.0.0",
"speaker_uuid": "f34ab151-c0f5-4e0a-9ad2-51ce30dba24d",
"order": 1
},
{
"name": "A",
"styles": [
{
"id": 2,
"name": "A_3",
"order": 2
}
],
"version": "0.0.0",
"speaker_uuid": "d6fd707c-a451-48e9-8f00-fe9ee3bf6264",
"order": 0
},
{
"name": "A",
"styles": [
{
"id": 1,
"name": "A_1",
"order": 0
},
{
"id": 0,
"name": "A_2",
"order": 1
}
],
"version": "0.0.0",
"speaker_uuid": "d6fd707c-a451-48e9-8f00-fe9ee3bf6264",
"order": 0
}
])
});

static EXPECTED: Lazy<serde_json::Value> = Lazy::new(|| {
json!([
{
"name": "A",
"styles": [
{
"id": 1,
"name": "A_1",
"order": 0
},
{
"id": 0,
"name": "A_2",
"order": 1
},
{
"id": 2,
"name": "A_3",
"order": 2
}
],
"version": "0.0.0",
"speaker_uuid": "d6fd707c-a451-48e9-8f00-fe9ee3bf6264",
"order": 0
},
{
"name": "B",
"styles": [
{
"id": 3,
"name": "B_1",
"order": 0
}
],
"version": "0.0.0",
"speaker_uuid": "f34ab151-c0f5-4e0a-9ad2-51ce30dba24d",
"order": 1
}
])
});

let input = &serde_json::from_value::<Vec<_>>(INPUT.clone())?;
let actual = serde_json::to_value(super::merge(input))?;

pretty_assertions::assert_eq!(*EXPECTED, actual);
Ok(())
}
}
12 changes: 11 additions & 1 deletion crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ pub type RawVoiceModelId = String;

/// 音声モデルID。
#[derive(
PartialEq, Eq, Clone, Ord, PartialOrd, Deserialize, new, Getters, derive_more::Display, Debug,
PartialEq,
Eq,
Clone,
Ord,
Hash,
PartialOrd,
Deserialize,
new,
Getters,
derive_more::Display,
Debug,
)]
pub struct VoiceModelId {
raw_voice_model_id: RawVoiceModelId,
Expand Down
4 changes: 3 additions & 1 deletion crates/voicevox_core_c_api/src/compatible_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ static VOICE_MODEL_SET: Lazy<VoiceModelSet> = Lazy::new(|| {
.iter()
.map(|vvm| (vvm.id().clone(), vvm.clone()))
.collect();
let metas: Vec<_> = all_vvms.iter().flat_map(|vvm| vvm.metas()).collect();
let metas = voicevox_core::__internal::interop::merge_metas(
all_vvms.iter().flat_map(|vvm| vvm.metas()),
);
let mut style_model_map = BTreeMap::default();
for vvm in all_vvms.iter() {
for meta in vvm.metas().iter() {
Expand Down
Loading

0 comments on commit 0f2c3b2

Please sign in to comment.