diff --git a/querent/config/core/gpt_llm_config.py b/querent/config/core/gpt_llm_config.py deleted file mode 100644 index 0cea459e..00000000 --- a/querent/config/core/gpt_llm_config.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional -from querent.config.core.llm_config import LLM_Config -import os - -class GPTConfig(LLM_Config): - id: str = "" - name: str = "OPENAIEngine" - description: str = "An engine for NER using BERT and knowledge graph operations using OPENAI" - version: str = "0.0.1" - logger: str = "OPENAI.engine_config" - ner_model_name: str = "dbmdz/bert-large-cased-finetuned-conll03-english" - rel_model_name: str = "gpt-3.5-turbo" - requests_per_minute: int = 3 - openai_api_key: str = "" - user_context: str = None - huggingface_token: Optional[str] = None - - def __init__(self, config_source=None, **kwargs): - config_data = {} - config_data.update(kwargs) - if config_source: - config_data = self.load_config(config_source) - if "config" in config_data: - config_data.update(config_data["config"]) - super().__init__(**config_data) - - - @classmethod - def load_config(cls, config_source) -> dict: - if isinstance(config_source, dict): - # If config source is a dictionary, return a dictionary - cls.config_data = config_source - else: - raise ValueError("Invalid config. Must be a valid dictionary") - - env_vars = dict(os.environ) - cls.config_data.update(env_vars) - return cls.config_data \ No newline at end of file diff --git a/querent/config/core/llm_config.py b/querent/config/core/llm_config.py index 744475e6..c1d888fa 100644 --- a/querent/config/core/llm_config.py +++ b/querent/config/core/llm_config.py @@ -2,7 +2,6 @@ from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional import os - from querent.config.engine.engine_config import EngineConfig class LLM_Config(EngineConfig): @@ -11,11 +10,11 @@ class LLM_Config(EngineConfig): description: str = "An engine for NER and knowledge graph operations." version: str = "0.0.1" logger: str = "LLM.engine_config" - ner_model_name: str = "dbmdz/bert-large-cased-finetuned-conll03-english" + ner_model_name: str = "English" spacy_model_path: str = 'en_core_web_lg' nltk_path: str = '/model/nltk_data' - rel_model_type: str = 'llama' - rel_model_path: str = './tests/llama-2-7b-chat.Q5_K_M.gguf' + rel_model_type: str = 'bert' + rel_model_path: str = 'bert-base-uncased' grammar_file_path: str = './querent/kg/rel_helperfunctions/json.gbnf' emb_model_name: str = 'sentence-transformers/all-MiniLM-L6-v2' user_context: str = Field(default="In a semantic triple (Subject, Predicate & Object) framework, determine which of the above entity is the subject and which is the object based on the context along with the predicate between these entities. Please also identify the subject type, object type & predicate type.") diff --git a/querent/config/core/opensource_llm_config.py b/querent/config/core/opensource_llm_config.py index ff5f33e3..13e9b124 100644 --- a/querent/config/core/opensource_llm_config.py +++ b/querent/config/core/opensource_llm_config.py @@ -8,7 +8,7 @@ class Opensource_LLM_Config(BaseModel): version: str = "0.0.1" logger: str = "RelationshipExtractor.engine_config" model_type: str = 'llama' - model_path: str = './tests/llama-2-7b-chat.Q5_K_M.gguf' + model_path: str = '' grammar_file_path: str = './querent/kg/rel_helperfunctions/json.gbnf' qa_template: str = Field(default=None) emb_model_name: str = 'sentence-transformers/all-MiniLM-L6-v2' diff --git a/querent/core/transformers/bert_ner_opensourcellm.py b/querent/core/transformers/bert_ner_opensourcellm.py index 0e5bca00..510c84eb 100644 --- a/querent/core/transformers/bert_ner_opensourcellm.py +++ b/querent/core/transformers/bert_ner_opensourcellm.py @@ -1,9 +1,7 @@ import json -import re -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer +import transformers import time - -import unidecode from querent.common.types.ingested_table import IngestedTables from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor from querent.common.types.ingested_images import IngestedImages @@ -28,76 +26,155 @@ from querent.config.core.llm_config import LLM_Config from querent.kg.rel_helperfunctions.triple_to_json import TripleToJsonConverter from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore +from querent.models.model_manager import ModelManager +from querent.models.gguf_metadata_extractor import GGUFMetadataExtractor +from querent.kg.rel_helperfunctions.attn_based_relationship_model_getter import get_model +from querent.kg.rel_helperfunctions.attn_based_relationship_filter import process_tokens, trim_triples class BERTLLM(BaseEngine): def __init__( self, - input_queue:QuerentQueue, + input_queue: QuerentQueue, config: LLM_Config, Embedding=None - ): + ): self.logger = setup_logger(__name__, "BERTLLM") super().__init__(input_queue) - self.skip_inferences=config.skip_inferences + self.skip_inferences = config.skip_inferences + self.enable_filtering = config.enable_filtering + self.filter_params = config.filter_params or {} + self.sample_entities = config.sample_entities + self.fixed_entities = config.fixed_entities + self.fixed_relationships = config.fixed_relationships + self.sample_relationships = config.sample_relationships + self.user_context = config.user_context + self.isConfinedSearch = config.is_confined_search + self.attn_based_rel_extraction = True + self.create_emb = EmbeddingStore() if not Embedding else Embedding + try: - self.graph_config = GraphConfig(identifier=config.name) - self.contextual_graph = QuerentKG(self.graph_config) - self.semantic_graph = QuerentKG(self.graph_config) - self.file_buffer = FileBuffer() - self.ner_tokenizer = AutoTokenizer.from_pretrained(config.ner_model_name) - self.ner_model = NER_LLM.load_model(config.ner_model_name, "NER") - self.ner_llm_instance = NER_LLM(provided_tokenizer=self.ner_tokenizer, provided_model=self.ner_model) - self.nlp_model = NER_LLM.set_nlp_model(config.spacy_model_path) - self.nlp_model = NER_LLM.get_class_variable() - if not Embedding: - self.create_emb = EmbeddingStore() - else: - self.create_emb = Embedding - if not self.skip_inferences: - mock_config = Opensource_LLM_Config(qa_template=config.user_context, - model_type = config.rel_model_type, - model_path = config.rel_model_path, - grammar_file_path = config.grammar_file_path, - emb_model_name = config.emb_model_name, - spacy_model_path = config.spacy_model_path, - nltk_path = config.nltk_path - ) - self.semantic_extractor = RelationExtractor(mock_config,self.create_emb) - self.attn_scores_instance = EntityAttentionExtractor(model=self.ner_model, tokenizer=self.ner_tokenizer) - self.enable_filtering = config.enable_filtering - self.filter_params = config.filter_params or {} - self.triple_filter = None + self._initialize_components(config) + self._initialize_models(config) + self._initialize_extractors(config) + self._initialize_entity_context_extractor() + self._initialize_predicate_context_extractor(config) + if self.enable_filtering: self.triple_filter = TripleFilter(**self.filter_params) - self.sample_entities = config.sample_entities - self.fixed_entities = config.fixed_entities - if self.fixed_entities and not self.sample_entities: - raise ValueError("If specific entities are provided, their types should also be provided.") - if self.fixed_entities and self.sample_entities: - self.entity_context_extractor = FixedEntityExtractor(fixed_entities=self.fixed_entities, entity_types=self.sample_entities,model = self.nlp_model) - elif self.sample_entities: - self.entity_context_extractor = FixedEntityExtractor(entity_types=self.sample_entities, model = self.nlp_model) - else: - self.entity_context_extractor = None - self.fixed_relationships = config.fixed_relationships - self.sample_relationships = config.sample_relationships - if self.fixed_relationships and not self.sample_relationships: - raise ValueError("If specific predicates are provided, their types should also be provided.") - if self.fixed_relationships and self.sample_relationships: - self.predicate_context_extractor = FixedPredicateExtractor(fixed_predicates=self.fixed_relationships, predicate_types=self.sample_relationships,model = self.nlp_model) - self.predicate_json = self.predicate_context_extractor.construct_predicate_json(self.fixed_relationships, self.sample_relationships) - self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json) - elif self.sample_relationships: - self.predicate_context_extractor = FixedPredicateExtractor(predicate_types=self.sample_relationships,model = self.nlp_model) - self.predicate_json = self.predicate_context_extractor.construct_predicate_json(relationship_types=self.sample_relationships) - self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json) else: - self.predicate_context_extractor = None - self.user_context = config.user_context - self.isConfinedSearch = config.is_confined_search + self.triple_filter = None + except Exception as e: - self.logger.error("Error initializing BERT LLM Class", e) + self.logger.error("Error initializing BERT LLM Class") raise e + + def _initialize_components(self, config): + self.graph_config = GraphConfig(identifier=config.name) + self.contextual_graph = QuerentKG(self.graph_config) + self.semantic_graph = QuerentKG(self.graph_config) + self.file_buffer = FileBuffer() + self.model_manager = ModelManager() + + def _initialize_models(self, config): + self.ner_model_initialized = self.model_manager.get_model(config.ner_model_name) + if not self.skip_inferences and self.attn_based_rel_extraction == False: + extractor = GGUFMetadataExtractor(config.rel_model_path) + model_metadata = extractor.dump_metadata() + rel_model_name = extractor.extract_general_name(model_metadata) + self.rel_model_initialized = self.model_manager.get_model(rel_model_name, model_path=config.rel_model_path) + self.ner_llm_instance = NER_LLM(ner_model_name=self.ner_model_initialized) + self.ner_tokenizer = self.ner_llm_instance.ner_tokenizer + self.ner_model = self.ner_llm_instance.ner_model + self.nlp_model = NER_LLM.set_nlp_model(config.spacy_model_path) + self.nlp_model = NER_LLM.get_class_variable() + + def _initialize_extractors(self, config): + if not self.skip_inferences and self.attn_based_rel_extraction == False: + mock_config = Opensource_LLM_Config( + qa_template=config.user_context, + model_type=config.rel_model_type, + model_path=self.rel_model_initialized, + grammar_file_path=config.grammar_file_path, + emb_model_name=config.emb_model_name, + spacy_model_path=config.spacy_model_path, + nltk_path=config.nltk_path + ) + self.semantic_extractor = RelationExtractor(mock_config, self.create_emb) + + elif not self.skip_inferences and self.attn_based_rel_extraction == True: + # config.rel_model_path = 'bert-base-uncased' + config.rel_model_path = self.ner_model_initialized + model_config = AutoConfig.from_pretrained(config.rel_model_path) + if 'bert' in model_config.model_type.lower(): + self.ner_helper_instance = NER_LLM(ner_model_name=config.rel_model_path) + self.ner_helper_tokenizer = self.ner_helper_instance.ner_tokenizer + self.ner_helper_model = self.ner_helper_instance.ner_model + self.extractor = get_model("bert",model_tokenizer= self.ner_helper_tokenizer,model=self.ner_helper_model) + elif 'llama' in model_config.model_type.lower() or 'mpt' in model_config.model_type.lower(): + # model_id = "TheBloke/Llama-2-7B-GGUF" + # filename = "llama-2-7b.Q5_K_M.gguf" + # self.ner_tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename) + # self.model = transformers.AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename) + # self.ner_helper_instance = NER_LLM(provided_tokenizer =self.ner_tokenizer, provided_model=self.model) + self.model = transformers.AutoModelForCausalLM.from_pretrained(config.rel_model_path,trust_remote_code=True) + # self.ner_helper_instance = NER_LLM(ner_model_name= config.rel_model_path, provided_model=self.model) + self.ner_helper_instance = self.ner_llm_instance + self.ner_helper_tokenizer = self.ner_helper_instance.ner_tokenizer + self.ner_helper_model = self.ner_helper_instance.ner_model + self.extractor = get_model("llama",model_tokenizer= self.ner_helper_tokenizer,model=self.ner_helper_model) + else: + raise ValueError("Selected Model not supported for Attnetion Based Graph Extraction") + self.attn_scores_instance = EntityAttentionExtractor(model=self.ner_model, tokenizer=self.ner_tokenizer) + + def _initialize_entity_context_extractor(self): + if self.fixed_entities and not self.sample_entities: + raise ValueError("If specific entities are provided, their types should also be provided.") + + if self.fixed_entities and self.sample_entities: + self.entity_context_extractor = FixedEntityExtractor( + fixed_entities=self.fixed_entities, + entity_types=self.sample_entities, + model=self.nlp_model + ) + elif self.sample_entities: + self.entity_context_extractor = FixedEntityExtractor( + entity_types=self.sample_entities, + model=self.nlp_model + ) + else: + self.entity_context_extractor = None + + def _initialize_predicate_context_extractor(self, config): + if self.fixed_relationships and not self.sample_relationships: + raise ValueError("If specific predicates are provided, their types should also be provided.") + + self.predicate_json = None + if self.skip_inferences: + self.predicate_context_extractor = None + elif self.fixed_relationships and self.sample_relationships: + self.predicate_context_extractor = FixedPredicateExtractor( + fixed_predicates=self.fixed_relationships, + predicate_types=self.sample_relationships, + model=self.nlp_model + ) + self.predicate_json = self.predicate_context_extractor.construct_predicate_json( + self.fixed_relationships, + self.sample_relationships + ) + elif self.sample_relationships: + self.predicate_context_extractor = FixedPredicateExtractor( + predicate_types=self.sample_relationships, + model=self.nlp_model + ) + self.predicate_json = self.predicate_context_extractor.construct_predicate_json( + relationship_types=self.sample_relationships + ) + else: + self.predicate_context_extractor = None + + if self.predicate_json: + self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json) + @@ -168,7 +245,18 @@ async def process_images(self, data: IngestedImages): if graph_json: current_state = EventState(event_type=EventType.Graph, timestamp=time.time(), payload=graph_json, file=file, doc_source=doc_source, image_id=unique_id) await self.set_state(new_state=current_state) - vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(updated_tuple, blob)) + subject, json_str, object_ = updated_tuple + context = json.loads(json_str) + sen_emb = self.create_emb.get_embeddings([context['context']])[0] + sub_emb = self.create_emb.get_embeddings(subject)[0] + obj_emb = self.create_emb.get_embeddings(object_)[0] + predicate_score=1 + final_emb = TripleToJsonConverter.dynamic_weighted_average_embeddings( + [sub_emb, obj_emb, sen_emb], + base_weights=[predicate_score, predicate_score, 3], + normalize_weights=True # Normalize weights to ensure they sum to 1 + ) + vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(updated_tuple, blob, final_emb)) if vector_json: current_state = EventState(event_type=EventType.Vector, timestamp=time.time(), payload=vector_json, file=file, doc_source=doc_source, image_id=unique_id) await self.set_state(new_state=current_state) @@ -207,86 +295,154 @@ def set_filter_params(self, **kwargs): async def process_tokens(self, data: IngestedTokens): - doc_entity_pairs = [] try: + doc_entity_pairs = [] doc_source = data.doc_source + if not BERTLLM.validate_ingested_tokens(data): - self.set_termination_event() - return - if data.data: - clean_text = ' '.join(data.data) - else: - clean_text = data.data - if not data.is_token_stream : - file, content = self.file_buffer.add_chunk( - data.get_file_path(), clean_text) - else: - content = clean_text - file = data.get_file_path() - if content: - if self.fixed_entities: - content = self.entity_context_extractor.find_entity_sentences(content) - if self.fixed_relationships: - content = self.predicate_context_extractor.find_predicate_sentences(content) - (_, doc_entity_pairs) = self.ner_llm_instance.get_entity_pairs(isConfinedSearch= self.isConfinedSearch, - content=content, - fixed_entities=self.fixed_entities, - sample_entities=self.sample_entities) - else: + self.set_termination_event() return - if self.sample_entities: - doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs) - if any(doc_entity_pairs): - doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs) + + content, file = self._prepare_content(data) + if not content: + return + if self.fixed_entities: + content = self.entity_context_extractor.find_entity_sentences(content) + doc_entity_pairs = self._get_entity_pairs(content) + if not doc_entity_pairs: + return + + doc_entity_pairs = self._process_entity_types(doc_entity_pairs) + if not self.entity_context_extractor and not self.predicate_context_extractor: pairs_withattn = self.attn_scores_instance.extract_and_append_attention_weights(doc_entity_pairs) - if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor: - self.entity_embedding_extractor = EntityEmbeddingExtractor(self.ner_model, self.ner_tokenizer) - pairs_withemb = self.entity_embedding_extractor.extract_and_append_entity_embeddings(pairs_withattn) - else: - pairs_withemb = pairs_withattn - pairs_with_predicates = process_data(pairs_withemb, file) - if self.enable_filtering == True and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn)>1 and not self.predicate_context_extractor: - cluster_output = self.triple_filter.cluster_triples(pairs_with_predicates) - clustered_triples = cluster_output['filtered_triples'] - cluster_labels = cluster_output['cluster_labels'] - cluster_persistence = cluster_output['cluster_persistence'] - if clustered_triples: - filtered_triples, reduction_count = self.triple_filter.filter_triples(clustered_triples) - else: - self.logger.debug(f"Filtering in {self.__class__.__name__} producing 0 entity pairs. Filtering Disabled. ") - filtered_triples = pairs_with_predicates - else: - filtered_triples = pairs_with_predicates - if not filtered_triples: - return - elif not self.skip_inferences: - relationships = self.semantic_extractor.process_tokens(filtered_triples, fixed_entities=(len(self.sample_entities) >= 1)) - if len(relationships) > 0: - if self.fixed_relationships and self.sample_relationships: - embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True) - elif self.sample_relationships: - embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True) - else: - embedding_triples = self.create_emb.generate_embeddings(relationships) - if self.sample_relationships: - embedding_triples = self.predicate_context_extractor.update_embedding_triples_with_similarity(self.predicate_json_emb, embedding_triples) - for triple in embedding_triples: - if not self.termination_event.is_set(): - graph_json = json.dumps(TripleToJsonConverter.convert_graphjson(triple)) - if graph_json: - current_state = EventState(event_type=EventType.Graph, timestamp=time.time(), payload=graph_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(triple)) - if vector_json: - current_state = EventState(event_type=EventType.Vector, timestamp=time.time(), payload=vector_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - else: - return - else: - return - else: - return filtered_triples, file else: + pairs_withattn = doc_entity_pairs + pairs_with_predicates = self._process_pairs_with_embeddings(pairs_withattn, file) + filtered_triples = self._filter_triples(pairs_with_predicates, pairs_withattn) + if not filtered_triples: return + + if not self.skip_inferences: + await self._process_relationships(filtered_triples, file, doc_source) + else: + return filtered_triples, file + except Exception as e: self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}") + + def _prepare_content(self, data): + if data.data: + clean_text = ' '.join(data.data) + else: + clean_text = data.data + + if not data.is_token_stream: + file, content = self.file_buffer.add_chunk(data.get_file_path(), clean_text) + else: + content = clean_text + file = data.get_file_path() + return content, file + + def _get_entity_pairs(self, content): + return self.ner_llm_instance.get_entity_pairs( + isConfinedSearch=self.isConfinedSearch, + content=content, + fixed_entities=self.fixed_entities, + sample_entities=self.sample_entities + )[1] + + def _process_entity_types(self, doc_entity_pairs): + if self.sample_entities: + doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs) + if any(doc_entity_pairs): + doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs) + return doc_entity_pairs + + def _process_pairs_with_embeddings(self, pairs_withattn, file): + if self.enable_filtering and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn) > 1 and not self.predicate_context_extractor: + self.entity_embedding_extractor = EntityEmbeddingExtractor(self.ner_model, self.ner_tokenizer) + pairs_withemb = self.entity_embedding_extractor.extract_and_append_entity_embeddings(pairs_withattn) + else: + pairs_withemb = pairs_withattn + return process_data(pairs_withemb, file) + + def _filter_triples(self, pairs_with_predicates, pairs_withattn): + if self.enable_filtering and not self.entity_context_extractor and self.count_entity_pairs(pairs_withattn) > 1 and not self.predicate_context_extractor: + cluster_output = self.triple_filter.cluster_triples(pairs_with_predicates) + clustered_triples = cluster_output['filtered_triples'] + if clustered_triples: + filtered_triples, _ = self.triple_filter.filter_triples(clustered_triples) + else: + self.logger.debug(f"Filtering in {self.__class__.__name__} producing 0 entity pairs. Filtering Disabled.") + filtered_triples = pairs_with_predicates + else: + filtered_triples = pairs_with_predicates + return filtered_triples + + async def _process_relationships(self, filtered_triples, file, doc_source): + if self.attn_based_rel_extraction == False: + relationships = self.semantic_extractor.process_tokens( + filtered_triples, + fixed_entities=(len(self.sample_entities) >= 1) + ) + else: + filtered_triples = trim_triples(filtered_triples) + relationships = process_tokens(filtered_triples=filtered_triples, ner_instance=self.ner_helper_instance, extractor=self.extractor, nlp_model=self.nlp_model) + if not relationships: + return + + embedding_triples = self._generate_embeddings(relationships) + await self._process_embedding_triples(embedding_triples, file, doc_source) + + def _generate_embeddings(self, relationships): + if self.fixed_relationships and self.sample_relationships: + return self.create_emb.generate_embeddings( + relationships, + relationship_finder=True, + generate_embeddings_with_fixed_relationship=True + ) + elif self.sample_relationships: + return self.create_emb.generate_embeddings(relationships, relationship_finder=True) + else: + return self.create_emb.generate_embeddings(relationships) + + async def _process_embedding_triples(self, embedding_triples, file, doc_source): + if self.sample_relationships: + embedding_triples = self.predicate_context_extractor.update_embedding_triples_with_similarity( + self.predicate_json_emb, embedding_triples) + + for triple in embedding_triples: + if self.termination_event.is_set(): + return + + graph_json = json.dumps(TripleToJsonConverter.convert_graphjson(triple)) + if graph_json: + current_state = EventState( + event_type=EventType.Graph, + timestamp=time.time(), + payload=graph_json, + file=file, + doc_source=doc_source + ) + await self.set_state(new_state=current_state) + subject, json_str, object_ = triple + context = json.loads(json_str) + sen_emb = self.create_emb.get_embeddings([context['context']])[0] + sub_emb = self.create_emb.get_embeddings(subject)[0] + obj_emb = self.create_emb.get_embeddings(object_)[0] + predicate_score=context['score'] + final_emb = TripleToJsonConverter.dynamic_weighted_average_embeddings( + [sub_emb, obj_emb, sen_emb], + base_weights=[predicate_score, predicate_score, 3], + normalize_weights=True # Normalize weights to ensure they sum to 1 + ) + vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(triple=triple, embeddings=final_emb)) + if vector_json: + current_state = EventState( + event_type=EventType.Vector, + timestamp=time.time(), + payload=vector_json, + file=file, + doc_source=doc_source + ) + await self.set_state(new_state=current_state) diff --git a/querent/core/transformers/fixed_entities_set_opensourcellm.py b/querent/core/transformers/fixed_entities_set_opensourcellm.py deleted file mode 100644 index 7cac1f27..00000000 --- a/querent/core/transformers/fixed_entities_set_opensourcellm.py +++ /dev/null @@ -1,195 +0,0 @@ -import json -from transformers import AutoTokenizer -import time -from querent.common.types.ingested_table import IngestedTables -from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor -from querent.common.types.ingested_images import IngestedImages -from querent.config.core.opensource_llm_config import Opensource_LLM_Config -from querent.core.transformers.relationship_extraction_llm import RelationExtractor -from querent.kg.rel_helperfunctions.contextual_predicate import process_data -from querent.kg.ner_helperfunctions.fixed_entities import FixedEntityExtractor -from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM -from querent.common.types.querent_event import EventState, EventType -from querent.core.base_engine import BaseEngine -from querent.common.types.ingested_tokens import IngestedTokens -from querent.common.types.ingested_messages import IngestedMessages -from querent.common.types.ingested_code import IngestedCode -from querent.common.types.querent_queue import QuerentQueue -from querent.common.types.file_buffer import FileBuffer -from querent.logging.logger import setup_logger -from querent.kg.querent_kg import QuerentKG -from querent.config.graph_config import GraphConfig -from querent.kg.ner_helperfunctions.filter_triples import TripleFilter -from querent.config.core.llm_config import LLM_Config -from querent.kg.rel_helperfunctions.triple_to_json import TripleToJsonConverter -from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore - - -class Fixed_Entities_LLM(BaseEngine): - def __init__( - self, - input_queue:QuerentQueue, - config: LLM_Config, - Embedding=None - ): - self.logger = setup_logger(__name__, "Fixed_Entities_LLM") - super().__init__(input_queue) - self.skip_inferences=config.skip_inferences - if not self.skip_inferences: - mock_config = Opensource_LLM_Config(qa_template=config.user_context, - model_type = config.rel_model_type, - model_path = config.rel_model_path, - grammar_file_path = config.grammar_file_path, - emb_model_name = config.emb_model_name, - is_confined_search = config.is_confined_search, - spacy_model_path = config.spacy_model_path, - nltk_path = config.nltk_path, - ) - self.semantic_extractor = RelationExtractor(mock_config) - self.graph_config = GraphConfig(identifier=config.name) - self.contextual_graph = QuerentKG(self.graph_config) - self.semantic_graph = QuerentKG(self.graph_config) - self.file_buffer = FileBuffer() - self.ner_tokenizer = AutoTokenizer.from_pretrained(config.ner_model_name) - self.ner_llm_instance = NER_LLM(provided_tokenizer=self.ner_tokenizer, provided_model= "dummy") - self.nlp_model = NER_LLM.set_nlp_model(config.spacy_model_path) - self.nlp_model = NER_LLM.get_class_variable() - huggingface_token = config.huggingface_token - if Embedding is None: - self.create_emb = EmbeddingStore() - else: - self.create_emb = Embedding - self.enable_filtering = config.enable_filtering - self.filter_params = config.filter_params or {} - self.triple_filter = None - if self.enable_filtering: - self.triple_filter = TripleFilter(**self.filter_params) - self.sample_entities = config.sample_entities - self.fixed_entities = config.fixed_entities - if self.fixed_entities and not self.sample_entities: - raise ValueError("If specific entities are provided, their types should also be provided.") - if self.fixed_entities and self.sample_entities: - self.entity_context_extractor = FixedEntityExtractor(fixed_entities=self.fixed_entities, entity_types=self.sample_entities,model = self.nlp_model) - elif self.sample_entities: - self.entity_context_extractor = FixedEntityExtractor(entity_types=self.sample_entities, model = self.nlp_model) - else: - self.entity_context_extractor = None - self.fixed_relationships = config.fixed_relationships - self.sample_relationships = config.sample_relationships - if self.fixed_relationships and not self.sample_relationships: - raise ValueError("If specific predicates are provided, their types should also be provided.") - if self.fixed_relationships and self.sample_relationships: - self.predicate_context_extractor = FixedPredicateExtractor(fixed_predicates=self.fixed_relationships, predicate_types=self.sample_relationships,model = self.nlp_model) - self.predicate_json = self.predicate_context_extractor.construct_predicate_json(self.fixed_relationships, self.sample_relationships) - self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json) - elif self.sample_relationships: - self.predicate_context_extractor = FixedPredicateExtractor(predicate_types=self.sample_relationships,model = self.nlp_model) - self.predicate_json = self.predicate_context_extractor.construct_predicate_json(relationship_types=self.sample_relationships) - self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json) - else: - self.predicate_context_extractor = None - self.user_context = config.user_context - self.isConfinedSearch = config.is_confined_search - - - - def validate(self) -> bool: - return self.ner_tokenizer is not None - - def process_messages(self, data: IngestedMessages): - return super().process_messages(data) - - def process_tables(self, data: IngestedTables): - pass - - async def process_images(self, data: IngestedImages): - pass - - async def process_code(self, data: IngestedCode): - return super().process_code(data) - - @staticmethod - def validate_ingested_tokens(data: IngestedTokens) -> bool: - if data.is_error(): - - return False - - return True - - @staticmethod - def validate_ingested_images(data: IngestedImages) -> bool: - if data.is_error(): - - return False - - return True - - - async def process_tokens(self, data: IngestedTokens): - doc_entity_pairs = [] - try: - doc_source = data.doc_source - if not Fixed_Entities_LLM.validate_ingested_tokens(data): - self.set_termination_event() - return - if data.data: - clean_text = ' '.join(data.data) - else: - clean_text = data.data - if not data.is_token_stream : - file, content = self.file_buffer.add_chunk( - data.get_file_path(), clean_text) - else: - content = clean_text - file = data.get_file_path() - if content: - if self.fixed_entities: - content = self.entity_context_extractor.find_entity_sentences(content) - if self.fixed_relationships: - content = self.predicate_context_extractor.find_predicate_sentences(content) - (_, doc_entity_pairs) = self.ner_llm_instance.get_entity_pairs(isConfinedSearch= self.isConfinedSearch, - content=content, - fixed_entities=self.fixed_entities, - sample_entities=self.sample_entities) - else: - return - if self.sample_entities: - doc_entity_pairs = self.entity_context_extractor.process_entity_types(doc_entities=doc_entity_pairs) - if doc_entity_pairs and any(doc_entity_pairs): - doc_entity_pairs = self.ner_llm_instance.remove_duplicates(doc_entity_pairs) - filtered_triples = process_data(doc_entity_pairs, file) - if not filtered_triples: - self.logger.debug("No entity pairs") - return - elif not self.skip_inferences: - relationships = self.semantic_extractor.process_tokens(filtered_triples) - self.logger.debug(f"length of relationships {len(relationships)}") - if len(relationships) > 0: - if self.fixed_relationships and self.sample_relationships: - embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True) - elif self.sample_relationships: - embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True) - else: - embedding_triples = self.create_emb.generate_embeddings(relationships) - if self.sample_relationships: - embedding_triples = self.predicate_context_extractor.update_embedding_triples_with_similarity(self.predicate_json_emb, embedding_triples) - for triple in embedding_triples: - if not self.termination_event.is_set(): - graph_json = json.dumps(TripleToJsonConverter.convert_graphjson(triple)) - if graph_json: - current_state = EventState(event_type=EventType.Graph,timestamp=time.time(), payload=graph_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(triple)) - if vector_json: - current_state = EventState(event_type=EventType.Vector, timestamp=time.time(), payload=vector_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - else: - return - else: - return - else: - return filtered_triples, file - else: - return - except Exception as e: - self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}") diff --git a/querent/core/transformers/gpt_llm_bert_ner_or_fixed_entities_set_ner.py b/querent/core/transformers/gpt_llm_bert_ner_or_fixed_entities_set_ner.py deleted file mode 100644 index 8be671cb..00000000 --- a/querent/core/transformers/gpt_llm_bert_ner_or_fixed_entities_set_ner.py +++ /dev/null @@ -1,379 +0,0 @@ -import json -import re -import time -from querent.common.types.ingested_table import IngestedTables -from querent.core.transformers.fixed_entities_set_opensourcellm import Fixed_Entities_LLM -from querent.kg.ner_helperfunctions.fixed_predicate import FixedPredicateExtractor -from querent.config.core.gpt_llm_config import GPTConfig -from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -from querent.common.types.ingested_images import IngestedImages -from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM -from querent.kg.rel_helperfunctions.openai_functions import FunctionRegistry -from querent.common.types.querent_event import EventState, EventType -from querent.core.base_engine import BaseEngine -from querent.common.types.ingested_tokens import IngestedTokens -from querent.common.types.ingested_messages import IngestedMessages -from querent.common.types.ingested_code import IngestedCode -from querent.common.types.querent_queue import QuerentQueue -from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore -from typing import Any, List, Tuple -from querent.kg.rel_helperfunctions.triple_to_json import TripleToJsonConverter -from querent.logging.logger import setup_logger -from querent.config.core.llm_config import LLM_Config -from openai import OpenAI -from tenacity import ( - retry, - stop_after_attempt, - wait_random_exponential, - wait_fixed -) -from dotenv import load_dotenv, find_dotenv -import json - -_ = load_dotenv(find_dotenv()) - - -class GPTLLM(BaseEngine): - def __init__( - self, - input_queue: QuerentQueue, - config: GPTConfig - ): - self.logger = setup_logger(__name__, "OPENAILLM") - try: - super().__init__(input_queue) - llm_config = LLM_Config( - ner_model_name=config.ner_model_name, - enable_filtering=config.enable_filtering, - filter_params={ - 'score_threshold': config.filter_params['score_threshold'], - 'attention_score_threshold': config.filter_params['attention_score_threshold'], - 'similarity_threshold': config.filter_params['similarity_threshold'], - 'min_cluster_size': config.filter_params['min_cluster_size'], - 'min_samples': config.filter_params['min_samples'], - 'cluster_persistence_threshold':config.filter_params['cluster_persistence_threshold'] - }, - sample_entities = config.sample_entities, - fixed_entities = config.fixed_entities, - skip_inferences= True, - is_confined_search = config.is_confined_search, - huggingface_token = config.huggingface_token, - spacy_model_path = config.spacy_model_path, - nltk_path = config.nltk_path, - fixed_relationships = config.fixed_relationships, - sample_relationships = config.sample_relationships) - self.fixed_entities = config.fixed_entities - self.is_confined_search = config.is_confined_search - self.fixed_relationships = config.fixed_relationships - self.sample_relationships = config.sample_relationships - self.user_context = config.user_context - self.nlp_model = NER_LLM.set_nlp_model(config.spacy_model_path) - self.nlp_model = NER_LLM.get_class_variable() - self.create_emb = EmbeddingStore() - if self.fixed_relationships and not self.sample_relationships: - raise ValueError("If specific predicates are provided, their types should also be provided.") - if self.fixed_relationships and self.sample_relationships: - self.predicate_context_extractor = FixedPredicateExtractor(fixed_predicates=self.fixed_relationships, predicate_types=self.sample_relationships,model = self.nlp_model) - self.predicate_json = self.predicate_context_extractor.construct_predicate_json(self.fixed_relationships, self.sample_relationships) - self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json) - elif self.sample_relationships: - self.predicate_context_extractor = FixedPredicateExtractor(predicate_types=self.sample_relationships,model = self.nlp_model) - self.predicate_json = self.predicate_context_extractor.construct_predicate_json(relationship_types=self.sample_relationships) - self.predicate_json_emb = self.create_emb.generate_relationship_embeddings(self.predicate_json) - else: - self.predicate_context_extractor = None - if config.is_confined_search: - self.llm_instance = Fixed_Entities_LLM(input_queue, llm_config, self.create_emb) - else : - self.llm_instance = BERTLLM(input_queue, llm_config, self.create_emb) - if not isinstance (self.llm_instance, BERTLLM): - self.process_image_instance = BERTLLM(input_queue, llm_config, self.create_emb) - else: - self.process_image_instance = self.llm_instance - self.rel_model_name = config.rel_model_name - if config.openai_api_key: - self.gpt_llm = OpenAI(api_key=config.openai_api_key) - else: - self.gpt_llm = OpenAI() - self.function_registry = FunctionRegistry() - - except Exception as e: - self.logger.error(f"Invalid {self.__class__.__name__} configuration. Unable to Initialize. {e}") - raise Exception(f"Invalid {self.__class__.__name__} configuration. Unable to Initialize. {e}") - - def validate(self) -> bool: - return isinstance(self.llm_instance, BERTLLM) or isinstance(self.llm_instance, Fixed_Entities_LLM) - - def process_messages(self, data: IngestedMessages): - return super().process_messages(data) - - async def process_images(self, data: IngestedImages): - try: - if not GPTLLM.validate_ingested_images(data): - self.set_termination_event() - return - blob = data.image - unique_id = str(hash(data.image)) - doc_source = data.doc_source - result = await self.process_image_instance.process_images(data) - if not result: return - else: - filtered_triples, file, ner_instance = result - for triple in filtered_triples: - updated_tuple = ner_instance.final_ingested_images_tuples(triple, create_embeddings=self.create_emb) - graph_json = json.dumps(TripleToJsonConverter.convert_graphjson(updated_tuple)) - if graph_json: - current_state = EventState(event_type=EventType.Graph, timestamp=time.time(), payload=graph_json, file=file, doc_source=doc_source, image_id=unique_id) - await self.set_state(new_state=current_state) - vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(updated_tuple, blob)) - if vector_json: - current_state = EventState(event_type=EventType.Vector, timestamp=time.time(), payload=vector_json, file=file, doc_source=doc_source, image_id=unique_id) - await self.set_state(new_state=current_state) - - except Exception as e: - self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to process tokens. {e}") - - async def process_code(self, data: IngestedCode): - return super().process_messages(data) - - async def process_tables(self, data: IngestedTables): - return super().process_tables(data) - - @staticmethod - def validate_ingested_tokens(data: IngestedTokens) -> bool: - if data.is_error(): - - return False - - return True - - @staticmethod - def validate_ingested_images(data: IngestedImages) -> bool: - if data.is_error(): - - return False - - return True - def extract_semantic_triples(self, chat_completion): - # Extract the message content from the ChatCompletion - message_content = chat_completion.choices[0].message.content.replace('\n', '') - - # Parse the JSON content into a Python list - try: - triples_list = eval(message_content) - if not isinstance(triples_list, list): - raise ValueError("Content is not a list") - except Exception as e: - return [] - - return triples_list - - @staticmethod - def remove_items_from_tuples(data: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]: - result = [] - keys_to_remove = ['entity1_embedding', 'entity2_embedding', 'entity1_attnscore', 'entity2_attnscore', 'pair_attnscore'] - - for tup in data: - json_data = json.loads(tup[1]) - for key in keys_to_remove: - json_data.pop(key, None) - modified_json_str = json.dumps(json_data, ensure_ascii=False) - modified_tuple = (tup[0], modified_json_str, tup[2]) - result.append(modified_tuple) - - return result - - async def process_triples(self, context, entity1, entity2, entity1_label, entity2_label): - try: - if not self.user_context and not self.fixed_entities: - identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. - - Context: {context} - Entity 1: {entity1} and Entity 2: {entity2} - Output Format: - [ - {{ - 'subject': 'Identified as the main entity in the context, typically the initiator or primary focus of the action or topic being discussed.', - 'predicate': 'The relationship (predicate) between the subject and the object.', - 'object': 'This parameter represents the entity in the context directly impacted by or involved in the action, typically the recipient or target of the main verb's action.', - 'subject_type': 'The category of the subject entity e.g. location, person, event, material, process etc.', - 'object_type': 'The category of the object entity e.g. location, person, event, material, process etc.', - 'predicate_type': 'The category of the predicate e.g. causative, action, ownership, occurance etc.' - }}, - ] - """ - messages_classify_entity = [ - {"role": "user", "content": identify_entity_message}, - {"role": "user", "content": "Query : In the context of a semantic triple framework, first identify which entity is subject and which is the object along with their respective types. Also determine the predicate and predicate type."}, - ] - elif not self.user_context and self.fixed_entities : - identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. - - Context: {context} - Entity 1: {entity1} and Entity 2: {entity2} - Entity 1_Type: {entity1_label} and Entity 2_Type: {entity2_label} - Output Format: - [ - {{ - 'subject': 'Identified as the main entity in the context, typically the initiator or primary focus of the action or topic being discussed.', - 'predicate': 'The relationship (predicate) between the subject and the object.', - 'object': 'This parameter represents the entity in the context directly impacted by or involved in the action, typically the recipient or target of the main verb's action.', - 'subject_type': 'The category of the subject entity e.g. location, person, event, material, process etc.', - 'object_type': 'The category of the object entity e.g. location, person, event, material, process etc.', - 'predicate_type': 'The category of the predicate e.g. causative, action, ownership, occurance etc.' - }}, - ] - """ - messages_classify_entity = [ - {"role": "user", "content": identify_entity_message}, - {"role": "user", "content": "Query : In the context of a semantic triple framework, first identify which entity is subject and which is the object and also validate and output their their respective types. Also determine the predicate and predicate type."}, - ] - elif self.user_context and self.fixed_entities : - identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. - - Context: {context} - Entity 1: {entity1} and Entity 2: {entity2} - Entity 1_Type: {entity1_label} and Entity 2_Type: {entity2_label} - Output Format: - [ - {{ - 'subject': 'Identified as the main entity in the context, typically the initiator or primary focus of the action or topic being discussed.', - 'predicate': 'The relationship (predicate) between the subject and the object.', - 'object': 'This parameter represents the entity in the context directly impacted by or involved in the action, typically the recipient or target of the main verb's action.', - 'subject_type': 'The category of the subject entity e.g. location, person, event, material, process etc.', - 'object_type': 'The category of the object entity e.g. location, person, event, material, process etc.', - 'predicate_type': 'The category of the predicate e.g. causative, action, ownership, occurance etc.' - }}, - ] - """ - messages_classify_entity = [ - {"role": "user", "content": identify_entity_message}, - {"role": "user", "content": self.user_context}, - ] - elif self.user_context and not self.fixed_entities : - identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. - - Context: {context} - Entity 1: {entity1} and Entity 2: {entity2} - Output Format: - [ - {{ - 'subject': 'Identified as the main entity in the context, typically the initiator or primary focus of the action or topic being discussed.', - 'predicate': 'The relationship (predicate) between the subject and the object.', - 'object': 'This parameter represents the entity in the context directly impacted by or involved in the action, typically the recipient or target of the main verb's action.', - 'subject_type': 'The category of the subject entity e.g. location, person, event, material, process etc.', - 'object_type': 'The category of the object entity e.g. location, person, event, material, process etc.', - 'predicate_type': 'The category of the predicate e.g. causative, action, ownership, occurance etc.' - }}, - ] - """ - messages_classify_entity = [ - {"role": "user", "content": identify_entity_message}, - {"role": "user", "content": self.user_context}, - ] - identify_predicate_response = self.generate_response( - messages_classify_entity, - "predicate_info" - ) - semantic_triples = self.extract_semantic_triples(identify_predicate_response) - if len(semantic_triples)>0: - return { - 'subject_type': semantic_triples[0]['subject_type'].lower().replace(" ", "_"), - 'subject': semantic_triples[0]['subject'].lower(), - 'object_type': semantic_triples[0]['object_type'].lower().replace(" ", "_"), - 'object': semantic_triples[0]['object'].lower(), - 'predicate': semantic_triples[0]['predicate'].lower(), - 'predicate_type': semantic_triples[0]['predicate_type'].lower().replace(" ", "_") - } - except Exception as e: - self.logger.error(f"Invalid {self.__class__.__name__} configuration. Unable to process triples using GPT. {e}") - - # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) - def completion_with_backoff(self, **kwargs): - return self.gpt_llm.chat.completions.create(**kwargs) - - def generate_response(self, messages, name): - response = self.completion_with_backoff( - model=self.rel_model_name, - messages=messages, - temperature=0 - ) - return response - - def generate_output_tuple(self,result, context_json): - context_data = json.loads(context_json) - context = context_data.get("context", "") - subject_type = result.get("subject_type", "Unlabeled") - subject = result.get("subject", "") - object_type = result.get("object_type", "Unlabeled") - object = result.get("object", "") - predicate = result.get("predicate", "") - predicate_type = result.get("predicate_type", "Unlabled") - output_tuple = ( - subject, - f'{{"predicate": "{predicate}", "predicate_type": "{predicate_type}", "context": "{context}", "file": "{context_data.get("file_path", "")}", "subject_type": "{subject_type}", "object_type": "{object_type}"}}', - object - ) - - return output_tuple - - def extract_key(tup): - subject, json_string, obj = tup - data = json.loads(json_string.replace("\n", "")) - return (subject, data.get('predicate'), obj) - - async def process_tokens(self, data: IngestedTokens): - try: - if not GPTLLM.validate_ingested_tokens(data): - self.set_termination_event() - return - doc_source = data.doc_source - relationships = [] - unique_keys = set() - result = await self.llm_instance.process_tokens(data) - if not result: return - else: - filtered_triples, file = result - modified_data = GPTLLM.remove_items_from_tuples(filtered_triples) - for entity1, context_json, entity2 in modified_data: - context_data = json.loads(context_json) - context = context_data.get("context", "") - entity1_label = context_data.get("entity1_label", "") - entity2_label = context_data.get("entity2_label", "") - entity1_nn_chunk = context_data.get("entity1_nn_chunk","") - entity2_nn_chunk = context_data.get("entity2_nn_chunk","") - result = await self.process_triples(context, entity1_nn_chunk, entity2_nn_chunk, entity1_label, entity2_label) - if result: - output_tuple = self.generate_output_tuple(result, context_json) - key = GPTLLM.extract_key(output_tuple) - if key not in unique_keys: - unique_keys.add(key) - relationships.append(output_tuple) - if len(relationships) > 0: - if self.fixed_relationships and self.sample_relationships: - embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True, generate_embeddings_with_fixed_relationship = True) - elif self.sample_relationships: - embedding_triples = self.create_emb.generate_embeddings(relationships, relationship_finder=True) - else: - embedding_triples = self.create_emb.generate_embeddings(relationships) - if self.sample_relationships: - embedding_triples = self.predicate_context_extractor.update_embedding_triples_with_similarity(self.predicate_json_emb, embedding_triples) - for triple in embedding_triples: - if not self.termination_event.is_set(): - graph_json = json.dumps(TripleToJsonConverter.convert_graphjson(triple)) - if graph_json: - current_state = EventState(event_type=EventType.Graph,timestamp = time.time(), payload= graph_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson(triple)) - if vector_json: - current_state = EventState(event_type=EventType.Vector,timestamp=time.time(), payload = vector_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - else: - return - else: - return - except Exception as e: - self.logger.error(f"Invalid {self.__class__.__name__} configuration. Unable to extract predicates using GPT. {e}") - - async def process_messages(self, data: IngestedMessages): - raise NotImplementedError \ No newline at end of file diff --git a/querent/core/transformers/gpt_llm_gpt_ner.py b/querent/core/transformers/gpt_llm_gpt_ner.py deleted file mode 100644 index a1ff4c2a..00000000 --- a/querent/core/transformers/gpt_llm_gpt_ner.py +++ /dev/null @@ -1,266 +0,0 @@ -import json -import re -from unidecode import unidecode -import time - - -import spacy -from querent.config.core.gpt_llm_config import GPTConfig -from querent.common.types.ingested_images import IngestedImages -from querent.kg.rel_helperfunctions.openai_functions import FunctionRegistry -from querent.common.types.querent_event import EventState, EventType -from querent.core.base_engine import BaseEngine -from querent.common.types.ingested_tokens import IngestedTokens -from querent.common.types.ingested_messages import IngestedMessages -from querent.common.types.ingested_code import IngestedCode -from querent.common.types.querent_queue import QuerentQueue -from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore -from querent.kg.rel_helperfunctions.triple_to_json import TripleToJsonConverter -from querent.logging.logger import setup_logger -from openai import OpenAI -from querent.common.types.file_buffer import FileBuffer -from tenacity import ( - retry, - stop_after_attempt, - wait_random_exponential, -) -from dotenv import load_dotenv, find_dotenv -import json - -_ = load_dotenv(find_dotenv()) - - -class GPTNERLLM(BaseEngine): - def __init__( - self, - input_queue: QuerentQueue, - config: GPTConfig - ): - self.logger = setup_logger(__name__, "OPENAINERLLM") - try: - self.nlp = spacy.load(config.spacy_model_path) - super().__init__(input_queue) - self.file_buffer = FileBuffer() - self.rel_model_name = config.rel_model_name - if config.openai_api_key: - self.gpt_llm = OpenAI(api_key=config.openai_api_key) - else: - self.gpt_llm = OpenAI() - self.function_registry = FunctionRegistry() - self.create_emb = EmbeddingStore() - self.user_context = config.user_context - - except Exception as e: - self.logger.error(f"Invalid {self.__class__.__name__} configuration. Unable to Initialize. {e}") - raise Exception(f"Invalid {self.__class__.__name__} configuration. Unable to Initialize. {e}") - - def validate(self) -> bool: - return True - - def process_messages(self, data: IngestedMessages): - return super().process_messages(data) - - def process_images(self, data: IngestedImages): - return super().process_messages(data) - - async def process_code(self, data: IngestedCode): - return super().process_messages(data) - - @staticmethod - def validate_ingested_tokens(data: IngestedTokens) -> bool: - if data.is_error(): - - return False - - return True - - def get_context(self, sentences): - contexts = [] - for i, sentence in enumerate(sentences): - # Previous sentence or empty string if current sentence is the first - prev_sentence = sentences[i-1] if i > 0 else '' - # Next sentence or empty string if current sentence is the last - next_sentence = sentences[i+1] if i < len(sentences)-1 else '' - # Current context - context = f"{prev_sentence} {sentence} {next_sentence}".strip() - contexts.append(context) - return contexts - - @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) - def completion_with_backoff(self, **kwargs): - return self.gpt_llm.chat.completions.create(**kwargs) - - def generate_response(self, messages): - response = self.completion_with_backoff( - model=self.rel_model_name, - messages=messages, - temperature=0, - ) - return response - - def extract_semantic_triples(self, chat_completion): - # Extract the message content from the ChatCompletion - message_content = chat_completion.choices[0].message.content.replace('\n', '') - - # Parse the JSON content into a Python list - try: - triples_list = eval(message_content) - if not isinstance(triples_list, list): - raise ValueError("Content is not a list") - except Exception as e: - return [] - - return triples_list - - def filter_relevant_triples(self, triples, context, max_distance): - def tokenize_text(text): - # Tokenize the text while ignoring common filler words - filler_words = ["the", "and", "in", "of", "a", "an", "to", "with"] - tokens = re.findall(r'\b\w+\b', text) - tokens = [token.lower() for token in tokens if token.lower() not in filler_words] - return tokens - - def calculate_distance(tokenized_text, subject, object): - # Find the start index of subject and object in tokenized_text - subject_tokens = tokenize_text(subject) - object_tokens = tokenize_text(object) - - subject_start_index = None - object_start_index = None - - for i in range(len(tokenized_text)): - if tokenized_text[i:i+len(subject_tokens)] == subject_tokens: - subject_start_index = i - break - - for i in range(len(tokenized_text)): - if tokenized_text[i:i+len(object_tokens)] == object_tokens: - object_start_index = i - break - - if subject_start_index is not None and object_start_index is not None: - return abs(subject_start_index - object_start_index) - else: - return float('inf') # Return a large distance if either subject or object is not found - - tokenized_context = tokenize_text(context.lower()) # Convert context to lowercase - relevant_triples = [] - - for triple in triples: - subject = triple['subject'].lower() # Convert subject to lowercase - object = triple['object'].lower() # Convert object to lowercase - - distance = calculate_distance(tokenized_context, subject, object) - - if distance <= max_distance: - if (triple.get("subject") is None or triple.get("subect") == "") or (triple.get("object") is None or triple.get("object") == "") or (triple.get("subject_type") is None or triple.get("subject_type") == "") or (triple.get("object_type") is None or triple.get("object_type") == "") or (triple.get("predicate") is None or triple.get("predicate") == "") or (triple.get("predicate_type") is None or triple.get("predicate_type") == ""): - self.logger.debug(f"Received none while creating semantic triples") - continue - triple['subject'] = subject - triple['object'] = object - triple['subject_type'] = triple['subject_type'].lower().replace(" ", "_") - triple['object_type'] = triple['object_type'].lower().replace(" ", "_") - triple['predicate'] = triple['predicate'].lower() - triple['predicate_type'] = triple['predicate_type'].lower().replace(" ", "_") - triple['sentence'] = context # Add the context as a key-value pair - relevant_triples.append(triple) - - return relevant_triples - - def remove_duplicate_triplets(self,input_list): - seen = set() - result = [] - - for item in input_list: - triplet = (item['subject'], item['predicate'], item['object']) - if triplet not in seen: - result.append(item) - seen.add(triplet) - - return result - - async def process_tokens(self, data: IngestedTokens): - try: - if not GPTNERLLM.validate_ingested_tokens(data): - self.set_termination_event() - return - - doc_source = data.doc_source - if data.data: - clean_text = ' '.join(data.data) - else: - clean_text = data.data - if not data.is_token_stream : - file, content = self.file_buffer.add_chunk( - data.get_file_path(), clean_text) - else: - content = clean_text - file = data.get_file_path() - if content: - doc = self.nlp(content) - sentences = [sent.text for sent in doc.sents] - contexts = self.get_context(sentences) - final_triples = [] - for context in contexts: - identify_entity_message = f"""Please analyze the provided context below. Once you have understood the context, answer the user query using the specified output format. - - Context: {contexts[0]} - - Output Format: - [ - {{ - 'subject': 'Identified as the main entity in the context, typically the initiator or primary focus of the action or topic being discussed.', - 'predicate': 'The relationship (predicate) between the subject and the object.', - 'object': 'This parameter represents the entity in the context directly impacted by or involved in the action, typically the recipient or target of the main verb's action.', - 'subject_type': 'The category of the subject entity e.g. location, person, event, material, process etc.', - 'object_type': 'The category of the object entity e.g. location, person, event, material, process etc.', - 'predicate_type': 'The category of the predicate e.g. causative, action, ownership, occurance etc.' - }}, - {{}}, # Additional triples go here - {{}}, # Additional triples go here - # ... # Additional triples go here - ] - """ - if not self.user_context: - messages_classify_entity = [ - {"role": "user", "content": identify_entity_message}, - {"role": "user", "content": "Query: First, identify all geological entities in the provided context. Then, create relevant semantic triples (Subject, Predicate, Object) and also categorize the respective the Subject, Object types (e.g. location, person, event, material, process etc.) and Predicate type. Use the above output format to provide all the relevant semantic triples."}, - ] - else : - messages_classify_entity = [ - {"role": "user", "content": identify_entity_message}, - {"role": "user", "content": self.user_context}, - ] - identify_entity_response = self.generate_response(messages_classify_entity) - try: - semantic_triples = self.extract_semantic_triples(identify_entity_response) - relevant_triples = self.filter_relevant_triples(semantic_triples, context, 10) - except Exception as e: - self.logger.debug(f"Error extracting semantic triples in GPT NER & LLM Class: {e}") - continue - if len(relevant_triples)>0: - final_triples.extend(relevant_triples) - final_triples = self.remove_duplicate_triplets(final_triples) - if len(final_triples) > 0: - for triple in final_triples: - if not self.termination_event.is_set(): - graph_json = json.dumps(triple) - if graph_json: - current_state = EventState(event_type=EventType.Graph,timestamp=time.time(), payload=graph_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - context_embeddings = self.create_emb.get_embeddings([triple['sentence']])[0] - triple['context_embeddings'] = context_embeddings - triple['context'] = triple['sentence'] - vector_json = json.dumps(TripleToJsonConverter.convert_vectorjson((triple['subject'],json.dumps(triple), triple['object']))) - if vector_json: - current_state = EventState(event_type=EventType.Vector,timestamp=time.time(), payload=vector_json, file=file, doc_source=doc_source) - await self.set_state(new_state=current_state) - else: - return - else: - return - except Exception as e: - self.logger.debug(f"Invalid {self.__class__.__name__} configuration. Unable to extract predicates using GPT NER LLM class. {e}") - - async def process_messages(self, data: IngestedMessages): - raise NotImplementedError diff --git a/querent/core/transformers/relationship_extraction_llm.py b/querent/core/transformers/relationship_extraction_llm.py index 6bb61da0..a51d89e2 100644 --- a/querent/core/transformers/relationship_extraction_llm.py +++ b/querent/core/transformers/relationship_extraction_llm.py @@ -138,7 +138,8 @@ def create_semantic_triple(self, input1, input2): "context": input2_data.get("context", ""), "file_path": input2_data.get("file_path", ""), "subject_type": input1.get("subject_type","Unlabeled"), - "object_type": input1.get("object_type","Unlabeled") + "object_type": input1.get("object_type","Unlabeled"), + "score":1 }), input1.get("object","") ) diff --git a/querent/kg/ner_helperfunctions/dependency_parsing.py b/querent/kg/ner_helperfunctions/dependency_parsing.py index baeb4e14..00249b99 100644 --- a/querent/kg/ner_helperfunctions/dependency_parsing.py +++ b/querent/kg/ner_helperfunctions/dependency_parsing.py @@ -18,17 +18,13 @@ List of noun chunks identified in the sentence. filtered_chunks : list Filtered noun chunks based on certain criteria. - merged_entities : list + noun_chunks : list Entities merged based on overlapping criteria. Methods: -------- load_spacy_model(): Loads the specified SpaCy model. - filter_chunks(): - Filters the noun chunks based on length, stop words, and POS tagging. - merge_overlapping_entities(): - Merges entities that overlap with each other. compare_entities_with_chunks(): Compares the entities with the noun chunks and updates the entity details. process_entities(): @@ -43,42 +39,12 @@ def __init__(self, entities=None, sentence=None, model=None): self.nlp = model self.doc = self.nlp(self.sentence) self.noun_chunks = list(self.doc.noun_chunks) - self.filtered_chunks = self.filter_chunks() - self.merged_entities = self.merge_overlapping_entities() + self.noun_chunks = list(self.doc.noun_chunks) self.compare_entities_with_chunks() self.entities = self.process_entities() except Exception as e: raise Exception(f"Error Initializing Dependency Parsing Class: {e}") - def filter_chunks(self): - try: - filtered_chunks = [] - relevant_pos_tags = {"NOUN", "PROPN", "ADJ"} - for chunk in self.noun_chunks: - # Filtering logic - if len(chunk) > 1 and not chunk.root.is_stop and chunk.root.pos_ in relevant_pos_tags: - filtered_chunks.append(chunk) - return filtered_chunks - - except Exception as e: - raise Exception(f"Error filtering chunks: {e}") - - - def merge_overlapping_entities(self): - try: - merged_entities = [] - i = 0 - while i < len(self.filtered_chunks): - entity = self.filtered_chunks[i] - while i + 1 < len(self.filtered_chunks) and entity.end >= self.filtered_chunks[i + 1].start: - entity = self.doc[entity.start:self.filtered_chunks[i + 1].end] - i += 1 - merged_entities.append(entity) - i += 1 - return merged_entities - except Exception as e: - raise Exception(f"Error merging overlapping entities: {e}") - def compare_entities_with_chunks(self): try: for entity in self.entities: diff --git a/querent/kg/ner_helperfunctions/ner_llm_transformer.py b/querent/kg/ner_helperfunctions/ner_llm_transformer.py index 46808d82..010a1c22 100644 --- a/querent/kg/ner_helperfunctions/ner_llm_transformer.py +++ b/querent/kg/ner_helperfunctions/ner_llm_transformer.py @@ -54,12 +54,13 @@ def set_nlp_model(cls, model_path): cls.nlp = spacy.load(model_path) def __init__( - self, ner_model_name="dbmdz/bert-large-cased-finetuned-conll03-english", + self, ner_model_name="", filler_tokens=None, provided_tokenizer=None, provided_model=None, ): self.logger = setup_logger(__name__, "NER_LLM") + self.device = "cpu" if provided_tokenizer: self.ner_tokenizer = provided_tokenizer else: @@ -68,6 +69,7 @@ def __init__( self.ner_model = provided_model else: self.ner_model = NER_LLM.load_model(ner_model_name, "NER") + self.ner_model.eval() self.filler_tokens = filler_tokens or ["of", "a", "the", "in", "on", "at", "and", "or", "with","(",")","-"] @@ -124,9 +126,9 @@ def _tokenize_and_chunk(self, data: str) -> List[Tuple[List[str], str, int]]: raise Exception(f"An error occurred while tokenizing: {e}") return tokenized_sentences - def _token_distance(self, tokens, start_idx1, start_idx2, noun_chunk1, noun_chunk2): + def _token_distance(self, tokens, start_idx1, nn_chunk_length_idx1, start_idx2, noun_chunk1, noun_chunk2): distance = 0 - for idx in range(start_idx1 + 1, start_idx2): + for idx in range(start_idx1 + nn_chunk_length_idx1, start_idx2): token = tokens[idx] if (token not in self.filler_tokens and token not in noun_chunk1 and @@ -137,32 +139,38 @@ def _token_distance(self, tokens, start_idx1, start_idx2, noun_chunk1, noun_chun def transform_entity_pairs(self, entity_pairs): - transformed_pairs = [] - sentence_group = {} - for pair, metadata in entity_pairs: - combined_sentence = ' '.join(filter(None, [ - metadata['previous_sentence'], - metadata['current_sentence'], - metadata['next_sentence'] - ])) - if combined_sentence not in sentence_group: - sentence_group[combined_sentence] = [] - sentence_group[combined_sentence].append(pair) - - for combined_sentence, pairs in sentence_group.items(): - for entity1, entity2 in pairs: - meta_dict = { - "entity1_score": entity1['score'], - "entity2_score": entity2['score'], - "entity1_label": entity1['label'], - "entity2_label": entity2['label'], - "entity1_nn_chunk":entity1['noun_chunk'], - "entity2_nn_chunk":entity2['noun_chunk'], - } - new_pair = (entity1['entity'], combined_sentence, entity2['entity'], meta_dict) - transformed_pairs.append(new_pair) - - return transformed_pairs + try: + transformed_pairs = [] + sentence_group = {} + for pair, metadata in entity_pairs: + combined_sentence = ' '.join(filter(None, [ + metadata['previous_sentence'], + metadata['current_sentence'], + metadata['next_sentence'] + ])) + current_sentence = metadata['current_sentence'] + if combined_sentence not in sentence_group: + sentence_group[combined_sentence] = [] + sentence_group[combined_sentence].append(pair + (current_sentence,)) + + for combined_sentence, pairs in sentence_group.items(): + for entity1, entity2, current_sentence in pairs: + meta_dict = { + "entity1_score": entity1['score'], + "entity2_score": entity2['score'], + "entity1_label": entity1['label'], + "entity2_label": entity2['label'], + "entity1_nn_chunk":entity1['noun_chunk'], + "entity2_nn_chunk":entity2['noun_chunk'], + "current_sentence":current_sentence + } + new_pair = (entity1['entity'], combined_sentence, entity2['entity'], meta_dict) + transformed_pairs.append(new_pair) + + return transformed_pairs + except Exception as e: + self.logger.error(f"Error trasnforming entity pairs: {e}") + raise Exception(f"Error trasnforming entity pairs: {e}") def get_chunks(self, tokens: List[str], max_chunk_size=510): chunks = [] @@ -179,9 +187,10 @@ def extract_entities_from_chunk(self, chunk: List[str]): results = [] try: input_ids = self.ner_tokenizer.convert_tokens_to_ids(chunk) - input_tensor = torch.tensor([input_ids]) + input_tensor = torch.tensor([input_ids], device=self.device) + attention_mask = torch.ones(input_tensor.shape, device=self.device) with torch.no_grad(): - outputs = self.ner_model(input_tensor) + outputs = self.ner_model(input_tensor, attention_mask=attention_mask) predictions = torch.argmax(outputs[0], dim=2) scores = torch.nn.functional.softmax(outputs[0], dim=2) label_ids = predictions[0].tolist() @@ -206,7 +215,7 @@ def combine_entities_wordpiece(self, entities: List[dict], tokens: List[str]): i = 0 while i < len(entities): entity = entities[i] - while i + 1 < len(entities) and entities[i + 1]["entity"].startswith("##"): + while i + 1 < len(entities) and entities[i + 1]["entity"].startswith("##") and entities[i + 1]["start_idx"] - entities[i]["start_idx"] ==1: entity["entity"] += entities[i + 1]["entity"][2:] entity["score"] = (entity["score"] + entities[i + 1]["score"]) / 2 i += 1 @@ -256,8 +265,8 @@ def extract_binary_pairs(self, entities: List[dict], tokens: List[str], all_sent for j in range(i + 1, len(entities)): if entities[i]["start_idx"] + 1 == entities[j]["start_idx"]: continue - distance = self._token_distance(tokens, entities[i]["start_idx"], entities[j]["start_idx"],entities[i]["noun_chunk"], entities[j]["noun_chunk"]) - if distance <= 30: + distance = self._token_distance(tokens, entities[i]["start_idx"], entities[i]["noun_chunk_length"],entities[j]["start_idx"],entities[i]["noun_chunk"], entities[j]["noun_chunk"]) + if distance <= 10: pair = (entities[i], entities[j]) if pair not in binary_pairs: metadata = { @@ -332,7 +341,21 @@ def filter_matching_entities(self, tuples_nested_list, entities_nested_list): return matched_tuples - + def find_subword_indices(self, text, entity): + subwords = self.ner_tokenizer.tokenize(entity) + subword_ids = self.ner_tokenizer.convert_tokens_to_ids(subwords) + token_ids = self.ner_tokenizer.convert_tokens_to_ids(self.ner_tokenizer.tokenize(text)) + subword_positions = [] + for i in range(len(token_ids) - len(subword_ids) + 1): + if token_ids[i:i + len(subword_ids)] == subword_ids: + subword_positions.append((i+1, i + len(subword_ids))) + return subword_positions + + def tokenize_sentence_with_positions(self, sentence: str): + tokens = self.ner_tokenizer.tokenize(sentence) + token_positions = [(token, idx +1 ) for idx, token in enumerate(tokens)] + + return token_positions def extract_entities_from_sentence(self, sentence: str, sentence_idx: int, all_sentences: List[str], fixed_entities_flag: bool, fixed_entities: List[str],entity_types: List[str]): @@ -356,6 +379,7 @@ def extract_entities_from_sentence(self, sentence: str, sentence_idx: int, all_s entity['noun_chunk_length'] = len(entity['noun_chunk'].split()) entities_withnnchunk = final_entities binary_pairs = self.extract_binary_pairs(entities_withnnchunk, tokens, all_sentences, sentence_idx) + return entities_withnnchunk, binary_pairs except Exception as e: self.logger.error(f"Error extracting entities from sentence: {e}") @@ -404,6 +428,7 @@ def remove_duplicates(self, data): if cleaned_sublist: new_data.append(cleaned_sublist) + return new_data diff --git a/querent/kg/rel_helperfunctions/attn_based_relationship_filter.py b/querent/kg/rel_helperfunctions/attn_based_relationship_filter.py new file mode 100644 index 00000000..f9a2b877 --- /dev/null +++ b/querent/kg/rel_helperfunctions/attn_based_relationship_filter.py @@ -0,0 +1,260 @@ +import ast +import json +import torch +from querent.kg.rel_helperfunctions.attn_based_relationship_seach_scope import SearchContextualRelationship as sc +from querent.kg.rel_helperfunctions.attn_based_relationship_seach_scope import EntityPair as ep +from querent.kg.rel_helperfunctions.attn_based_relationship_seach_scope import perform_search +from dataclasses import dataclass +from querent.logging.logger import setup_logger +from typing import Optional +from collections import defaultdict +import numpy +from querent.kg.ner_helperfunctions.ner_llm_transformer import NER_LLM + +@dataclass +class Entity: + text: str + + +@dataclass +class SemanticPairs: + head: Entity + tail: Entity + relations: list[str] + scores: list[float] + + +class IndividualFilter: + def __init__(self, forward_relations: bool, threshold: float, token_idx_with_word: list, spacy_doc): + self.forward_relations = forward_relations + self.threshold = threshold + self.token_idx_with_word = token_idx_with_word + self.doc = spacy_doc + + def filter(self, candidates: list[sc], e_pair: ep) -> SemanticPairs: + response = SemanticPairs( + head=Entity( + text=e_pair.head_entity['noun_chunk'].lower() + ), + tail=Entity( + text=e_pair.tail_entity['noun_chunk'].lower() + ), + relations=[], + scores = [0] + ) + counter = 0 + for candidate in candidates: + if candidate.mean_score() < self.threshold: + continue + rel_txt = '' + rel_idx = [] + last_index = None + valid = True + + for token_id in candidate.relation_tokens: + word, word_id = self.token_idx_with_word[token_id -1] + if self.forward_relations and last_index is not None and word_id - last_index != 1: + valid = False + break + last_index = word_id + + if len(rel_txt) > 0: + rel_txt += ' ' + lowered_word = word.lower() + if lowered_word not in e_pair.head_entity['noun_chunk'] and lowered_word not in e_pair.tail_entity['noun_chunk']: + rel_txt += word.lower() + rel_idx.append(word_id) + + if valid: + rel_txt = self.lemmatize(rel_txt, rel_idx) + if len(rel_txt) == 0: + continue + + response.relations.append(rel_txt) + response.scores.append(candidate.mean_score()) + counter = counter +1 + del response.scores[0] + return response + + def combine_entities(self, entity_list): + # This list will store the final entities after combining + combined_entities = [] + # Temporary storage for current entity being processed + current_entity = None + + for entity, index in entity_list: + if entity.startswith('##'): + # If the entity starts with ##, concatenate it with the last part of current_entity + if current_entity: + current_entity = (current_entity[0] + entity[2:], current_entity[1]) + else: + # If the current_entity is not None, it means we have completed processing an entity + if current_entity: + combined_entities.append(current_entity) + # Start a new entity + current_entity = (entity, len(combined_entities) + 1) + + # Append the last processed entity if any + if current_entity: + combined_entities.append(current_entity) + + return combined_entities + + def lemmatize(self, relation: str, indexes: list[int]) -> str: + if relation.isnumeric(): + return '' + + new_text = '' + # Another option would be including 'AUX' + remove_morpho = {'SYM', 'OTHER', 'PUNCT', 'NUM', 'INTJ', 'DET', 'ADP', 'PART'} + last_word = ' ' + words = [] + for idx in indexes: + words.append(self.token_idx_with_word[idx -1]) + words = self.combine_entities(words) + for word, word_id in words: + token = next((token for token in self.doc if word in token.text.lower()), None) + if token and token.pos_ not in remove_morpho: + new_word = token.lemma_.lower() + if last_word != new_word: + new_text += new_word + new_text += ' ' + last_word = new_word + + new_text = new_text.strip() + return new_text + +def get_best_relation(semantic_pair): + scores = [score.item() if isinstance(score, torch.Tensor) else score for score in semantic_pair.scores] + max_index = scores.index(max(scores)) + best_relation = semantic_pair.relations[max_index] + best_score = scores[max_index] + + return best_relation, best_score + + +def frequency_cutoff(ht_relations: list[SemanticPairs], frequency: int): + if frequency == 1: + return + counter: dict[str, int] = {} + for ht_item in ht_relations: + for relation in ht_item.relations: + if relation in counter: + counter[relation] += 1 + else: + counter[relation] = 1 + + for ht_item in ht_relations: + ht_item.relations = [rel for rel in ht_item.relations if counter[rel] >= frequency] + +def trim_triples(data): + try: + trimmed_data = [] + for entity1, predicate, entity2 in data: + predicate_dict = json.loads(predicate) + trimmed_predicate = { + 'context': predicate_dict.get('context', ''), + 'entity1_nn_chunk': predicate_dict.get('entity1_nn_chunk', ''), + 'entity2_nn_chunk': predicate_dict.get('entity2_nn_chunk', ''), + 'entity1_label': predicate_dict.get('entity1_label', ''), + 'entity2_label': predicate_dict.get('entity2_label', ''), + 'file_path': predicate_dict.get('file_path', ''), + 'current_sentence': predicate_dict.get('current_sentence', '') + } + trimmed_data.append((entity1, trimmed_predicate, entity2)) + + return trimmed_data + except Exception as e: + raise Exception(f'Error in trimming triples: {e}') + +def process_tokens(ner_instance : NER_LLM, extractor, filtered_triples, nlp_model): + try: + updated_triples = [] + for subject, predicate_metadata, object in filtered_triples: + try: + context = predicate_metadata['current_sentence'].replace("\n"," ") + head_positions = ner_instance.find_subword_indices(context, predicate_metadata['entity1_nn_chunk']) + tail_positions = ner_instance.find_subword_indices(context, predicate_metadata['entity2_nn_chunk']) + if head_positions[0][0] > tail_positions[0][0]: + head_entity = {'entity': object, 'noun_chunk':predicate_metadata['entity2_nn_chunk'], 'entity_label':predicate_metadata['entity2_label'] } + tail_entity = {'entity': subject, 'noun_chunk':predicate_metadata['entity1_nn_chunk'], 'entity_label':predicate_metadata['entity1_label']} + entity_pair = ep(head_entity, tail_entity, context, tail_positions, head_positions) + else: + head_entity = {'entity': subject, 'noun_chunk':predicate_metadata['entity1_nn_chunk'], 'entity_label':predicate_metadata['entity1_label']} + tail_entity = {'entity': object, 'noun_chunk':predicate_metadata['entity2_nn_chunk'], 'entity_label':predicate_metadata['entity2_label']} + entity_pair = ep(head_entity, tail_entity, context, head_positions, tail_positions) + tokenized_sentence = extractor.tokenize_sentence(context) + model_input = extractor.model_input(tokenized_sentence) + attention_matrix = extractor.inference_attention(model_input) + token_idx_with_word = ner_instance.tokenize_sentence_with_positions(context) + spacy_doc = nlp_model(context) + filter = IndividualFilter(True, 0.02, token_idx_with_word, spacy_doc) + + ## HEAD Entity Based Attention Search + candidate_paths = perform_search(entity_pair.head_entity['start_idx'], attention_matrix, entity_pair, search_candidates=5, require_contiguous=True, max_relation_length=8, num_initial_tokens=extractor.num_start_tokens()) + candidate_paths = remove_duplicates(candidate_paths) + filtered_results = filter.filter(candidates=candidate_paths,e_pair=entity_pair) + predicate_he, score_he = get_best_relation(filtered_results) + + ##TAIL ENTITY Based Attention Search + candidate_paths = perform_search(entity_pair.tail_entity['start_idx'], attention_matrix, entity_pair, search_candidates=5, require_contiguous=True, max_relation_length=8, num_initial_tokens=extractor.num_start_tokens()) + candidate_paths = remove_duplicates(candidate_paths) + filtered_results = filter.filter(candidates=candidate_paths,e_pair=entity_pair) + predicate_te, score_te = get_best_relation(filtered_results) + + if score_he > score_te and (score_he >= 0.1 or score_te >= 0.1): + triple = create_semantic_triple(head_entity=head_entity['noun_chunk'], + tail_entity=tail_entity['noun_chunk'], + predicate=predicate_he, + score=score_he, + predicate_metadata=predicate_metadata, + subject_type=head_entity['entity_label'], + object_type=tail_entity['entity_label']) + updated_triples.append(triple) + elif score_he < score_te and (score_he >= 0.1 or score_te >= 0.1): + triple = create_semantic_triple(head_entity=tail_entity['noun_chunk'], + tail_entity=head_entity['noun_chunk'], + predicate=predicate_te, + score=score_te, + predicate_metadata=predicate_metadata, + subject_type=tail_entity['entity_label'], + object_type=head_entity['entity_label']) + updated_triples.append(triple) + except Exception as e: + continue + return updated_triples + except Exception as e: + raise Exception(f'Error in extracting Attention Based Relationships: {e}') + + +def remove_duplicates(candidate_paths): + seen_relations = set() + unique_paths = [] + + for path in candidate_paths: + # Convert the relation_tokens to a tuple to make it hashable + relation_tokens_tuple = tuple(path.relation_tokens) + if relation_tokens_tuple not in seen_relations: + seen_relations.add(relation_tokens_tuple) + unique_paths.append(path) + + return unique_paths + +def create_semantic_triple(head_entity, tail_entity, predicate, score, predicate_metadata, subject_type, object_type): + try: + triple = ( + head_entity, + json.dumps({ + "predicate": predicate, + "predicate_type": "", + "context": predicate_metadata["context"].replace('\n',' '), + "file_path": predicate_metadata["file_path"], + "subject_type": subject_type, + "object_type": object_type, + "score":score, + }), + tail_entity + ) + return triple + except Exception as e: + raise Exception(f"Error in creating semantic triple: {e}") diff --git a/querent/kg/rel_helperfunctions/attn_based_relationship_model_getter.py b/querent/kg/rel_helperfunctions/attn_based_relationship_model_getter.py new file mode 100644 index 00000000..75afcc03 --- /dev/null +++ b/querent/kg/rel_helperfunctions/attn_based_relationship_model_getter.py @@ -0,0 +1,150 @@ +import torch +import transformers +from transformers import AutoTokenizer + + + +class AttnRelationshipExtractor: + def __init__(self, model_tokenizer, model): + self.tokenizer = model_tokenizer + self.model = model + + def init_token_idx_2_word_doc_idx(self) -> list[tuple[str, int]]: + """ + This function initializes a dictionary of token index to spacy doc index. It should contain only the + first token in a tokenized sentence, alongside its corresponding doc index in the sentence. + For BERT, this is ('CLS', -1). We use -1 because the CLS does not correspond to a word in the sentence. + :return: A list with the first tokenized item and its doc index. + """ + pass + + def num_start_tokens(self) -> int: + """ + This function returns the number of start tokens in a tokenized sentence. + :return: Integer, representing the number of start tokens. + """ + pass + + def append_last_token(self, listing: list[tuple[str, int]]): + """ + Appends the last token of a tokenized sentence. In the case of BERT, this is only + ('SEP', len(linsting)), as 'SEP' indicates the end of the sentence. + :param listing: List of tokenized words and their corresponding Spacy doc index. + """ + pass + + def model_input(self, tokenized_sequence: list[int]) -> dict[str, torch.Tensor]: + """ + This function prepares the model input. It should correspond to the exact dictionary the model expects. + :param tokenized_sequence: The sentence Havina has tokenized. + :return: The dictionary the language model expects as inputs. + """ + pass + + def tokenize(self, word: str): + """ + Tokenize a word using the model's specific tokenizer. + :param word: A word to tokenize. + """ + pass + + def inference_attention(self, model_input: dict[str, torch.Tensor]): + """ + Perform the inference and return the average of all the attention matrices in the model's last layer. + :param model_input: The language model's input. + :return: + """ + pass + + def maximum_tokens(self) -> int: + """ + Returns the maximum sequence length the language model can consume. + :return: An integer, representing the maximum number of tokens the language model can handle. + """ + pass + + def tokenize_sentence(self, sentence: str): + pass + +class BertBasedModel(AttnRelationshipExtractor): + def __init__(self, model_tokenizer, model): + super().__init__(model_tokenizer, model) + + def init_token_idx_2_word_doc_idx(self) -> list[tuple[str, int]]: + return [('CLS', -1)] + + def num_start_tokens(self) -> int: + return 1 + + def append_last_token(self, listing: list[tuple[str, int]]): + listing.append(('SEP', len(listing))) + + def model_input(self, tokenized_sentence: list[int]) -> dict[str, torch.Tensor]: + tokenized_sentence = [self.tokenizer.cls_token_id] + tokenized_sentence + [self.tokenizer.sep_token_id] + input_dict = { + 'input_ids': torch.tensor(tokenized_sentence, device='cpu').long().unsqueeze(0), + 'token_type_ids': torch.zeros(len(tokenized_sentence), device='cpu').long().unsqueeze(0), + 'attention_mask': torch.ones(len(tokenized_sentence), device='cpu').long().unsqueeze(0), + } + return input_dict + + def tokenize(self, word): + return self.tokenizer(str(word), add_special_tokens=False)['input_ids'] + + def inference_attention(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor: + output = self.model(**model_input, output_attentions=True) + last_att_layer = output.attentions[-1] + mean = torch.mean(last_att_layer, dim=1) + return mean[0] + + def maximum_tokens(self) -> int: + return 512 + + def tokenize_sentence(self, sentence: str): + return self.tokenizer.encode(sentence, add_special_tokens=False) + + +class LlamaBasedModel(AttnRelationshipExtractor): + def __init__(self, model_tokenizer, model): + super().__init__(model_tokenizer, model) + + def init_token_idx_2_word_doc_idx(self) -> list[tuple[str, int]]: + return [] + + def num_start_tokens(self) -> int: + return 0 + + def append_last_token(self, listing: list[tuple[str, int]]): + pass + + def model_input(self, tokenized_sequence: list[int]) -> dict[str, torch.Tensor]: + input_dict = { + 'input_ids': torch.tensor(tokenized_sequence, device=self.device).long().unsqueeze(0), + 'attention_mask': torch.ones(len(tokenized_sequence), device=self.device).long().unsqueeze(0) + } + return input_dict + + def tokenize(self, word: str): + return self.tokenizer(str(word), add_special_tokens=False)['input_ids'] + + def inference_attention(self, model_input: dict[str, torch.Tensor]): + output = self.model(**model_input, output_attentions=True) + last_att_layer = output.attentions[-1] + mean = torch.mean(last_att_layer, dim=1) + return mean[0] + + def maximum_tokens(self) -> int: + return 2048 + + def tokenize_sentence(self, sentence: str): + return self.tokenizer.encode(sentence, add_special_tokens=False) + + +def get_model(model_name:str, model_tokenizer: str, model: str) -> AttnRelationshipExtractor: + if model_name == 'bert': + return BertBasedModel(model_tokenizer, model) + elif model_name == 'llama': + return LlamaBasedModel(model_tokenizer, model) + + raise Exception("Model not found") + diff --git a/querent/kg/rel_helperfunctions/attn_based_relationship_seach_scope.py b/querent/kg/rel_helperfunctions/attn_based_relationship_seach_scope.py new file mode 100644 index 00000000..68d6975a --- /dev/null +++ b/querent/kg/rel_helperfunctions/attn_based_relationship_seach_scope.py @@ -0,0 +1,103 @@ +import torch +import copy +from typing import List, Tuple, Dict +import numpy + +class EntityPair: + def __init__(self, head_entity: Dict, tail_entity: Dict, context: str, head_positions, tail_positions): + self.head_entity = head_entity + self.tail_entity = tail_entity + self.context = context + self.head_entity['start_idx'], self.head_entity['end_idx'] = head_positions[0] + self.tail_entity['start_idx'], self.tail_entity['end_idx'] = tail_positions[0] + +class SearchContextualRelationship: + + def __init__(self, initial_token_id): + self.current_token = initial_token_id + self.total_score = 0 + self.visited_tokens = [initial_token_id] + self.relation_tokens = [] + + def add_token(self, token_id, score): + self.current_token = token_id + self.visited_tokens.append(token_id) + self.total_score += score + self.relation_tokens.append(token_id) + + def has_relation(self) -> bool: + return len(self.relation_tokens) > 0 + + def finalize_path(self, score): + self.total_score += score + + def mean_score(self) -> float: + if len(self.relation_tokens) == 0: + return 0 + return self.total_score / len(self.relation_tokens) + + +def sort_by_mean_score(path: SearchContextualRelationship) -> float: + return path.mean_score() + + +def is_valid_token(token_id, pair: EntityPair, candidate_paths: List[SearchContextualRelationship], current_path: SearchContextualRelationship, score: float) -> bool: + if pair.tail_entity['start_idx'] <= token_id <= pair.tail_entity['end_idx']: + if current_path.has_relation(): + current_path.finalize_path(score) + candidate_paths.append(current_path) + return False + + return not (pair.head_entity['start_idx'] <= token_id <= pair.head_entity['end_idx'] or + pair.tail_entity['start_idx'] <= token_id <= pair.tail_entity['end_idx']) + + + +def perform_search(entity_start_index, attention_matrix: torch.Tensor, entity_pair: EntityPair, search_candidates: int, require_contiguous: bool, max_relation_length: int, num_initial_tokens: int) -> List[SearchContextualRelationship]: + """ + Initialize the perform search function with the following parameters: + :param attention_matrix :Mean attention score, average attention each token pays to every other token showing which tokens are most related to each other in the context of the given sentence(s). + :param search_candidates: Number of candidates to select for the next iteration of the search + :param contiguous_token: When generating relations, consider only those with contiguous tokens + :param max_relation_length: Maximum quantity of tokens allowed in a relation. + :patam num_initial_tokens: Different for different models. E.g. 'Bert' adds a '[CLS]' to the start of a sequence, so it is 1. + + """ + try: + queue = [ + SearchContextualRelationship(entity_start_index) + ] + candidate_paths = [] + visited_paths = set() + while len(queue) > 0: + current_path = queue.pop(0) + + if len(current_path.relation_tokens) > max_relation_length: + continue + + if require_contiguous and len(current_path.relation_tokens) > 1 and abs(current_path.relation_tokens[-2] - current_path.relation_tokens[-1]) != 1: + continue + + new_paths = [] + + # How all other tokens attend to an entity e.g. "Emily Stanton" + # These scores indicate how much importance the model places on each token when considering "Emily Stanton." + # The tokens which consider entity "Emily Stanton" important, highlight entity's relationships and relevance within the sentence. + + attention_scores = attention_matrix[:, current_path.current_token] + for i in range(num_initial_tokens, len(attention_scores) - 1): + next_path = tuple(current_path.visited_tokens + [i]) + if is_valid_token(i, entity_pair, candidate_paths, current_path, attention_scores[i].detach()) and next_path not in visited_paths and current_path.current_token != i: + new_paths.append( + copy.deepcopy(current_path) + ) + new_paths[-1].add_token(i, attention_scores[i].detach()) + visited_paths.add(next_path) + new_paths.sort(key=sort_by_mean_score, reverse=True) + queue += new_paths[:search_candidates] + + return candidate_paths + except Exception as e: + raise e + + diff --git a/querent/kg/rel_helperfunctions/contextual_predicate.py b/querent/kg/rel_helperfunctions/contextual_predicate.py index 908196d4..018e5146 100644 --- a/querent/kg/rel_helperfunctions/contextual_predicate.py +++ b/querent/kg/rel_helperfunctions/contextual_predicate.py @@ -42,6 +42,7 @@ class ContextualPredicate(BaseModel): pair_attnscore: float entity1_embedding: List[float] entity2_embedding: List[float] + current_sentence: str @classmethod @@ -63,7 +64,8 @@ def from_tuple(cls, data: Tuple[str, str, str, Dict[str, str], str]) -> 'Context pair_attnscore=data[3].get('pair_attnscore',1), entity1_embedding=entity1_embedding, entity2_embedding=entity2_embedding, - file_path=data[4] + file_path=data[4], + current_sentence = data[3].get('current_sentence'), ) except Exception as e: raise ValueError(f"Error creating ContextualPredicate from tuple: {e}") diff --git a/querent/kg/rel_helperfunctions/embedding_store.py b/querent/kg/rel_helperfunctions/embedding_store.py index 1faa4ef5..0c0dcb88 100644 --- a/querent/kg/rel_helperfunctions/embedding_store.py +++ b/querent/kg/rel_helperfunctions/embedding_store.py @@ -40,6 +40,7 @@ def generate_embeddings(self, payload, relationship_finder=False, generate_embed predicate_type = data.get("predicate_type","Unlabeled").replace('"', '\\"') subject_type = data.get("subject_type","Unlabeled").replace('"', '\\"') object_type = data.get("object_type","Unlabeled").replace('"', '\\"') + score = data.get("score") context_embeddings = None predicate_embedding = None context_embeddings = self.get_embeddings([context])[0] @@ -54,7 +55,8 @@ def generate_embeddings(self, payload, relationship_finder=False, generate_embed "predicate": predicate, "subject_type": subject_type, "object_type": object_type, - "predicate_emb": predicate_embedding if predicate_embedding is not None else "Not Implemented" + "predicate_emb": predicate_embedding if predicate_embedding is not None else "Not Implemented", + "score":score } updated_json_string = json.dumps(essential_data) processed_pairs.append( diff --git a/querent/kg/rel_helperfunctions/opeai_ratelimiter.py b/querent/kg/rel_helperfunctions/opeai_ratelimiter.py deleted file mode 100644 index 981e2b03..00000000 --- a/querent/kg/rel_helperfunctions/opeai_ratelimiter.py +++ /dev/null @@ -1,14 +0,0 @@ -import time - -class RateLimiter: - def __init__(self, requests_per_minute): - self.requests_per_minute = requests_per_minute - self.timestamps = [] - - def wait_for_request_slot(self): - while len(self.timestamps) >= self.requests_per_minute: - if time.time() - self.timestamps[0] > 60: - self.timestamps.pop(0) - else: - time.sleep(1) - self.timestamps.append(time.time()) diff --git a/querent/kg/rel_helperfunctions/triple_to_json.py b/querent/kg/rel_helperfunctions/triple_to_json.py index 1e999f4a..950d3b4a 100644 --- a/querent/kg/rel_helperfunctions/triple_to_json.py +++ b/querent/kg/rel_helperfunctions/triple_to_json.py @@ -1,5 +1,6 @@ import json import re +import numpy as np """ A class to convert triples into different JSON formats. @@ -44,15 +45,29 @@ def convert_graphjson(triple): "object_type": TripleToJsonConverter._normalize_text(predicate_info.get("object_type", "Unlabeled"), replace_space=True), "predicate": TripleToJsonConverter._normalize_text(predicate_info.get("predicate", ""), replace_space=True), "predicate_type": TripleToJsonConverter._normalize_text(predicate_info.get("predicate_type", "Unlabeled"), replace_space=True), - "sentence": predicate_info.get("context", "").lower() + "sentence": predicate_info.get("context", "").lower(), + "score": predicate_info.get("score", 1) } return json_object except Exception as e: raise Exception(f"Error in convert_graphjson: {e}") + + def dynamic_weighted_average_embeddings(embeddings, base_weights, normalize_weights=True): + embeddings = [np.array(emb) for emb in embeddings] + weights = np.array(base_weights, dtype=float) + + if normalize_weights: + weights /= np.sum(weights) # Normalize weights to sum to 1 + + weighted_sum = np.zeros_like(embeddings[0]) + for emb, weight in zip(embeddings, weights): + weighted_sum += emb * weight + + return weighted_sum @staticmethod - def convert_vectorjson(triple, blob = None): + def convert_vectorjson(triple, blob = None, embeddings=None): try: subject, json_str, object_ = triple data = TripleToJsonConverter._parse_json_str(json_str) @@ -62,8 +77,8 @@ def convert_vectorjson(triple, blob = None): id_format = f"{TripleToJsonConverter._normalize_text(subject,replace_space=True)}-{TripleToJsonConverter._normalize_text(data.get('predicate', ''),replace_space=True)}-{TripleToJsonConverter._normalize_text(object_,replace_space=True)}" json_object = { "id": TripleToJsonConverter._normalize_text(id_format), - "embeddings": data.get("context_embeddings", []), - "size": len(data.get("context_embeddings", [])), + "embeddings": embeddings.tolist(), + "size": len(embeddings.tolist()), "namespace": TripleToJsonConverter._normalize_text(data.get("predicate", ""),replace_space=True), "sentence": data.get("context", "").lower(), "blob": blob, diff --git a/querent/kg/rel_helperfunctions/openllm.py b/querent/models/__init__.py similarity index 100% rename from querent/kg/rel_helperfunctions/openllm.py rename to querent/models/__init__.py diff --git a/querent/models/gguf_metadata_extractor.py b/querent/models/gguf_metadata_extractor.py new file mode 100644 index 00000000..d35c6fe2 --- /dev/null +++ b/querent/models/gguf_metadata_extractor.py @@ -0,0 +1,87 @@ +from __future__ import annotations +import sys +from pathlib import Path +import logging +import numpy as np +from gguf import GGUFReader, GGUFValueType +import json + +class GGUFMetadataExtractor: + def __init__(self, model_path: str): + self.model_path = model_path + self.reader = GGUFReader(self.model_path, 'r') + self.model_arch_value = None + + def get_file_host_endian(self) -> tuple[str, str]: + host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG' + file_endian = 'BIG' if self.reader.byte_order == 'S' else host_endian + return (host_endian, file_endian) + + def dump_metadata(self) -> None: + params = [] + for n, field in enumerate(self.reader.fields.values(), 1): + pretty_type = self.format_field_type(field) + log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}' + if len(field.types) == 1: + curr_type = field.types[0] + if curr_type == GGUFValueType.STRING: + log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60])) + elif curr_type in self.reader.gguf_scalar_to_np: + log_message += ' = {0}'.format(field.parts[-1][0]) + params.append(log_message) + return params + + def format_field_type(self, field) -> str: + if not field.types: + return 'N/A' + if field.types[0] == GGUFValueType.ARRAY: + nest_count = len(field.types) - 1 + return '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + return str(field.types[-1].name) + + def dump_metadata_json(self) -> None: + host_endian, file_endian = self.get_file_host_endian() + metadata, tensors = {}, {} + result = { + "filename": self.model_path, + "endian": file_endian, + "metadata": metadata, + "tensors": tensors, + } + self.fill_metadata_json(metadata) + json.dump(result, sys.stdout) + + def fill_metadata_json(self, metadata): + for idx, field in enumerate(self.reader.fields.values()): + curr = { + "index": idx, + "type": field.types[0].name if field.types else 'UNKNOWN', + "offset": field.offset, + "value": self.extract_field_value(field) + } + metadata[field.name] = curr + + def extract_field_value(self, field): + if field.types[:1] == [GGUFValueType.ARRAY]: + if field.types[-1] == GGUFValueType.STRING: + return [str(bytes(part, encoding="utf-8")) for part in field.parts] + return [pv.tolist() for part in field.parts for pv in part] + if field.types[0] == GGUFValueType.STRING: + return str(bytes(field.parts[-1], encoding="utf-8")) + return field.parts[-1].tolist() + + def extract_general_name(self, lines): + for line in lines: + if "general.architecture" in line: + parts = line.split('=') + if len(parts[1].strip().strip("'")) > 1: + return parts[1].strip().strip("'") + return "Name not found" + + +# Usage example +# if __name__ == '__main__': +# extractor = GGUFMetadataExtractor("/home/nishantg/querent-main/querent/tests/llama-2-7b-chat.Q5_K_M.gguf") +# model_metadata = extractor.dump_metadata() +# model_name = extractor.extract_general_name(model_metadata) +# print(model_name) \ No newline at end of file diff --git a/querent/models/model.py b/querent/models/model.py new file mode 100644 index 00000000..7a5140da --- /dev/null +++ b/querent/models/model.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + +class Model(ABC): + """ Abstract base class for all models. """ + def __init__(self, model_name): + self.model_name = model_name + + @abstractmethod + def return_model_name(self): + """ Return the model name, to be implemented by all subclasses. """ + pass diff --git a/querent/models/model_factory.py b/querent/models/model_factory.py new file mode 100644 index 00000000..21b8edb7 --- /dev/null +++ b/querent/models/model_factory.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from querent.models.model import Model + + +class ModelFactory(ABC): + """ Abstract factory for creating models. """ + @abstractmethod + def create(self, model_name) -> Model: + """ Method to create model instances. """ + pass + + diff --git a/querent/models/model_manager.py b/querent/models/model_manager.py new file mode 100644 index 00000000..c0626697 --- /dev/null +++ b/querent/models/model_manager.py @@ -0,0 +1,26 @@ +from querent.models.ner_models.english.english import EnglishFactory +from querent.models.ner_models.geobert.geobert import GeoBERTFactory +from querent.models.rel_models.llama.llama import LLAMAFactory + +class ModelManager: + def __init__(self): + # Maps model identifiers to their corresponding factory classes + self.factory_map = { + "English": EnglishFactory, + "GeoBERT": GeoBERTFactory, + "llama" : LLAMAFactory, + } + + def get_model(self, model_identifier, model_path = None): + factory_class = self.factory_map.get(model_identifier) + if not factory_class: + raise Exception(f"No factory available for the model identifier: {model_identifier}") + factory = factory_class() + model = factory.create(model_identifier) + if not model_path: + return model.return_model_name() + else: + if model.return_model_name(): + return model_path + else: + raise Exception("No factory available for the model identifier: {model_identifier} and model path : {model_path}") diff --git a/querent/models/ner_models/__init__.py b/querent/models/ner_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/models/ner_models/english/__init__.py b/querent/models/ner_models/english/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/models/ner_models/english/english.py b/querent/models/ner_models/english/english.py new file mode 100644 index 00000000..e65f75be --- /dev/null +++ b/querent/models/ner_models/english/english.py @@ -0,0 +1,21 @@ +from querent.models.model import Model +from querent.models.model_factory import ModelFactory + + +class English(Model): + """ A specific implementation of a Model for English language processing. """ + def __init__(self, model_name): + super().__init__(model_name) + self.model_instance = None + + def return_model_name(self): + """ Returns the specific model name for the English model. """ + self.model_instance = "dbmdz/bert-large-cased-finetuned-conll03-english" + return self.model_instance + +class EnglishFactory(ModelFactory): + """ Factory for creating English model instances. """ + def create(self, model_name: str) -> Model: + return English(model_name) + + diff --git a/querent/models/ner_models/geobert/__init__.py b/querent/models/ner_models/geobert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/models/ner_models/geobert/geobert.py b/querent/models/ner_models/geobert/geobert.py new file mode 100644 index 00000000..56b7bee0 --- /dev/null +++ b/querent/models/ner_models/geobert/geobert.py @@ -0,0 +1,19 @@ +from querent.models.model import Model +from querent.models.model_factory import ModelFactory + +class GeoBERT(Model): + """ A specific implementation of a Model for GeoBERT NER. """ + def __init__(self, model_name): + super().__init__(model_name) + self.model_instance = None + + def return_model_name(self): + """ Returns the specific model name GeoBERT model. """ + self.model_instance = "botryan96/GeoBERT" + return self.model_instance + + +class GeoBERTFactory(ModelFactory): + """ Factory for creating English model instances. """ + def create(self, model_name: str) -> Model: + return GeoBERT(model_name) \ No newline at end of file diff --git a/querent/models/rel_models/__init__.py b/querent/models/rel_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/models/rel_models/llama/__init__.py b/querent/models/rel_models/llama/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/models/rel_models/llama/llama.py b/querent/models/rel_models/llama/llama.py new file mode 100644 index 00000000..310699d6 --- /dev/null +++ b/querent/models/rel_models/llama/llama.py @@ -0,0 +1,20 @@ +from querent.models.model import Model +from querent.models.model_factory import ModelFactory + +class LLAMA(Model): + """ A specific implementation of a Model for LLAMA2 language processing. """ + def __init__(self, model_name): + super().__init__(model_name) + self.model_instance = None + + def return_model_name(self): + """ Returns the specific model name for the LLAMA2 model. """ + self.model_instance = self.model_name + return self.model_instance + +class LLAMAFactory(ModelFactory): + """ Factory for creating LLAMA v2 model instances. """ + def create(self, model_name: str) -> Model: + return LLAMA(model_name) + + \ No newline at end of file diff --git a/querent/workflow/_helpers.py b/querent/workflow/_helpers.py index 77f75ad2..42202c64 100644 --- a/querent/workflow/_helpers.py +++ b/querent/workflow/_helpers.py @@ -8,8 +8,6 @@ from querent.common.uri import Uri from querent.ingestors.ingestor_manager import IngestorFactoryManager from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -from querent.core.transformers.gpt_llm_gpt_ner import GPTNERLLM -from querent.core.transformers.gpt_llm_bert_ner_or_fixed_entities_set_ner import GPTLLM from querent.common.types.querent_event import EventType from querent.querent.querent import Querent from querent.querent.resource_manager import ResourceManager @@ -110,34 +108,6 @@ def find_first_file(directory, extension): await asyncio.gather(querent_task, token_feeder, check_message_states_task) -async def start_gpt_workflow( - resource_manager: ResourceManager, config: Config, result_queue: QuerentQueue -): - search_directory = os.getenv('MODEL_PATH', '/model/') - setup_nltk_and_spacy_paths(config, search_directory) - # llm_instance = GPTNERLLM(result_queue, config.engines[0]) - llm_instance = GPTLLM(result_queue, config.engines[0]) - - llm_instance.subscribe(EventType.Graph, config.workflow.event_handler) - llm_instance.subscribe(EventType.Vector, config.workflow.event_handler) - querent = Querent( - [llm_instance], - resource_manager=resource_manager, - ) - querent_task = asyncio.create_task(querent.start()) - token_feeder = asyncio.create_task( - receive_token_feeder( - resource_manager=resource_manager, - config=config, - result_queue=result_queue, - ) - ) - check_message_states_task = asyncio.create_task( - check_message_states(config, resource_manager, [querent_task, token_feeder]) - ) - await asyncio.gather(querent_task, token_feeder, check_message_states_task) - - # Config workflow channel for setting termination event async def receive_token_feeder( resource_manager: ResourceManager, config: Config, result_queue: QuerentQueue diff --git a/querent/workflow/workflow.py b/querent/workflow/workflow.py index f4cfbe4a..33a72433 100644 --- a/querent/workflow/workflow.py +++ b/querent/workflow/workflow.py @@ -7,10 +7,6 @@ from querent.config.workflow.workflow_config import WorkflowConfig from querent.config.collector.collector_config import CollectorConfig from querent.config.core.llm_config import LLM_Config -from querent.config.core.gpt_llm_config import GPTConfig -from querent.collectors.collector_resolver import CollectorResolver -from querent.common.uri import Uri -from querent.ingestors.ingestor_manager import IngestorFactoryManager from querent.querent.resource_manager import ResourceManager from querent.workflow._helpers import * @@ -73,10 +69,7 @@ async def start_workflow(config_dict: dict): if is_engine_params: engine_config.update(engine_params_json) engine_config_source = engine_config.get("config", {}) - if engine_config["name"] == "knowledge_graph_using_openai": - engine_config.update({"openai_api_key": engine_config["config"]["openai_api_key"]}) - engines.append(GPTConfig(config_source=engine_config)) - elif engine_config["name"] == "knowledge_graph_using_llama2_v1": + if engine_config["name"] == "knowledge_graph_using_llama2_v1": engines.append(LLM_Config(config_source=engine_config)) config_dict["engines"] = engines config_dict["collectors"] = collectors @@ -84,7 +77,6 @@ async def start_workflow(config_dict: dict): config = Config(config_source=config_dict) workflows = { - "knowledge_graph_using_openai": start_gpt_workflow, "knowledge_graph_using_llama2_v1": start_llama_workflow, } diff --git a/requirements.txt b/requirements.txt index c99163ec..694c0c50 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ dropbox==11.36.2 fastembed==0.2.6 ffmpeg-python==0.2.0 gensim==4.3.2 +gguf==0.6.0 google-api-python-client==2.105.0 google-cloud-storage==2.14.0 hdbscan==0.8.33 diff --git a/setup.py b/setup.py index ff3995c9..725c62fe 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "fastembed==0.2.6", "ffmpeg-python==0.2.0", "gensim==4.3.2", + "gguf==0.6.0", "google-api-python-client==2.105.0", "google-cloud-storage==2.14.0", "hdbscan==0.8.33", @@ -79,7 +80,7 @@ setup( name="querent", - version="3.1.1", + version="3.1.2", author="Querent AI", description="The Asynchronous Data Dynamo and Graph Neural Network Catalyst", long_description=long_description, diff --git a/tests/collectors/test_email_collector.py b/tests/collectors/test_email_collector.py index 3b6eef74..7276719d 100644 --- a/tests/collectors/test_email_collector.py +++ b/tests/collectors/test_email_collector.py @@ -49,7 +49,7 @@ async def poll_and_print(): if chunk is not None: counter += 1 - assert counter == 2 + assert counter == 3 await poll_and_print() diff --git a/tests/data/llm/case_study_files/english.pdf b/tests/data/llm/case_study_files/english.pdf deleted file mode 100644 index 7888997e..00000000 Binary files a/tests/data/llm/case_study_files/english.pdf and /dev/null differ diff --git a/tests/ingestors/test_email_ingestor.py b/tests/ingestors/test_email_ingestor.py index 63f5919f..ac8246ae 100644 --- a/tests/ingestors/test_email_ingestor.py +++ b/tests/ingestors/test_email_ingestor.py @@ -56,7 +56,7 @@ async def poll_and_print(): assert ingested is not None if ingested != "" or ingested is not None: counter += 1 - assert counter == 4 + assert counter == 6 await poll_and_print() # Notice the use of await here diff --git a/tests/kg_tests/fixed_relationship_test.py b/tests/kg_tests/fixed_relationship_test.py index c93e651c..58b6972e 100644 --- a/tests/kg_tests/fixed_relationship_test.py +++ b/tests/kg_tests/fixed_relationship_test.py @@ -48,20 +48,7 @@ # images have been made by Kääb (2005). However, this approach is sometimes limited by # weather, clouds, and shadows in areas of high relief. # In a novel approach, the current study presents one of the most comprehensive assess- -# ments of the Gangotri glacier in recent times (2004–2011). The methodology entails the -# utilization of interferometric SAR (InSAR) coherence and sub-pixel offset tracking. While -# complementing most previous studies, the result presented here establishes the effective- -# ness of the techniques implemented to produce robust estimates of areal changes and glacier -# surface velocity in near real time. But, perhaps most importantly, this is one of the few stud- -# ies which has shown the melting trend of Gangotri glacier over a considerably continuous -# period during recent times (2004–2011). -# 2. Study area -# The Gangotri glacier is a valley-type glacier and one of the largest Himalayan glaciers -# located in Uttarkashi district of Uttarakhand, India (Figure 1). Extending between the lat- -# itudes 30◦ 43 22 N–30◦ 55 49 N and longitudes 79◦ 4 41 E–79◦ 16 34 E, Gangotri -# is the only major Himalayan glacier that flows towards the northwest. It spans a length -# of 30.2 km, its width varies between 0.20 and 2.35 km, and it thereby covers an area of -# about 86.32 km2 . While the average thickness of the Gangotri glacier is ∼200 m, its sur- +# ments of the Gangotri glacier in recent times (20fixed_relationshipsthe Gangotri glacier is ∼200 m, its sur- # face elevation varies from 4000 to 7000 m above mean sea level (Jain 2008). Gangotri has # three main tributaries, namely the Raktvarna, the Chaturangi, and the Kirti, with lengthsInternational Journal of Remote Sensing # 70° 0′ 00″ E 75° 0′ 00″ E 80° 0′ 00″ E 85° 0′ 00″ E 90° 0′ 00″ E @@ -69,23 +56,7 @@ # 440 km # 79º 10' E # Raktvarna -# 30° 55′ N -# 220 -# 30° 55′ N -# 35° 0′ 00″ N -# 0 -# 8655 -# Chaturangi -# 30° 0′ 00″ N -# 30° 50′ N -# tri -# go -# an -# G -# 25° 0′ 00″ N -# 30° 50′ N -# Kirti -# Shivling +# 30° 55′ Nfixed_relationships # Hills # Bhagirathi # Hills diff --git a/querent/kg/rel_helperfunctions/openai_functions.py b/tests/reduntant code/openai_functions.py similarity index 100% rename from querent/kg/rel_helperfunctions/openai_functions.py rename to tests/reduntant code/openai_functions.py diff --git a/tests/workflows/Postgres_new_algo.py b/tests/workflows/Postgres_new_algo.py new file mode 100644 index 00000000..7fd80928 --- /dev/null +++ b/tests/workflows/Postgres_new_algo.py @@ -0,0 +1,295 @@ +# import psycopg2 +# from psycopg2 import sql +# from psycopg2.extras import Json + +# from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore +# import numpy as np + +# class DatabaseManager: +# def __init__(self, dbname, user, password, host, port): +# self.dbname = dbname +# self.user = user +# self.password = password +# self.host = host +# self.port = port +# self.connection = None + +# def connect_db(self): +# try: +# self.connection = psycopg2.connect( +# dbname=self.dbname, +# user=self.user, +# password=self.password, +# host=self.host, +# port=self.port +# ) +# print("Database connection established") +# except Exception as e: +# print(f"Error connecting to database: {e}") + +# def create_tables(self): +# create_metadata_table_query = """ +# CREATE TABLE IF NOT EXISTS metadata ( +# id SERIAL PRIMARY KEY, +# subject VARCHAR(255), +# subject_type VARCHAR(255), +# predicate VARCHAR(255), +# object VARCHAR(255), +# object_type VARCHAR(255), +# sentence TEXT, +# file VARCHAR(255), +# doc_source VARCHAR(255) +# ); +# """ + +# create_embedding_table_query = """ +# CREATE TABLE IF NOT EXISTS embedding ( +# id SERIAL PRIMARY KEY, +# document_source VARCHAR, +# file VARCHAR, +# knowledge TEXT, +# sentence TEXT, +# predicate TEXT, +# embeddings vector(384) +# ); +# """ + +# try: +# with self.connection.cursor() as cursor: +# cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;") # Enable pgvector extension +# cursor.execute(create_metadata_table_query) +# cursor.execute(create_embedding_table_query) +# self.connection.commit() +# print("Tables created successfully") +# except Exception as e: +# print(f"Error creating tables: {e}") +# self.connection.rollback() + +# def insert_metadata(self, subject, subject_type, predicate, object, object_type, sentence, file, doc_source): +# insert_query = """ +# INSERT INTO metadata (subject, subject_type, predicate, object, object_type, sentence, file, doc_source) +# VALUES (%s, %s, %s, %s, %s, %s, %s, %s) +# RETURNING id; +# """ +# try: +# with self.connection.cursor() as cursor: +# cursor.execute(insert_query, (subject, subject_type, predicate, object, object_type, sentence, file, doc_source)) +# metadata_id = cursor.fetchone()[0] +# self.connection.commit() +# return metadata_id +# except Exception as e: +# print(f"Error inserting metadata: {e}") +# self.connection.rollback() + +# def insert_embedding(self,document_source, knowledge, sentence, predicate, embeddings, file): +# insert_query = """ +# INSERT INTO embedding (document_source, file, knowledge, sentence, predicate, embeddings) +# VALUES (%s, %s, %s, %s, %s, %s); +# """ +# try: +# with self.connection.cursor() as cursor: +# cursor.execute(insert_query, (document_source, file, knowledge, sentence, predicate, embeddings)) +# self.connection.commit() +# except Exception as e: +# print(f"Error inserting embedding: {e}") +# self.connection.rollback() + +# def close_connection(self): +# if self.connection: +# self.connection.close() +# print("Database connection closed") + +# def find_similar_embeddings(self, sentence_embedding, top_k=3, similarity_threshold=0.9): +# # print("Senetence embeddi ---", sentence_embedding) +# emb = sentence_embedding +# query = f""" +# SELECT id, 1 - (embeddings <=> '{emb}') AS cosine_similarity +# FROM public.embedding +# ORDER BY cosine_similarity DESC +# LIMIT {top_k}; +# """ +# try: +# with self.connection.cursor() as cursor: +# cursor.execute(query, (sentence_embedding, top_k)) +# results = cursor.fetchall() +# for result in results: +# print("Result -----------", result) +# filtered_results = [result for result in results if result[1] >= similarity_threshold] +# return filtered_results +# except Exception as e: +# print(f"Error in finding similar embeddings: {e}") +# return [] + +# def fetch_metadata_by_ids(self, metadata_ids): +# print("metafataaaa ids-----", metadata_ids) + +# query = """ +# SELECT * FROM public.metadata WHERE id IN %s; +# """ +# try: +# with self.connection.cursor() as cursor: +# cursor.execute(query, (tuple(metadata_ids),)) +# results = cursor.fetchall() +# return results +# except Exception as e: +# print(f"Error fetching metadata: {e}") +# return [] + +# def traverser_bfs(self, metadata_ids): +# print("Metadata IDs ---", metadata_ids) +# if not metadata_ids: +# return [] +# fetch_query = """ +# SELECT * FROM public.metadata WHERE id IN %s; +# """ +# incoming_query = """ +# SELECT * FROM public.metadata WHERE object = %s; +# """ +# outgoing_query = """ +# SELECT * FROM public.metadata WHERE subject = %s; +# """ + +# try: +# with self.connection.cursor() as cursor: +# cursor.execute(fetch_query, (tuple(metadata_ids),)) +# initial_results = cursor.fetchall() + +# related_results = [] + +# # For each row in the initial results, find incoming and outgoing edges +# for row in initial_results: +# subject = row[1] # 'subject' is the second column +# object = row[4] # 'object' is the fifth column + +# # Find incoming edges for the subject +# cursor.execute(incoming_query, (subject,)) +# incoming_edges = cursor.fetchall() + +# # Find outgoing edges for the object +# cursor.execute(outgoing_query, (object,)) +# outgoing_edges = cursor.fetchall() +# related_results.append({ +# 'metadata_id': row[0], +# 'subject': subject, +# 'object': object, +# 'incoming_edges': incoming_edges, +# 'outgoing_edges': outgoing_edges +# }) + +# return related_results + +# except Exception as e: +# print(f"Error fetching related metadata: {e}") +# return [] + +# def show_detailed_relationship_paths(self, data): +# for entry in data: +# print(f"Base Entry: Subject = {entry['subject']}, Object = {entry['object']}") +# print("Incoming Relationships:") +# if entry['incoming_edges']: +# for edge in entry['incoming_edges']: +# print(f" From {edge[1]} via {edge[3]} (Predicate) to {edge[4]} (Object)") +# print(f" Description: {edge[5]}") +# print(f" Source: {edge[7]}") +# else: +# print(" No incoming relationships found.") + +# print("Outgoing Relationships:") +# if entry['outgoing_edges']: +# for edge in entry['outgoing_edges']: +# print(f" From {edge[1]} via {edge[3]} (Predicate) to {edge[4]} (Object)") +# print(f" Description: {edge[6]}") +# print(f" Source: {edge[7]}") +# else: +# print(" No outgoing relationships found.") +# print("\n -----------------------------------------------------") + +# def suggest_queries_based_on_edges(self, data): +# print("Suggested Queries Based on Relationships:") +# for entry in data: +# subject = entry['subject'] +# object = entry['object'] + +# # Outgoing edges from the subject +# if entry['outgoing_edges']: +# print(f"From '{object}':") +# for edge in entry['outgoing_edges']: +# print(f" Explore '{edge[4]}' related to '{object}' via '{edge[3]}' (outgoing) : id - {edge[0]}") +# else: +# print(f"No outgoing queries suggested for '{object}'.") + +# # Incoming edges to the object +# if entry['incoming_edges']: +# print(f"To '{subject}':") +# for edge in entry['incoming_edges']: +# print(f" Explore '{edge[1]}' affecting '{subject}' via '{edge[3]}' (incoming) : id - {edge[0]}") +# else: +# print(f"No incoming queries suggested for '{subject}'.") + +# print("\n") + + + +# # Usage example +# if __name__ == "__main__": +# db_manager = DatabaseManager( +# dbname="querent_test", +# user="querent", +# password="querent", +# host="localhost", +# port="5432" +# ) + +# db_manager.connect_db() +# db_manager.create_tables() + +# # # Example data insertion +# # metadata_id = db_manager.insert_metadata( +# # subject='the_environmental_sciences_department', +# # subject_type='i_org', +# # predicate='have_be_advocate_clean_energy_use', +# # object='dr__emily_stanton', +# # object_type='i_per', +# # sentence='This is an example sentence.', +# # file='example_file', +# # doc_source='example_source' +# # ) + +# # db_manager.insert_embedding( +# # subject_emb=[0.1, 0.2, 0.3], # Example vectors +# # object_emb=[0.4, 0.5, 0.6], +# # predicate_emb=[0.7, 0.8, 0.9], +# # sentence_emb=[1.0, 1.1, 1.2], +# # metadata_id=metadata_id +# # ) +# # db_manager.update_database_with_averages() +# query_1 = "What is gas injection ?" +# # query_1 = "What is eagle ford shale porosity and permiability ?" +# # query_1 = "What is austin chalk formation ?" +# # query_1 = "What type of source rock does austin chalk reservoir have ?" +# # query_1 = "What are some of the important characteristics of Gulf of Mexico basin ?" +# # query_1 = "Which wells are producing oil ?" +# create_emb = EmbeddingStore() +# query_1_emb = create_emb.get_embeddings([query_1])[0] +# # Find similar embeddings in the database +# similar_embeddings = db_manager.find_similar_embeddings(query_1_emb, top_k=10) +# # Extract metadata IDs from the results +# metadata_ids = [result[0] for result in similar_embeddings] + +# # Fetch metadata for these IDs +# # metadata_results = db_manager.fetch_metadata_by_ids(metadata_ids) +# # print(metadata_results) +# # traverser_bfs_results = db_manager.traverser_bfs(metadata_ids=metadata_ids) +# # print(traverser_bfs_results) +# # print(db_manager.show_detailed_relationship_paths(traverser_bfs_results)) +# # print(db_manager.suggest_queries_based_on_edges(traverser_bfs_results)) + + +# ## Second Query +# # print("2nd Query ---------------------------------------------------") +# # user_choice = [27, 29, 171] +# # traverser_bfs_results = db_manager.traverser_bfs(metadata_ids=user_choice) +# # print(traverser_bfs_results) +# # print(db_manager.show_detailed_relationship_paths(traverser_bfs_results)) +# # print(db_manager.suggest_queries_based_on_edges(traverser_bfs_results)) +# db_manager.close_connection() diff --git a/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py b/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py index 5066ec6a..8ce840ac 100644 --- a/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py +++ b/tests/workflows/bert_llm_test_fixed_entities_predicates_workflow.py @@ -17,11 +17,30 @@ # from querent.querent.resource_manager import ResourceManager # from querent.querent.querent import Querent # import time +# from .Postgres_new_algo import DatabaseManager +# from querent.kg.rel_helperfunctions.embedding_store import EmbeddingStore +# import numpy as np # @pytest.mark.asyncio # async def test_ingest_all_async(): +# db_manager = DatabaseManager( +# dbname="querent_test", +# user="querent", +# password="querent", +# host="localhost", +# port="5432" +# ) + +# db_manager.connect_db() +# db_manager.create_tables() + +# create_emb = EmbeddingStore() + # # Set up the collectors -# directories = [ "./tests/data/llm/predicate_checker"] +# # directories = [ "/home/nishantg/querent-main/resp/Data/GOM Basin"] +# # directories = [ "/home/nishantg/querent-main/querent/tests/data/llm/case_study_files"] +# # directories = ["/home/nishantg/querent-main/querent/tests/data/llm/predicate_checker"] +# directories = ["/home/nishantg/querent-main/Demo_june 6/demo_files"] # collectors = [ # FSCollectorFactory().resolve( # Uri("file://" + str(Path(directory).resolve())), @@ -47,57 +66,195 @@ # ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) # resource_manager = ResourceManager() # bert_llm_config = LLM_Config( -# # ner_model_name="botryan96/GeoBERT", +# # ner_model_name="English", +# ner_model_name = "GeoBERT", +# rel_model_type = "bert", +# # rel_model_path = "/home/nishantg/Downloads/capybarahermes-2.5-mistral-7b.Q5_K_M.gguf", +# rel_model_path = 'bert-base-uncased', +# # rel_model_path = 'daryl149/llama-2-7b-chat-hf', # enable_filtering=True, # filter_params={ -# 'score_threshold': 0.5, +# 'score_threshold': 0.3, # 'attention_score_threshold': 0.1, -# 'similarity_threshold': 0.5, +# 'similarity_threshold': 0.2, # 'min_cluster_size': 5, # 'min_samples': 3, -# 'cluster_persistence_threshold':0.2 +# 'cluster_persistence_threshold':0.1 # } -# ,fixed_entities = ["university", "greenwood", "liam zheng", "department", "Metroville", "Emily Stanton", "Coach", "health", "training", "atheletes" ] -# ,sample_entities=["organization", "organization", "person", "department", "city", "person", "person", "method", "method", "person"] -# ,fixed_relationships=[ -# "Increase in research funding leads to environmental science focus", -# "Dr. Emily Stanton's advocacy for cleaner energy", -# "University's commitment to reduce carbon emissions", -# "Dr. Stanton's research influences architectural plans", -# "Collaborative project between sociology and environmental sciences", -# "Student government launches mental health awareness workshops", -# "Enhanced fitness programs improve sports teams' performance", -# "Coach Torres influences student-athletes' holistic health", -# "Partnership expands access to digital resources", -# "Interdisciplinary approach enriches academic experience" -# ] -# , sample_relationships=[ -# "Causal", -# "Contributory", -# "Causal", -# "Influential", -# "Collaborative", -# "Initiative", -# "Beneficial", -# "Influential", -# "Collaborative", -# "Enriching" -# ] -# ,is_confined_search = True +# ,fixed_entities = [ +# "Carbonate", "Clastic", "Porosity", "Permeability", +# "Oil saturation", "Water saturation", "Gas saturation", +# "Depth", "Size", "Temperature", +# "Pressure", "Oil viscosity", "Gas-oil ratio", +# "Water cut", "Recovery factor", "Enhanced recovery technique", +# "Horizontal drilling", "Hydraulic fracturing", "Water injection", "Gas injection", "Steam injection", +# "Seismic activity", "Structural deformation", "Faulting", +# "Cap rock integrity", "Compartmentalization", +# "Connectivity", "Production rate", "Depletion rate", +# "Exploration technique", "Drilling technique", "Completion technique", +# "Environmental impact", "Regulatory compliance", +# "Economic analysis", "Market analysis", "oil well", "gas well", "well", "oil field", "gas field", "eagle ford", "ghawar", "johan sverdrup", "karachaganak","maracaibo" +# ] +# , sample_entities = [ +# "rock_type", "rock_type", "reservoir_property", "reservoir_property", +# "reservoir_property", "reservoir_property", "reservoir_property", +# "reservoir_characteristic", "reservoir_characteristic", "reservoir_characteristic", +# "reservoir_characteristic", "reservoir_property", "reservoir_property", +# "production_metric", "production_metric", "recovery_technique", +# "drilling_technique", "recovery_technique", "recovery_technique", "recovery_technique", "recovery_technique", +# "geological_feature", "geological_feature", "geological_feature", +# "reservoir_feature", "reservoir_feature", +# "reservoir_feature", "production_metric", "production_metric", +# "exploration_method", "drilling_method", "completion_method", +# "environmental_aspect", "regulatory_aspect", +# "economic_aspect", "economic_aspect","hydrocarbon_source","hydrocarbon_source","hydrocarbon_source","hydrocarbon_source","hydrocarbon_source","reservoir","reservoir","reservoir","reservoir","reservoir" +# ] +# # ,fixed_entities = [ +# # "Hadean", "Archean", "Proterozoic", "Phanerozoic", +# # "Paleozoic", "Mesozoic", "Cenozoic", +# # "Cambrian", "Ordovician", "Silurian", "Devonian", "Carboniferous", "Permian", +# # "Triassic", "Jurassic", "Cretaceous", +# # "Paleogene", "Neogene", "Quaternary", +# # "Paleocene", "Eocene", "Oligocene", +# # "Miocene", "Pliocene", +# # "Pleistocene", "Holocene", +# # "Anticline", "Syncline", "Fault", "Salt dome", "Horst", "Graben", +# # "Reef", "Shoal", "Deltaic deposits", "Turbidite", "Channel sandstone", +# # "Sandstone", "Limestone", "Dolomite", "Shale", +# # "Source rock", "Cap rock", +# # "Crude oil", "Natural gas", "Coalbed methane", "Tar sands", "Gas hydrates", +# # "Structural trap", "Stratigraphic trap", "Combination trap", "Salt trap", "Unconformity trap", +# # "Hydrocarbon migration", "Hydrocarbon accumulation", +# # "Placer deposits", "Vein deposit", "Porphyry deposit", "Kimberlite pipe", "Laterite deposit", +# # "Volcanic rock", "Basalt", "Geothermal gradient", "Sedimentology", +# # "Paleontology", "Biostratigraphy", "Sequence stratigraphy", "Geophysical survey", +# # "Magnetic anomaly", "Gravitational anomaly", "Petrology", "Geochemistry", "Hydrogeology", "trap" +# # ] + +# # , sample_entities=[ +# # "geological_eon", "geological_eon", "geological_eon", "geological_eon", +# # "geological_era", "geological_era", "geological_era", +# # "geological_period", "geological_period", "geological_period", "geological_period", "geological_period", "geological_period", +# # "geological_period", "geological_period", "geological_period", +# # "geological_period", "geological_period", "geological_period", +# # "geological_epoch", "geological_epoch", "geological_epoch", +# # "geological_epoch", "geological_epoch", +# # "geological_epoch", "geological_epoch", "structural_feature", "structural_feature", "structural_feature", "structural_feature", "structural_feature", "structural_feature", +# # "stratigraphic_feature", "stratigraphic_feature", "stratigraphic_feature", "stratigraphic_feature", "stratigraphic_feature", +# # "rock_type", "rock_type", "rock_type", "rock_type", +# # "rock_type", "rock_type", "hydrocarbon_source", +# # "hydrocarbon", "hydrocarbon", "hydrocarbon", "hydrocarbon", +# # "trap_type", "trap_type", "trap_type", "trap_type", "trap_type", +# # "geological_process", "geological_process", +# # "mineral_deposit", "mineral_deposit", "mineral_deposit", "mineral_deposit", "mineral_deposit", +# # "rock_type", "rock_type", "geological_process", "geological_discipline", +# # "geological_discipline", "geological_method", "geological_method", "geological_method", +# # "geophysical_feature", "geophysical_feature", "geological_discipline", "geological_discipline", "geological_discipline", "trap_type" +# # ] +# # ,fixed_entities = ["university", "greenwood", "liam zheng", "department", "Metroville", "Emily Stanton", "Coach", "health", "training", "atheletes" ] +# # ,sample_entities=["organization", "organization", "person", "department", "city", "person", "person", "method", "method", "person"] +# ,is_confined_search = True +# # fixed_relationships=[ +# # "Increase in research funding leads to environmental science focus", +# # "Dr. Emily Stanton's advocacy for cleaner energy", +# # "University's commitment to reduce carbon emissions", +# # "Dr. Stanton's research influences architectural plans", +# # "Collaborative project between sociology and environmental sciences", +# # "Student government launches mental health awareness workshops", +# # "Enhanced fitness programs improve sports teams' performance", +# # "Coach Torres influences student-athletes' holistic health", +# # "Partnership expands access to digital resources", +# # "Interdisciplinary approach enriches academic experience" +# # ] +# # , sample_relationships=[ +# # "Causal", +# # "Contributory", +# # "Causal", +# # "Influential", +# # "Collaborative", +# # "Initiative", +# # "Beneficial", +# # "Influential", +# # "Collaborative", +# # "Enriching" +# # ] , + # ,user_context="Query: Your task is to analyze and interpret the context to construct semantic triples. The above context is from a university document along with the identified entities using NER. Identify which entity is the subject entity and which is the object entity based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Also identify the respective subject entity type , object entity and predicate type. Answer:" # ) # llm_instance = BERTLLM(result_queue, bert_llm_config) # class StateChangeCallback(EventCallbackInterface): + +# def average_embeddings(embedding1, embedding2, embedding3, embedding4, predicate_score): +# emb1 = np.array(embedding1) +# emb2 = np.array(embedding2) +# emb3 = np.array(embedding3) * predicate_score +# emb4 = np.array(embedding4) + +# # Calculate the average embedding +# average_emb = (emb1 + emb2 + emb3 + emb4) / 4 +# return average_emb + +# import numpy as np + +# def weighted_average_embeddings(embeddings, weights=None, normalize_weights=True): +# embeddings = [np.array(emb) for emb in embeddings] + +# if weights is None: +# weights = np.ones(len(embeddings)) +# else: +# weights = np.array(weights) +# if len(weights) != len(embeddings): +# raise ValueError("The number of weights must match the number of embeddings.") +# if normalize_weights: +# weights = weights / np.sum(weights) # Normalize weights to sum to 1 + +# weighted_sum = np.zeros_like(embeddings[0]) +# for emb, weight in zip(embeddings, weights): +# weighted_sum += emb * weight + +# return weighted_sum + +# def dynamic_weighted_average_embeddings(embeddings, base_weights, normalize_weights=True): +# embeddings = [np.array(emb) for emb in embeddings] +# weights = np.array(base_weights, dtype=float) + +# if normalize_weights: +# weights /= np.sum(weights) # Normalize weights to sum to 1 + +# weighted_sum = np.zeros_like(embeddings[0]) +# for emb, weight in zip(embeddings, weights): +# weighted_sum += emb * weight + +# return weighted_sum + # def handle_event(self, event_type: EventType, event_state: EventState): # if event_state['event_type'] == EventType.Graph: # triple = json.loads(event_state['payload']) # print("triple: {}".format(triple)) # assert isinstance(triple['subject'], str) and triple['subject'] +# db_manager.insert_metadata( +# subject=triple['subject'], +# subject_type=triple['subject_type'], +# predicate=triple['predicate'], +# object=triple['object'], +# object_type=triple['object_type'], +# sentence=triple['sentence'], +# file=event_state['file'], +# doc_source=event_state['doc_source'] +# ) # elif event_state['event_type'] == EventType.Vector: -# triple = json.loads(event_state['payload']) -# print("id: {}".format(triple['id'])) -# print("namespace: {}".format(triple['namespace'])) +# triple_v = json.loads(event_state['payload']) +# # print("Vector Json :----: {}".format(triple_v)) +# db_manager.insert_embedding( +# document_source=event_state['doc_source'], +# knowledge=triple_v['id'], +# predicate=triple_v['namespace'], +# sentence=triple_v['sentence'], +# embeddings=triple_v['embeddings'], +# file=event_state['file'] +# ) + # llm_instance.subscribe(EventType.Graph, StateChangeCallback()) # llm_instance.subscribe(EventType.Vector, StateChangeCallback()) @@ -107,8 +264,10 @@ # ) # querent_task = asyncio.create_task(querent.start()) # await asyncio.gather(ingest_task, querent_task) +# db_manager.close_connection() # if __name__ == "__main__": # # Run the async function # asyncio.run(test_ingest_all_async()) + diff --git a/tests/workflows/colab_test_1.py b/tests/workflows/colab_test_1.py index f902f0e0..3b3d8535 100644 --- a/tests/workflows/colab_test_1.py +++ b/tests/workflows/colab_test_1.py @@ -1,113 +1,113 @@ -import asyncio -import json -import pytest -from pathlib import Path -import uuid -import nltk - -from querent.callback.event_callback_interface import EventCallbackInterface -from querent.common.types.querent_queue import QuerentQueue -from querent.common.types.querent_event import EventState, EventType -from querent.common.types.ingested_tokens import IngestedTokens -from querent.common.types.querent_queue import QuerentQueue -from querent.config.core.llm_config import LLM_Config -from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -from querent.querent.resource_manager import ResourceManager -from querent.querent.querent import Querent -from querent.collectors.collector_resolver import CollectorResolver -from querent.config.collector.collector_config import FSCollectorConfig -from querent.common.uri import Uri -from querent.ingestors.ingestor_manager import IngestorFactoryManager - -async def main(): - print("Inside main)") -# Setup directories for data collection and configure collectors. - directories = ["/home/nishantg/querent-main/querent/tests/data/llm/case_study_files"] - collectors = [ - # Resolve and configure each collector based on the provided directory and file system configuration. - CollectorResolver().resolve( - Uri("file://" + str(Path(directory).resolve())), - FSCollectorConfig( - config_source={ - "id": str(uuid.uuid4()), - "root_path": directory, - "name": "Local-config", - "config": {}, - "uri": "file://", - } - ), - ) - for directory in directories - ] - - # Connect each collector asynchronously. - for collector in collectors: - await collector.connect() - - # Setup the result queue for processing results from collectors. - result_queue = asyncio.Queue() - - # Initialize the IngestorFactoryManager with the collectors and result queue. - ingestor_factory_manager = IngestorFactoryManager( - collectors=collectors, result_queue=result_queue - ) - - - # Start the asynchronous ingestion process and store the task. - ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) - - # Initialize the ResourceManager. - resource_manager = ResourceManager() - - # Set NLTK data path for natural language processing tasks. - nltk.data.path=["/home/nishantg/querent-main/model/nltk_data"] - - # Configure the BERT language model for named entity recognition (NER) and filtering. - bert_llm_config = LLM_Config( - rel_model_path="/home/nishantg/querent-main/model/llama-2-7b-chat.Q5_K_M.gguf", - grammar_file_path="/home/nishantg/querent-main/model/json.gbnf", - spacy_model_path="/home/nishantg/querent-main/model/en_core_web_lg-3.7.1/en_core_web_lg/en_core_web_lg-3.7.1", - ner_model_name="dbmdz/bert-large-cased-finetuned-conll03-english", - nltk_path="/home/nishantg/querent-main/model/nltk_data", - enable_filtering=True, - filter_params={ - 'score_threshold': 0.5, - 'attention_score_threshold': 0.1, - 'similarity_threshold': 0.5, - 'min_cluster_size': 5, - 'min_samples': 3, - 'cluster_persistence_threshold':0.2 - } - ) - - # Initialize the BERTLLM instance with the result queue and configuration. - llm_instance = BERTLLM(result_queue, bert_llm_config) - - # Define a function to automatically terminate the task after 5 minutes - async def terminate_querent(result_queue): - await asyncio.sleep(180) - await result_queue.put(None) - await result_queue.put(None) - - # Define a callback class to handle state changes and print resulting triples. - class StateChangeCallback(EventCallbackInterface): - def handle_event(self, event_type: EventType, event_state: EventState): - assert event_state['event_type'] == EventType.Graph - triple = json.loads(event_state['payload']) - print("triple: {}".format(triple)) - - # Subscribe the BERTLLM instance to graph events using the StateChangeCallback. - llm_instance.subscribe(EventType.Graph, StateChangeCallback()) - - # Initialize Querent with the BERTLLM instance and ResourceManager. - querent = Querent( - [llm_instance], - resource_manager=resource_manager, - ) - - # Start Querent and the ingestion task asynchronously and wait for both to complete. - querent_task = asyncio.create_task(querent.start()) - terminate_task = asyncio.create_task(terminate_querent(result_queue)) - await asyncio.gather(querent_task, ingest_task, terminate_task) -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file +# import asyncio +# import json +# import pytest +# from pathlib import Path +# import uuid +# import nltk + +# from querent.callback.event_callback_interface import EventCallbackInterface +# from querent.common.types.querent_queue import QuerentQueue +# from querent.common.types.querent_event import EventState, EventType +# from querent.common.types.ingested_tokens import IngestedTokens +# from querent.common.types.querent_queue import QuerentQueue +# from querent.config.core.llm_config import LLM_Config +# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM +# from querent.querent.resource_manager import ResourceManager +# from querent.querent.querent import Querent +# from querent.collectors.collector_resolver import CollectorResolver +# from querent.config.collector.collector_config import FSCollectorConfig +# from querent.common.uri import Uri +# from querent.ingestors.ingestor_manager import IngestorFactoryManager + +# async def main(): +# print("Inside main)") +# # Setup directories for data collection and configure collectors. +# directories = ["/home/nishantg/querent-main/querent/tests/data/llm/case_study_files"] +# collectors = [ +# # Resolve and configure each collector based on the provided directory and file system configuration. +# CollectorResolver().resolve( +# Uri("file://" + str(Path(directory).resolve())), +# FSCollectorConfig( +# config_source={ +# "id": str(uuid.uuid4()), +# "root_path": directory, +# "name": "Local-config", +# "config": {}, +# "uri": "file://", +# } +# ), +# ) +# for directory in directories +# ] + +# # Connect each collector asynchronously. +# for collector in collectors: +# await collector.connect() + +# # Setup the result queue for processing results from collectors. +# result_queue = asyncio.Queue() + +# # Initialize the IngestorFactoryManager with the collectors and result queue. +# ingestor_factory_manager = IngestorFactoryManager( +# collectors=collectors, result_queue=result_queue +# ) + + +# # Start the asynchronous ingestion process and store the task. +# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) + +# # Initialize the ResourceManager. +# resource_manager = ResourceManager() + +# # Set NLTK data path for natural language processing tasks. +# nltk.data.path=["/home/nishantg/querent-main/model/nltk_data"] + +# # Configure the BERT language model for named entity recognition (NER) and filtering. +# bert_llm_config = LLM_Config( +# rel_model_path="/home/nishantg/querent-main/model/llama-2-7b-chat.Q5_K_M.gguf", +# grammar_file_path="/home/nishantg/querent-main/model/json.gbnf", +# spacy_model_path="/home/nishantg/querent-main/model/en_core_web_lg-3.7.1/en_core_web_lg/en_core_web_lg-3.7.1", +# ner_model_name="dbmdz/bert-large-cased-finetuned-conll03-english", +# nltk_path="/home/nishantg/querent-main/model/nltk_data", +# enable_filtering=True, +# filter_params={ +# 'score_threshold': 0.5, +# 'attention_score_threshold': 0.1, +# 'similarity_threshold': 0.5, +# 'min_cluster_size': 5, +# 'min_samples': 3, +# 'cluster_persistence_threshold':0.2 +# } +# ) + +# # Initialize the BERTLLM instance with the result queue and configuration. +# llm_instance = BERTLLM(result_queue, bert_llm_config) + +# # Define a function to automatically terminate the task after 5 minutes +# async def terminate_querent(result_queue): +# await asyncio.sleep(180) +# await result_queue.put(None) +# await result_queue.put(None) + +# # Define a callback class to handle state changes and print resulting triples. +# class StateChangeCallback(EventCallbackInterface): +# def handle_event(self, event_type: EventType, event_state: EventState): +# assert event_state['event_type'] == EventType.Graph +# triple = json.loads(event_state['payload']) +# print("triple: {}".format(triple)) + +# # Subscribe the BERTLLM instance to graph events using the StateChangeCallback. +# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) + +# # Initialize Querent with the BERTLLM instance and ResourceManager. +# querent = Querent( +# [llm_instance], +# resource_manager=resource_manager, +# ) + +# # Start Querent and the ingestion task asynchronously and wait for both to complete. +# querent_task = asyncio.create_task(querent.start()) +# terminate_task = asyncio.create_task(terminate_querent(result_queue)) +# await asyncio.gather(querent_task, ingest_task, terminate_task) +# if __name__ == "__main__": +# asyncio.run(main()) \ No newline at end of file diff --git a/tests/workflows/gpt_llm_case_study.py b/tests/workflows/gpt_llm_case_study.py deleted file mode 100644 index 73190b72..00000000 --- a/tests/workflows/gpt_llm_case_study.py +++ /dev/null @@ -1,147 +0,0 @@ -# import asyncio -# from asyncio import Queue -# import json -# from pathlib import Path -# from querent.callback.event_callback_interface import EventCallbackInterface -# from querent.collectors.fs.fs_collector import FSCollectorFactory -# from querent.common.types.ingested_tokens import IngestedTokens -# from querent.common.types.querent_event import EventState, EventType -# from querent.config.collector.collector_config import FSCollectorConfig -# from querent.common.uri import Uri -# from querent.config.core.gpt_llm_config import GPTConfig -# from querent.core.transformers.gpt_llm_bert_ner_or_fixed_entities_set_ner import GPTLLM -# from querent.ingestors.ingestor_manager import IngestorFactoryManager -# import pytest -# import uuid -# from querent.common.types.file_buffer import FileBuffer -# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -# from querent.processors.text_cleanup_processor import TextCleanupProcessor -# from querent.querent.resource_manager import ResourceManager -# from querent.querent.querent import Querent -# import time -# from querent.storage.postgres_graphevent_storage import DatabaseConnection - -# @pytest.mark.asyncio -# async def test_ingest_all_async(): -# # Set up the collectors -# # db_conn = DatabaseConnection(dbname="postgres", -# # user="querent", -# # password="querent", -# # host="localhost", -# # port="5432") -# # ml_conn = MilvusDBConnection() -# directories = [ "./tests/data/llm/one_file/"] -# collectors = [ -# FSCollectorFactory().resolve( -# Uri("file://" + str(Path(directory).resolve())), -# FSCollectorConfig(config_source={ -# "id": str(uuid.uuid4()), -# "root_path": directory, -# "name": "Local-config", -# "config": {}, -# "backend": "localfile", -# "uri": "file://", -# }), -# ) -# for directory in directories -# ] - -# # Set up the result queue -# result_queue = asyncio.Queue() -# text_cleanup_processor = TextCleanupProcessor() -# # Create the IngestorFactoryManager -# ingestor_factory_manager = IngestorFactoryManager( -# collectors=collectors, result_queue=result_queue, processors=[text_cleanup_processor] -# ) -# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) -# print("Going to start ingesting now.......") -# resource_manager = ResourceManager() -# bert_llm_config = GPTConfig( -# ner_model_name="botryan96/GeoBERT", -# rel_model_path="/path/to/model.gguf", -# enable_filtering=True, -# filter_params={ -# 'score_threshold': 0.5, -# 'attention_score_threshold': 0.1, -# 'similarity_threshold': 0.5, -# 'min_cluster_size': 5, -# 'min_samples': 3, -# 'cluster_persistence_threshold':0.2 -# } -# ,fixed_entities = [ -# "Carbonate", "Clastic", "Porosity", "Permeability", -# "Oil saturation", "Water saturation", "Gas saturation", -# "Depth", "Size", "Temperature", -# "Pressure", "Oil viscosity", "Gas-oil ratio", -# "Water cut", "Recovery factor", "Enhanced recovery technique", -# "Horizontal drilling", "Hydraulic fracturing", "Water injection", "Gas injection", "Steam injection", -# "Seismic activity", "Structural deformation", "Faulting", -# "Cap rock integrity", "Compartmentalization", -# "Connectivity", "Production rate", "Depletion rate", -# "Exploration technique", "Drilling technique", "Completion technique", -# "Environmental impact", "Regulatory compliance", -# "Economic analysis", "Market analysis", "oil well", "gas well", "oil field", "Gas field", "eagle ford", "ghawar", "johan sverdrup", "karachaganak","maracaibo" -# ] -# , sample_entities = [ -# "rock_type", "rock_type", "reservoir_property", "reservoir_property", -# "reservoir_property", "reservoir_property", "reservoir_property", -# "reservoir_characteristic", "reservoir_characteristic", "reservoir_characteristic", -# "reservoir_characteristic", "reservoir_property", "reservoir_property", -# "production_metric", "production_metric", "recovery_technique", -# "drilling_technique", "recovery_technique", "recovery_technique", "recovery_technique", "recovery_technique", -# "geological_feature", "geological_feature", "geological_feature", -# "reservoir_feature", "reservoir_feature", -# "reservoir_feature", "production_metric", "production_metric", -# "exploration_method", "drilling_method", "completion_method", -# "environmental_aspect", "regulatory_aspect", -# "economic_aspect", "economic_aspect","hydrocarbon_source","hydrocarbon_source","hydrocarbon_source","hydrocarbon_source","reservoir","reservoir","reservoir","reservoir","reservoir" -# ] -# , is_confined_search = True, -# openai_api_key = "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", -# # , huggingface_token = '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' -# user_context = """Query: Your task is to analyze and interpret the context to construct semantic triples. The above context is from a geological research study on reservoirs and the above entities and their respective types have already been identified. -# Please Identify the entity which is the subject and the entity which is object based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Determine whether the entity labels provided match the subject type and object type and correct if needed.Also provide the predicate type. -# Answer:""" -# ) -# llm_instance = GPTLLM(result_queue, bert_llm_config) -# class StateChangeCallback(EventCallbackInterface): -# def handle_event(self, event_type: EventType, event_state: EventState): -# if event_state["event_type"] == EventType.Vector : -# triple = json.loads(event_state["payload"]) -# # print("triple: {}".format(triple)) -# vector_triple = json.loads(event_state["payload"]) -# # print("Inside Vector event ---------------------------------", vector_triple) -# # milvus_coll = ml_conn.create_collection(collection_name=vector_triple['namespace'],dim = 384) -# # ml_conn.insert_vector_event(id = vector_triple['id'], embedding= vector_triple['embeddings'], namespace= vector_triple['namespace'], document=event_state["file"], collection= milvus_coll ) -# # assert event_state.event_type == EventType.Graph -# if event_state["event_type"] == EventType.Graph : -# triple = json.loads(event_state["payload"]) -# print("file---------------------",event_state["file"], "----------------", type(event_state["file"])) -# # print("triple: {}".format(triple)) -# graph_event_data = { -# 'subject': triple['subject'], -# 'subject_type': triple['subject_type'], -# 'object': triple['object'], -# 'object_type': triple['object_type'], -# 'predicate': triple['predicate'], -# 'predicate_type': triple['predicate_type'], -# 'sentence': triple['sentence'], -# 'document_id': event_state["file"] -# } -# # db_conn.insert_graph_event(graph_event_data) -# assert isinstance(triple['subject'], str) and triple['subject'] - -# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) -# # llm_instance.subscribe(EventType.Vector, StateChangeCallback()) -# querent = Querent( -# [llm_instance], -# resource_manager=resource_manager, -# ) -# querent_task = asyncio.create_task(querent.start()) -# await asyncio.gather(querent_task, ingest_task) -# # db_conn.close() - -# if __name__ == "__main__": - -# # Run the async function -# asyncio.run(test_ingest_all_async()) diff --git a/tests/workflows/gpt_llm_test_fixed_entities_predicates_workflow.py b/tests/workflows/gpt_llm_test_fixed_entities_predicates_workflow.py deleted file mode 100644 index f744b5e3..00000000 --- a/tests/workflows/gpt_llm_test_fixed_entities_predicates_workflow.py +++ /dev/null @@ -1,116 +0,0 @@ -# import asyncio -# from asyncio import Queue -# import json -# from pathlib import Path -# from querent.callback.event_callback_interface import EventCallbackInterface -# from querent.collectors.fs.fs_collector import FSCollectorFactory -# from querent.common.types.ingested_tokens import IngestedTokens -# from querent.common.types.querent_event import EventState, EventType -# from querent.config.collector.collector_config import FSCollectorConfig -# from querent.common.uri import Uri -# from querent.config.core.llm_config import LLM_Config -# from querent.ingestors.ingestor_manager import IngestorFactoryManager -# import pytest -# import uuid -# from querent.common.types.file_buffer import FileBuffer -# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -# from querent.querent.resource_manager import ResourceManager -# from querent.querent.querent import Querent -# import time -# from querent.core.transformers.gpt_llm_bert_ner_or_fixed_entities_set_ner import GPTLLM -# from querent.config.core.gpt_llm_config import GPTConfig - -# @pytest.mark.asyncio -# async def test_ingest_all_async(): -# # Set up the collectors -# directories = [ "/home/ansh/pyg-trail/testing-xlsx"] -# collectors = [ -# FSCollectorFactory().resolve( -# Uri("file://" + str(Path(directory).resolve())), -# FSCollectorConfig(config_source={ -# "id": str(uuid.uuid4()), -# "root_path": directory, -# "name": "Local-config", -# "config": {}, -# "backend": "localfile", -# "uri": "file://", -# }), -# ) -# for directory in directories -# ] - -# # Set up the result queue -# result_queue = asyncio.Queue() - -# # Create the IngestorFactoryManager -# ingestor_factory_manager = IngestorFactoryManager( -# collectors=collectors, result_queue=result_queue -# ) -# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) -# resource_manager = ResourceManager() -# bert_llm_config = GPTConfig( -# # ner_model_name="botryan96/GeoBERT", -# enable_filtering=True, -# openai_api_key="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", -# filter_params={ -# 'score_threshold': 0.5, -# 'attention_score_threshold': 0.1, -# 'similarity_threshold': 0.5, -# 'min_cluster_size': 5, -# 'min_samples': 3, -# 'cluster_persistence_threshold':0.2 -# } -# ,fixed_entities = ["university", "greenwood", "liam zheng", "department", "Metroville", "Emily Stanton", "Coach", "health", "training", "atheletes" ] -# ,sample_entities=["organization", "organization", "person", "department", "city", "person", "person", "method", "method", "person"] -# ,fixed_relationships=[ -# "Increase in research funding leads to environmental science focus", -# "Dr. Emily Stanton's advocacy for cleaner energy", -# "University's commitment to reduce carbon emissions", -# "Dr. Stanton's research influences architectural plans", -# "Collaborative project between sociology and environmental sciences", -# "Student government launches mental health awareness workshops", -# "Enhanced fitness programs improve sports teams' performance", -# "Coach Torres influences student-athletes' holistic health", -# "Partnership expands access to digital resources", -# "Interdisciplinary approach enriches academic experience" -# ] -# , sample_relationships=[ -# "Causal", -# "Contributory", -# "Causal", -# "Influential", -# "Collaborative", -# "Initiative", -# "Beneficial", -# "Influential", -# "Collaborative", -# "Enriching" -# ] -# ,is_confined_search = True - -# ,user_context="Your task is to analyze and interpret the context to construct semantic triples. The above context is from a university document along with the identified entities using NER. Identify which entity is the subject entity and which is the object entity based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Also identify the respective subject entity type , object entity and predicate type. Answer:" -# ) -# llm_instance = GPTLLM(result_queue, bert_llm_config) -# class StateChangeCallback(EventCallbackInterface): -# def handle_event(self, event_type: EventType, event_state: EventState): -# if event_state['event_type'] == EventType.Graph: -# triple = json.loads(event_state['payload']) -# print("triple: {}".format(triple)) -# assert isinstance(triple['subject'], str) and triple['subject'] -# elif event_state['event_type'] == EventType.Vector: -# triple = json.loads(event_state['payload']) -# print("id: {}".format(triple['id'])) -# print("namespace: {}".format(triple['namespace'])) -# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) -# llm_instance.subscribe(EventType.Vector, StateChangeCallback()) -# querent = Querent( -# [llm_instance], -# resource_manager=resource_manager, -# ) -# querent_task = asyncio.create_task(querent.start()) -# await asyncio.gather(ingest_task, querent_task) - -# if __name__ == "__main__": - -# # Run the async function -# asyncio.run(test_ingest_all_async()) diff --git a/tests/workflows/openai_case_study_workflow.py b/tests/workflows/openai_case_study_workflow.py deleted file mode 100644 index 869cb87c..00000000 --- a/tests/workflows/openai_case_study_workflow.py +++ /dev/null @@ -1,158 +0,0 @@ -# import asyncio -# from asyncio import Queue -# import json -# from pathlib import Path -# from querent.callback.event_callback_interface import EventCallbackInterface -# from querent.collectors.fs.fs_collector import FSCollectorFactory -# from querent.common.types.ingested_tokens import IngestedTokens -# from querent.common.types.querent_event import EventState, EventType -# from querent.config.collector.collector_config import FSCollectorConfig -# from querent.common.uri import Uri -# from querent.config.core.llm_config import LLM_Config -# from querent.core.transformers.fixed_entities_set_opensourcellm import Fixed_Entities_LLM -# from querent.ingestors.ingestor_manager import IngestorFactoryManager -# import pytest -# import uuid -# from querent.common.types.file_buffer import FileBuffer -# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -# from querent.processors.text_cleanup_processor import TextCleanupProcessor -# from querent.querent.resource_manager import ResourceManager -# from querent.querent.querent import Querent -# import time -# from querent.storage.postgres_graphevent_storage import DatabaseConnection -# # from querent.storage.milvus_vectorevent_storage import MilvusDBConnection -# from querent.config.core.gpt_llm_config import GPTConfig -# from querent.core.transformers.gpt_llm_bert_ner_or_fixed_entities_set_ner import GPTLLM - -# @pytest.mark.asyncio -# async def test_ingest_all_async(): -# # Set up the collectors -# # db_conn = DatabaseConnection(dbname="postgres", -# # user="postgres", -# # password="querent", -# # host="localhost", -# # port="5432") -# # # ml_conn = MilvusDBConnection() -# directories = [ "./tests/data/llm/one_file/"] -# collectors = [ -# FSCollectorFactory().resolve( -# Uri("file://" + str(Path(directory).resolve())), -# FSCollectorConfig(config_source={ -# "id": str(uuid.uuid4()), -# "root_path": directory, -# "name": "Local-config", -# "config": {}, -# "backend": "localfile", -# "uri": "file://", -# }), -# ) -# for directory in directories -# ] - -# # Set up the result queue -# result_queue = asyncio.Queue() -# text_cleanup_processor = TextCleanupProcessor() -# # Create the IngestorFactoryManager -# ingestor_factory_manager = IngestorFactoryManager( -# collectors=collectors, result_queue=result_queue, processors=[text_cleanup_processor] -# ) -# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) -# resource_manager = ResourceManager() -# gpt_llm_config = GPTConfig( -# ner_model_name="dbmdz/bert-large-cased-finetuned-conll03-english", -# # rel_model_path="/home/nishantg/Downloads/openhermes-2.5-mistral-7b.Q5_K_M.gguf", -# enable_filtering=True, -# openai_api_key="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" -# ,filter_params={ -# 'score_threshold': 0.5, -# 'attention_score_threshold': 0.1, -# 'similarity_threshold': 0.5, -# 'min_cluster_size': 5, -# 'min_samples': 3, -# 'cluster_persistence_threshold':0.2 -# } -# ,user_context="Query: Your task is to analyze and interpret the context to construct semantic triples. Please Identify the entity which is the subject and the entity which is object based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Determine whether the entity labels provided match the subject type and object type and correct if needed. Also provide the predicate type. Answer:" -# # ,fixed_entities = [ -# # "Hadean", "Archean", "Proterozoic", "Phanerozoic", -# # "Paleozoic", "Mesozoic", "Cenozoic", -# # "Cambrian", "Ordovician", "Silurian", "Devonian", "Carboniferous", "Permian", -# # "Triassic", "Jurassic", "Cretaceous", -# # "Paleogene", "Neogene", "Quaternary", -# # "Paleocene", "Eocene", "Oligocene", -# # "Miocene", "Pliocene", -# # "Pleistocene", "Holocene", -# # "Anticline", "Syncline", "Fault", "Salt dome", "Horst", "Graben", -# # "Reef", "Shoal", "Deltaic deposits", "Turbidite", "Channel sandstone", -# # "Sandstone", "Limestone", "Dolomite", "Shale", -# # "Source rock", "Cap rock", "Shale gas", -# # "Crude oil", "Natural gas", "Shale oil", "Coalbed methane", "Tar sands", "Gas hydrates", -# # "Structural trap", "Stratigraphic trap", "Combination trap", "Salt trap", "Unconformity trap", -# # "Hydrocarbon migration", "Hydrocarbon accumulation", -# # "Placer deposits", "Vein deposit", "Porphyry deposit", "Kimberlite pipe", "Laterite deposit", -# # "Volcanic rock", "Basalt", "Geothermal gradient", "Sedimentology", -# # "Paleontology", "Biostratigraphy", "Sequence stratigraphy", "Geophysical survey", -# # "Magnetic anomaly", "Gravitational anomaly", "Petrology", "Geochemistry", "Hydrogeology" -# # ] - -# # , sample_entities=[ -# # "geological_eon", "geological_eon", "geological_eon", "geological_eon", -# # "geological_era", "geological_era", "geological_era", -# # "geological_period", "geological_period", "geological_period", "geological_period", "geological_period", "geological_period", -# # "geological_period", "geological_period", "geological_period", -# # "geological_period", "geological_period", "geological_period", -# # "geological_epoch", "geological_epoch", "geological_epoch", -# # "geological_epoch", "geological_epoch", -# # "geological_epoch", "geological_epoch", "structural_feature", "structural_feature", "structural_feature", "structural_feature", "structural_feature", "structural_feature", -# # "stratigraphic_feature", "stratigraphic_feature", "stratigraphic_feature", "stratigraphic_feature", "stratigraphic_feature", -# # "rock_type", "rock_type", "rock_type", "rock_type", -# # "rock_type", "rock_type", "hydrocarbon_source", -# # "hydrocarbon", "hydrocarbon", "hydrocarbon", "hydrocarbon", "hydrocarbon", "hydrocarbon", -# # "trap_type", "trap_type", "trap_type", "trap_type", "trap_type", -# # "geological_process", "geological_process", -# # "mineral_deposit", "mineral_deposit", "mineral_deposit", "mineral_deposit", "mineral_deposit", -# # "rock_type", "rock_type", "geological_process", "geological_discipline", -# # "geological_discipline", "geological_method", "geological_method", "geological_method", -# # "geophysical_feature", "geophysical_feature", "geological_discipline", "geological_discipline", "geological_discipline" -# # ] -# # , is_confined_search = True -# # , huggingface_token = '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' -# ) -# llm_instance = GPTLLM(result_queue, gpt_llm_config) -# class StateChangeCallback(EventCallbackInterface): -# def handle_event(self, event_type: EventType, event_state: EventState): -# # assert event_state.event_type == EventType.Graph -# if event_state['event_type'] == EventType.Graph : -# triple = json.loads(event_state['payload']) -# print("file---------------------",event_state['file'], "----------------", type(event_state['file'])) -# print("triple: {}".format(triple)) -# graph_event_data = { -# 'subject': triple['subject'], -# 'subject_type': triple['subject_type'], -# 'object': triple['object'], -# 'object_type': triple['object_type'], -# 'predicate': triple['predicate'], -# 'predicate_type': triple['predicate_type'], -# 'sentence': triple['sentence'], -# 'document_id': event_state['file'] -# } -# # db_conn.insert_graph_event(graph_event_data) -# assert isinstance(triple['subject'], str) and triple['subject'] -# # else : -# # vector_triple = json.loads(event_state.payload) -# # print("Inside Vector event ---------------------------------", vector_triple) -# # milvus_coll = ml_conn.create_collection(collection_name=vector_triple['namespace'],dim = 384) -# # ml_conn.insert_vector_event(id = vector_triple['id'], embedding= vector_triple['embeddings'], namespace= vector_triple['namespace'], document=event_state.file, collection= milvus_coll ) -# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) -# # llm_instance.subscribe(EventType.Vector, StateChangeCallback()) -# querent = Querent( -# [llm_instance], -# resource_manager=resource_manager, -# ) -# querent_task = asyncio.create_task(querent.start()) -# await asyncio.gather(ingest_task, querent_task) -# # db_conn.close() - -# if __name__ == "__main__": - -# # Run the async function -# asyncio.run(test_ingest_all_async()) diff --git a/tests/workflows/openai_ingested_images_test.py b/tests/workflows/openai_ingested_images_test.py deleted file mode 100644 index c71d3255..00000000 --- a/tests/workflows/openai_ingested_images_test.py +++ /dev/null @@ -1,117 +0,0 @@ -# import asyncio -# from asyncio import Queue -# import json -# from pathlib import Path -# from querent.callback.event_callback_interface import EventCallbackInterface -# from querent.collectors.fs.fs_collector import FSCollectorFactory -# from querent.common.types.ingested_tokens import IngestedTokens -# from querent.common.types.querent_event import EventState, EventType -# from querent.config.collector.collector_config import FSCollectorConfig -# from querent.common.uri import Uri -# from querent.config.core.llm_config import LLM_Config -# from querent.core.transformers.fixed_entities_set_opensourcellm import Fixed_Entities_LLM -# from querent.ingestors.ingestor_manager import IngestorFactoryManager -# import pytest -# import uuid -# from querent.common.types.file_buffer import FileBuffer -# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -# from querent.processors.text_cleanup_processor import TextCleanupProcessor -# from querent.querent.resource_manager import ResourceManager -# from querent.querent.querent import Querent -# import time -# # from querent.storage.milvus_vectorevent_storage import MilvusDBConnection -# from querent.config.core.gpt_llm_config import GPTConfig -# from querent.core.transformers.gpt_llm_bert_ner_or_fixed_entities_set_ner import GPTLLM - -# @pytest.mark.asyncio -# async def test_ingest_all_async(): -# # Set up the collectors -# # db_conn = DatabaseConnection(dbname="postgres", -# # user="postgres", -# # password="querent", -# # host="localhost", -# # port="5432") -# # # ml_conn = MilvusDBConnection() -# directories = [ "/home/ansh/pyg-trail/testing-ocr"] -# collectors = [ -# FSCollectorFactory().resolve( -# Uri("file://" + str(Path(directory).resolve())), -# FSCollectorConfig(config_source={ -# "id": str(uuid.uuid4()), -# "root_path": directory, -# "name": "Local-config", -# "config": {}, -# "backend": "localfile", -# "uri": "file://", -# }), -# ) -# for directory in directories -# ] - -# # Set up the result queue -# result_queue = asyncio.Queue() -# text_cleanup_processor = TextCleanupProcessor() -# # Create the IngestorFactoryManager -# ingestor_factory_manager = IngestorFactoryManager( -# collectors=collectors, result_queue=result_queue, processors=[text_cleanup_processor] -# ) -# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) -# resource_manager = ResourceManager() -# gpt_llm_config = GPTConfig( -# # ner_model_name="botryan96/GeoBERT", -# # rel_model_path="/home/nishantg/Downloads/openhermes-2.5-mistral-7b.Q5_K_M.gguf", -# # enable_filtering=True, -# openai_api_key="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" -# # ,filter_params={ -# # 'score_threshold': 0.5, -# # 'attention_score_threshold': 0.1, -# # 'similarity_threshold': 0.5, -# # 'min_cluster_size': 5, -# # 'min_samples': 3, -# # 'cluster_persistence_threshold':0.2 -# # } -# ,user_context="Query: Your task is to analyze and interpret the context to construct semantic triples. Please Identify the entity which is the subject and the entity which is object based on the context, and determine the meaningful relationship or predicate linking the subject entity to the object entity. Determine whether the entity labels provided match the subject type and object type and correct if needed. Also provide the predicate type. Answer:" -# # ,fixed_entities =["modeling", "sonic", "symmetry","isotropy", "Carbonate", "Clastic", "Porosity", "Permeability", "Oil saturation", "Water saturation", "Gas saturation", "Depth", "Size", "Temperature", "Pressure", "Oil viscosity", "Gas-oil ratio", "Water cut", "Recovery factor", "Enhanced recovery technique", "Horizontal drilling", "Hydraulic fracturing", "Water injection", "Gas injection", "Steam injection", "Seismic activity", "Structural deformation", "Faulting", "Cap rock integrity", "Compartmentalization", "Connectivity", "Production rate", "Depletion rate", "Exploration technique", "Drilling technique", "Completion technique", "Environmental impact", "Regulatory compliance", "Economic analysis", "Market analysis", "oil well", "gas well", "oil field", "Gas field", "eagle ford shale", "ghawar", "johan sverdrup", "karachaganak", "maracaibo"], -# # sample_entities = ["method","method","method","method", "rock_type", "rock_type", "reservoir_property", "reservoir_property", "reservoir_property", "reservoir_property", "reservoir_property", "reservoir_characteristic", "reservoir_characteristic", "reservoir_characteristic", "reservoir_characteristic", "reservoir_property", "reservoir_property", "production_metric", "production_metric", "recovery_technique", "drilling_technique", "recovery_technique", "recovery_technique", "recovery_technique", "recovery_technique", "geological_feature", "geological_feature", "geological_feature", "reservoir_feature", "reservoir_feature", "reservoir_feature", "production_metric", "production_metric", "exploration_method", "drilling_method", "completion_method", "environmental_aspect", "regulatory_aspect", "economic_aspect", "economic_aspect", "hydrocarbon_source", "hydrocarbon_source", "hydrocarbon_source", "hydrocarbon_source", "reservoir", "reservoir", "reservoir", "reservoir", "reservoir"] -# # , is_confined_search = True -# # , huggingface_token = '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' -# ) -# llm_instance = GPTLLM(result_queue, gpt_llm_config) -# class StateChangeCallback(EventCallbackInterface): -# def handle_event(self, event_type: EventType, event_state: EventState): -# # assert event_state.event_type == EventType.Graph -# if event_state['event_type'] == EventType.Graph : -# triple = json.loads(event_state['payload']) -# print("file---------------------",event_state['file'], "----------------", type(event_state['file'])) -# print("triple: {}".format(triple)) -# graph_event_data = { -# 'subject': triple['subject'], -# 'subject_type': triple['subject_type'], -# 'object': triple['object'], -# 'object_type': triple['object_type'], -# 'predicate': triple['predicate'], -# 'predicate_type': triple['predicate_type'], -# 'sentence': triple['sentence'], -# 'document_id': event_state['file'] -# } -# # db_conn.insert_graph_event(graph_event_data) -# assert isinstance(triple['subject'], str) and triple['subject'] -# # else : -# # vector_triple = json.loads(event_state.payload) -# # print("Inside Vector event ---------------------------------", vector_triple) -# # milvus_coll = ml_conn.create_collection(collection_name=vector_triple['namespace'],dim = 384) -# # ml_conn.insert_vector_event(id = vector_triple['id'], embedding= vector_triple['embeddings'], namespace= vector_triple['namespace'], document=event_state.file, collection= milvus_coll ) -# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) -# # llm_instance.subscribe(EventType.Vector, StateChangeCallback()) -# querent = Querent( -# [llm_instance], -# resource_manager=resource_manager, -# ) -# querent_task = asyncio.create_task(querent.start()) -# await asyncio.gather(ingest_task, querent_task) -# # db_conn.close() - -# if __name__ == "__main__": - -# # Run the async function -# asyncio.run(test_ingest_all_async()) diff --git a/tests/workflows/openai_ner_case_study_workflow.py b/tests/workflows/openai_ner_case_study_workflow.py deleted file mode 100644 index 3e19b23b..00000000 --- a/tests/workflows/openai_ner_case_study_workflow.py +++ /dev/null @@ -1,101 +0,0 @@ -# import asyncio -# from asyncio import Queue -# import json -# from pathlib import Path -# from querent.callback.event_callback_interface import EventCallbackInterface -# from querent.collectors.fs.fs_collector import FSCollectorFactory -# from querent.common.types.ingested_tokens import IngestedTokens -# from querent.common.types.querent_event import EventState, EventType -# from querent.config.collector.collector_config import FSCollectorConfig -# from querent.common.uri import Uri -# from querent.config.core.llm_config import LLM_Config -# from querent.core.transformers.fixed_entities_set_opensourcellm import Fixed_Entities_LLM -# from querent.ingestors.ingestor_manager import IngestorFactoryManager -# import pytest -# import uuid -# from querent.common.types.file_buffer import FileBuffer -# from querent.core.transformers.bert_ner_opensourcellm import BERTLLM -# from querent.processors.text_cleanup_processor import TextCleanupProcessor -# from querent.querent.resource_manager import ResourceManager -# from querent.querent.querent import Querent -# import time -# from querent.storage.postgres_graphevent_storage import DatabaseConnection -# from querent.storage.milvus_vectorevent_storage import MilvusDBConnection -# from querent.config.core.gpt_llm_config import GPTConfig -# from querent.core.transformers.gpt_llm_gpt_ner import GPTNERLLM - -# @pytest.mark.asyncio -# async def test_ingest_all_async(): -# # Set up the collectors -# # db_conn = DatabaseConnection(dbname="postgres", -# # user="postgres", -# # password="querent", -# # host="localhost", -# # port="5432") -# # ml_conn = MilvusDBConnection() -# directories = [ "./tests/data/llm/one_file/"] -# collectors = [ -# FSCollectorFactory().resolve( -# Uri("file://" + str(Path(directory).resolve())), -# FSCollectorConfig(config_source={ -# "id": str(uuid.uuid4()), -# "root_path": directory, -# "name": "Local-config", -# "config": {}, -# "backend": "localfile", -# "uri": "file://", -# }), -# ) -# for directory in directories -# ] - -# # Set up the result queue -# result_queue = asyncio.Queue() - -# text_cleanup_processor = TextCleanupProcessor() -# # Create the IngestorFactoryManager -# ingestor_factory_manager = IngestorFactoryManager( -# collectors=collectors, result_queue=result_queue, processors=[text_cleanup_processor] -# ) -# ingest_task = asyncio.create_task(ingestor_factory_manager.ingest_all_async()) -# resource_manager = ResourceManager() -# gpt_llm_config = GPTConfig() -# llm_instance = GPTNERLLM(result_queue, gpt_llm_config) -# class StateChangeCallback(EventCallbackInterface): -# def handle_event(self, event_type: EventType, event_state: EventState): -# # assert event_state.event_type == EventType.Graph -# if event_state['event_type'] == EventType.Graph : -# triple = json.loads(event_state['payload']) -# print("file---------------------",event_state["file"], "----------------", type(event_state["file"])) -# # print("triple: {}".format(triple)) -# graph_event_data = { -# 'subject': triple['subject'], -# 'subject_type': triple['subject_type'], -# 'object': triple['object'], -# 'object_type': triple['object_type'], -# 'predicate': triple['predicate'], -# 'predicate_type': triple['predicate_type'], -# 'sentence': triple['sentence'], -# 'document_id': event_state['file'] -# } -# # db_conn.insert_graph_event(graph_event_data) -# assert isinstance(triple['subject'], str) and triple['subject'] -# # else : -# # vector_triple = json.loads(event_state.payload) -# # print("Inside Vector event ---------------------------------", vector_triple) -# # milvus_coll = ml_conn.create_collection(collection_name=vector_triple['namespace'],dim = 384) -# # ml_conn.insert_vector_event(id = vector_triple['id'], embedding= vector_triple['embeddings'], namespace= vector_triple['namespace'], document=event_state.file, collection= milvus_coll ) -# llm_instance.subscribe(EventType.Graph, StateChangeCallback()) -# # llm_instance.subscribe(EventType.Vector, StateChangeCallback()) -# querent = Querent( -# [llm_instance], -# resource_manager=resource_manager, -# ) -# querent_task = asyncio.create_task(querent.start()) -# await asyncio.gather(ingest_task, querent_task) -# # db_conn.close() - -# if __name__ == "__main__": - -# # Run the async function -# asyncio.run(test_ingest_all_async())