diff --git a/bbconf/bbconf.py b/bbconf/bbconf.py index f467d1d..fd36cd1 100644 --- a/bbconf/bbconf.py +++ b/bbconf/bbconf.py @@ -114,12 +114,21 @@ def __init__(self, config_path: str = None, database_only: bool = False): if self.config[CFG_PATH_KEY].get(CFG_PATH_REGION2VEC_KEY) and self.config[ CFG_PATH_KEY ].get(CFG_PATH_VEC2VEC_KEY): + self.region2vec_model = self.config[CFG_PATH_KEY].get( + CFG_PATH_REGION2VEC_KEY + ) self._t2bsi = self._create_t2bsi_object() else: if not self.config[CFG_PATH_KEY].get(CFG_PATH_REGION2VEC_KEY): _LOGGER.debug( f"{CFG_PATH_REGION2VEC_KEY} was not provided in config file! Using default.." ) + self.region2vec_model = DEFAULT_REGION2_VEC_MODEL + else: + self.region2vec_model = self.config[CFG_PATH_KEY].get( + CFG_PATH_REGION2VEC_KEY + ) + if not self.config[CFG_PATH_KEY].get(CFG_PATH_VEC2VEC_KEY): self.config[CFG_PATH_KEY][ CFG_PATH_VEC2VEC_KEY @@ -555,9 +564,8 @@ def add_bed_to_qdrant( raise BedBaseConfError( "Could not add add region to qdrant. Invalid type, or path. " ) - if not region_to_vec: - reg_2_vec_obj = Region2VecExModel(DEFAULT_REGION2_VEC_MODEL) - else: + if not region_to_vec or isinstance(self.region2vec_model, str): + reg_2_vec_obj = Region2VecExModel(self.region2vec_model) reg_2_vec_obj = region_to_vec bed_embedding = reg_2_vec_obj.encode( bed_region_set,