Skip to content

Commit 390e791

Browse files
committed
Fixed #36
1 parent ccb4f08 commit 390e791

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

bbconf/bbconf.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from textwrap import indent
55

66
import yacman
7-
import pipestat
8-
from pipestat.exceptions import RecordNotFoundError
7+
from pipestat import PipestatManager
8+
from pipestat.exceptions import RecordNotFoundError, SchemaError
99

1010
from sqlmodel import SQLModel, Field, select
1111
import qdrant_client
@@ -17,6 +17,7 @@
1717
CFG_PATH_KEY,
1818
CFG_PATH_PIPELINE_OUTPUT_KEY,
1919
CFG_PATH_BEDSTAT_DIR_KEY,
20+
CFG_PATH_SENTENCE2VEC_KEY,
2021
DEFAULT_SECTION_VALUES,
2122
CFG_PATH_BEDBUNCHER_DIR_KEY,
2223
BED_TABLE,
@@ -34,7 +35,7 @@
3435
CFG_QDRANT_API_KEY,
3536
CFG_QDRANT_HOST_KEY,
3637
CFG_QDRANT_COLLECTION_NAME_KEY,
37-
DEFAULT_HF_MODEL,
38+
DEFAULT_SENTENCE2VEC_MODEL,
3839
DEFAULT_VEC2VEC_MODEL,
3940
DEFAULT_REGION2_VEC_MODEL,
4041
CFG_ACCESS_METHOD_KEY,
@@ -49,7 +50,6 @@
4950
from bbconf.helpers import raise_missing_key, get_bedbase_cfg
5051
from bbconf.models import DRSModel, AccessMethod, AccessURL
5152

52-
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to suppress verbose warnings tensorflow
5353
from geniml.text2bednn import text2bednn
5454
from geniml.search import QdrantBackend
5555
from fastembed.embedding import FlagEmbedding
@@ -85,12 +85,12 @@ def __init__(self, config_path: str = None, database_only: bool = False):
8585
# Create Pipestat objects and tables if they do not exist
8686
_LOGGER.debug("Creating pipestat objects...")
8787
self.__pipestats = {
88-
BED_TABLE: pipestat.PipestatManager(
88+
BED_TABLE: PipestatManager(
8989
config_file=cfg_path,
9090
schema_path=BED_TABLE_SCHEMA,
9191
database_only=database_only,
9292
),
93-
BEDSET_TABLE: pipestat.PipestatManager(
93+
BEDSET_TABLE: PipestatManager(
9494
config_file=cfg_path,
9595
schema_path=BEDSET_TABLE_SCHEMA,
9696
database_only=database_only,
@@ -102,6 +102,9 @@ def __init__(self, config_path: str = None, database_only: bool = False):
102102
# setup t2bsi object
103103
self._t2bsi = None
104104
try:
105+
self._senta2vec_hg_model_name = self.config[CFG_PATH_KEY].get(
106+
CFG_PATH_SENTENCE2VEC_KEY, DEFAULT_SENTENCE2VEC_MODEL
107+
)
105108
_LOGGER.debug("Setting up qdrant database connection...")
106109
if self.config[CFG_QDRANT_KEY].get(CFG_QDRANT_API_KEY, None):
107110
os.environ["QDRANT_API_KEY"] = self.config[CFG_QDRANT_KEY].get(
@@ -125,7 +128,7 @@ def __init__(self, config_path: str = None, database_only: bool = False):
125128
except qdrant_client.http.exceptions.ResponseHandlingException as err:
126129
_LOGGER.error(f"error in Connection to qdrant! skipping... Error: {err}")
127130

128-
def _read_config_file(self, config_path: str) -> yacman.YAMLConfigManager:
131+
def _read_config_file(self, config_path: str) -> dict:
129132
"""
130133
Read configuration file and insert default values if not set
131134
@@ -218,7 +221,7 @@ def config(self) -> yacman.YAMLConfigManager:
218221
return self._config
219222

220223
@property
221-
def bed(self) -> pipestat.PipestatManager:
224+
def bed(self) -> PipestatManager:
222225
"""
223226
PipestatManager of the bedfiles table
224227
@@ -227,7 +230,7 @@ def bed(self) -> pipestat.PipestatManager:
227230
return self.__pipestats[BED_TABLE]
228231

229232
@property
230-
def bedset(self) -> pipestat.PipestatManager:
233+
def bedset(self) -> PipestatManager:
231234
"""
232235
PipestatManager of the bedsets table
233236
@@ -460,9 +463,7 @@ def select_unique(self, table_name: str, column: str = None) -> List[dict]:
460463
with self.bedset.backend.session:
461464
values = self.bedset.backend.select_records(columns=column)["records"]
462465
else:
463-
raise pipestat.exceptions.SchemaError(
464-
f"Incorrect table name provided {table_name}"
465-
)
466+
raise SchemaError(f"Incorrect table name provided {table_name}")
466467

467468
return [i for n, i in enumerate(values) if i not in values[n + 1 :]]
468469

@@ -510,9 +511,7 @@ def _create_t2bsi_object(self) -> Union[text2bednn.Text2BEDSearchInterface, None
510511

511512
try:
512513
return text2bednn.Text2BEDSearchInterface(
513-
nl2vec_model=FlagEmbedding(
514-
model_name=os.getenv("HF_MODEL", DEFAULT_HF_MODEL)
515-
),
514+
nl2vec_model=FlagEmbedding(model_name=self._senta2vec_hg_model_name),
516515
vec2vec_model=self._config[CFG_PATH_KEY][CFG_PATH_VEC2VEC_KEY],
517516
search_backend=self.qdrant_backend,
518517
)
@@ -654,7 +653,11 @@ def get_result(
654653
return result
655654

656655
def get_drs_metadata(
657-
self, record_type: str, record_id: str, result_id: str, base_uri: str
656+
self,
657+
record_type: Literal["bed", "bedset"],
658+
record_id: str,
659+
result_id: str,
660+
base_uri: str,
658661
) -> DRSModel:
659662
"""
660663
Get DRS metadata for a bed- or bedset-associated file
@@ -700,4 +703,3 @@ def get_drs_metadata(
700703
)
701704

702705
return drs_dict
703-

bbconf/const.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
CFG_PATH_PIPELINE_OUTPUT_KEY = "pipeline_output_path"
3636
CFG_PATH_REGION2VEC_KEY = "region2vec"
3737
CFG_PATH_VEC2VEC_KEY = "vec2vec"
38+
CFG_PATH_SENTENCE2VEC_KEY = "sentence2vec"
3839

3940

4041
CFG_DATABASE_KEY = "database"
@@ -94,7 +95,7 @@
9495
},
9596
}
9697

97-
DEFAULT_HF_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
98+
DEFAULT_SENTENCE2VEC_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
9899
DEFAULT_VEC2VEC_MODEL = "databio/v2v-MiniLM-v2-ATAC-hg38"
99100
DEFAULT_REGION2_VEC_MODEL = "databio/r2v-ChIP-atlas-hg38"
100101

0 commit comments

Comments
 (0)