Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for converting VectorFst to ConstFst to ffi/python lib #268

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions rustfst-ffi/src/fst/const_fst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,21 @@ pub unsafe extern "C" fn const_fst_display(
Ok(())
})
}

/// # Safety
///
/// The pointers should be valid.
#[no_mangle]
pub unsafe extern "C" fn const_fst_from_vec_fst(
vec_fst_prt: *const CFst,
const_fst_ptr: *mut *const CFst,
) -> RUSTFST_FFI_RESULT {
wrap(|| {
let fst = get!(CFst, vec_fst_prt);
let vec_fst = as_fst!(VectorFst<TropicalWeight>, fst);
let const_fst = ConstFst::from(vec_fst.clone());
let raw_pointer = CFst(Box::new(const_fst)).into_raw_pointer();
unsafe { *const_fst_ptr = raw_pointer };
Ok(())
})
}
6 changes: 3 additions & 3 deletions rustfst-python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
"setuptools>=62.1<63",
"setuptools_rust>=1.3<1.4",
"wheel>=0.34<0.35",
"setuptools>=62.1,<63",
"setuptools_rust>=1.3,<1.4",
"wheel>=0.34,<0.35",
]
19 changes: 19 additions & 0 deletions rustfst-python/rustfst/fst/const_fst.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)

from rustfst.fst import Fst
from rustfst.fst.vector_fst import VectorFst
from rustfst.symbol_table import SymbolTable
from rustfst.drawing_config import DrawingConfig
from typing import Optional, Union
Expand Down Expand Up @@ -105,6 +106,24 @@ def read(cls, filename: Union[str, Path]) -> ConstFst:

return cls(ptr=fst)

@classmethod
def from_vector_fst(cls, fst: VectorFst) -> ConstFst:
"""
Converts a given `VectorFst` to `ConstFst`
Args:
fst: The `VectorFst` that should be converted
Returns:
A `ConstFst`
Raises:
ValueError: Conversion failed
"""
const_fst = ctypes.pointer(ctypes.c_void_p())
ret_code = lib.const_fst_from_vec_fst(fst.ptr, ctypes.byref(const_fst))
err_msg = "Failed to convert VectorFST to ConstFST"
check_ffi_error(ret_code, err_msg)

return cls(ptr=const_fst)

def write(self, filename: Union[str, Path]):
"""
Serializes FST to a file.
Expand Down
15 changes: 14 additions & 1 deletion rustfst-python/tests/test_fst.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from rustfst import VectorFst, Tr, SymbolTable
from rustfst import VectorFst, Tr, SymbolTable, ConstFst
import pytest
from tempfile import NamedTemporaryFile

Expand Down Expand Up @@ -353,3 +353,16 @@ def test_fst_relabel_tables():
assert fst_3 == fst_ref
assert fst_3.input_symbols() == new_isymt
assert fst_3.output_symbols() == new_osymt


def test_const_fst_from_vector_fst():
fst = VectorFst()
s1 = fst.add_state()
s2 = fst.add_state()
fst.add_tr(s1, Tr(1, 2, weight_one(), s2))
fst.set_start(s1)
fst.set_final(s2)

const_fst = ConstFst.from_vector_fst(fst)

assert const_fst.num_trs(const_fst.start()) == 1