diff --git a/README.md b/README.md index be17ef34..382aa32e 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,14 @@ # STORM: Synthesis of Topic Outlines through Retrieval and Multi-perspective Question Asking

-| Research preview | Paper | Website | +| Research preview | STORM Paper| Co-STORM Paper | Website |

- **Latest News** 🔥 +- [2024/09] Co-STORM codebase is now released and integrated into `knowledge-storm` python package v1.0.0. Run `pip install knowledge-storm --upgrade` to check it out. + +- [2024/09] We introduce collaborative STORM (Co-STORM) to support human-AI collaborative knowledge curation! [Co-STORM Paper](https://www.arxiv.org/abs/2408.15232) has been accepted to EMNLP 2024 main conference. + - [2024/07] You can now install our package with `pip install knowledge-storm`! - [2024/07] We add `VectorRM` to support grounding on user-provided documents, complementing existing support of search engines (`YouRM`, `BingSearch`). (check out [#58](https://github.com/stanford-oval/storm/pull/58)) - [2024/07] We release demo light for developers a minimal user interface built with streamlit framework in Python, handy for local development and demo hosting (checkout [#54](https://github.com/stanford-oval/storm/pull/54)) @@ -24,17 +27,20 @@

-STORM is a LLM system that writes Wikipedia-like articles from scratch based on Internet search. +STORM is a LLM system that writes Wikipedia-like articles from scratch based on Internet search. Co-STORM further enhanced its feature by enabling human to collaborative LLM system to support more aligned and preferred information seeking and knowledge curation. While the system cannot produce publication-ready articles that often require a significant number of edits, experienced Wikipedia editors have found it helpful in their pre-writing stage. -**Try out our [live research preview](https://storm.genie.stanford.edu/) to see how STORM can help your knowledge exploration journey and please provide feedback to help us improve the system 🙏!** +**More than 70,000 people have tried our [live research preview](https://storm.genie.stanford.edu/). Try it out to see how STORM can help your knowledge exploration journey and please provide feedback to help us improve the system 🙏!** + +## How STORM & Co-STORM works -## How STORM works +### STORM STORM breaks down generating long articles with citations into two steps: + 1. **Pre-writing stage**: The system conducts Internet-based research to collect references and generates an outline. 2. **Writing stage**: The system uses the outline and references to generate the full-length article with citations.

@@ -45,9 +51,21 @@ STORM identifies the core of automating the research process as automatically co 1. **Perspective-Guided Question Asking**: Given the input topic, STORM discovers different perspectives by surveying existing articles from similar topics and uses them to control the question-asking process. 2. **Simulated Conversation**: STORM simulates a conversation between a Wikipedia writer and a topic expert grounded in Internet sources to enable the language model to update its understanding of the topic and ask follow-up questions. -Based on the separation of the two stages, STORM is implemented in a highly modular way using [dspy](https://github.com/stanfordnlp/dspy). +### CO-STORM + +Co-STORM proposes **a collaborative discourse protocol** which implements a turn management policy to support smooth collaboration among +- **Co-STORM LLM experts**: This type of agent generates answers grounded on external knowledge sources and/or raises follow-up questions based on the discourse history. +- **Moderator**: This agent generates thought-provoking questions inspired by information discovered by the retriever but not directly used in previous turns. Question generation can also be grounded! +- **Human user**: The human user will take the initiative to either (1) observe the discourse to gain deeper understanding of the topic, or (2) actively engage in the conversation by injecting utterances to steer the discussion focus. +

+ +

+ +Co-STORM also maintains a dynamic updated **mind map**, which organize collected information into a hierarchical concept structure, aiming to **build a shared conceptual space between the human user and the system**. The mind map has been proven to help reduce the mental load when the discourse goes long and in-depth. + +Both STORM and Co-STORM are implemented in a highly modular way using [dspy](https://github.com/stanfordnlp/dspy). ## Installation @@ -70,9 +88,20 @@ You could also install the source code which allows you to modify the behavior o ## API -The STORM knowledge curation engine is defined as a simple Python `STORMWikiRunner` class. -As STORM is working in the information curation layer, you need to set up the information retrieval module and language model module to create a `STORMWikiRunner` instance. Here is an example of using You.com search engine and OpenAI models. +Currently, our package support: + +- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components +- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components + +:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** + +Both STORM and Co-STORM are working in the information curation layer, you need to set up the information retrieval module and language model module to create their `Runner` classes respectively. + +### STORM + +The STORM knowledge curation engine is defined as a simple Python `STORMWikiRunner` class. Here is an example of using You.com search engine and OpenAI models. + ```python import os from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs @@ -101,12 +130,6 @@ rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) runner = STORMWikiRunner(engine_args, lm_configs, rm) ``` -Currently, our package support: -- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components -- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components - -:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** - The `STORMWikiRunner` instance can be evoked with the simple `run` method: ```python topic = input('Topic: ') @@ -125,41 +148,126 @@ runner.summary() - `do_generate_article`: if True, generate an article for the topic based on the outline and the collected information; otherwise, load the results. - `do_polish_article`: if True, polish the article by adding a summarization section and (optionally) removing duplicate content; otherwise, load the results. +### Co-STORM + +The Co-STORM knowledge curation engine is defined as a simple Python `CoStormRunner` class. Here is an example of using Bing search engine and OpenAI models. + +```python +from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner +from knowledge_storm.lm import OpenAIModel +from knowledge_storm.logging_wrapper import LoggingWrapper +from knowledge_storm.rm import BingSearch + +# Co-STORM adopts the same multi LM system paradigm as STORM +lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs() +openai_kwargs = { + "api_key": os.getenv("OPENAI_API_KEY"), + "api_provider": "openai", + "temperature": 1.0, + "top_p": 0.9, + "api_base": None, +} +question_answering_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) +discourse_manage_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) +utterance_polishing_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) +warmstart_outline_gen_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) +question_asking_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) +knowledge_base_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) + +lm_config.set_question_answering_lm(question_answering_lm) +lm_config.set_discourse_manage_lm(discourse_manage_lm) +lm_config.set_utterance_polishing_lm(utterance_polishing_lm) +lm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm) +lm_config.set_question_asking_lm(question_asking_lm) +lm_config.set_knowledge_base_lm(knowledge_base_lm) + +# Check out the Co-STORM's RunnerArguments class for more configurations. +topic = input('Topic: ') +runner_argument = RunnerArgument(topic=topic, ...) +logging_wrapper = LoggingWrapper(lm_config) +bing_rm = BingSearch(bing_search_api_key=os.environ.get("BING_SEARCH_API_KEY"), + k=runner_argument.retrieve_top_k) +costorm_runner = CoStormRunner(lm_config=lm_config, + runner_argument=runner_argument, + logging_wrapper=logging_wrapper, + rm=bing_rm) +``` + +The `CoStormRunner` instance can be evoked with the `warmstart()` and `step(...)` methods. + +```python +# Warm start the system to build shared conceptual space between Co-STORM and users +costorm_runner.warm_start() + +# Step through the collaborative discourse +# Run either of the code snippets below in any order, as many times as you'd like +# To observe the conversation: +conv_turn = costorm_runner.step() +# To inject your utterance to actively steer the conversation: +costorm_runner.step(user_utterance="YOUR UTTERANCE HERE") + +# Generate report based on the collaborative discourse +costorm_runner.knowledge_base.reogranize() +article = costorm_runner.generate_report() +print(article) +``` + + ## Quick Start with Example Scripts -We provide scripts in our [examples folder](examples) as a quick start to run STORM with different configurations. +We provide scripts in our [examples folder](examples) as a quick start to run STORM and Co-STORM with different configurations. + +We suggest using `secrets.toml` to set up the API keys. Create a file `secrets.toml` under the root directory and add the following content: + +```shell +# Set up OpenAI API key. +OPENAI_API_KEY="your_openai_api_key" +# If you are using the API service provided by OpenAI, include the following line: +OPENAI_API_TYPE="openai" +# If you are using the API service provided by Microsoft Azure, include the following lines: +OPENAI_API_TYPE="azure" +AZURE_API_BASE="your_azure_api_base_url" +AZURE_API_VERSION="your_azure_api_version" +# Set up You.com search API key. +YDC_API_KEY="your_youcom_api_key" +``` + +### STORM examples **To run STORM with `gpt` family models with default configurations:** -1. We suggest using `secrets.toml` to set up the API keys. Create a file `secrets.toml` under the root directory and add the following content: - ```shell - # Set up OpenAI API key. - OPENAI_API_KEY="your_openai_api_key" - # If you are using the API service provided by OpenAI, include the following line: - OPENAI_API_TYPE="openai" - # If you are using the API service provided by Microsoft Azure, include the following lines: - OPENAI_API_TYPE="azure" - AZURE_API_BASE="your_azure_api_base_url" - AZURE_API_VERSION="your_azure_api_version" - # Set up You.com search API key. - YDC_API_KEY="your_youcom_api_key" - ``` -2. Run the following command. - ``` - python examples/run_storm_wiki_gpt.py \ - --output-dir $OUTPUT_DIR \ - --retriever you \ - --do-research \ - --do-generate-outline \ - --do-generate-article \ - --do-polish-article - ``` + +Run the following command. +```bash +python examples/storm_examples/run_storm_wiki_gpt.py \ + --output-dir $OUTPUT_DIR \ + --retriever you \ + --do-research \ + --do-generate-outline \ + --do-generate-article \ + --do-polish-article +``` **To run STORM using your favorite language models or grounding on your own corpus:** Check out [examples/README.md](examples/README.md). +### Co-STORM examples + +To run Co-STORM with `gpt` family models with default configurations, + +1. Add `BING_SEARCH_API_KEY="xxx"`to `secrets.toml` +2. Run the following command + +```bash +python examples/costorm_examples/run_costorm_gpt.py \ + --output-dir $OUTPUT_DIR \ + --retriever bing +``` + ## Customization of the Pipeline +### STORM + If you have installed the source code, you can customize STORM based on your own use case. STORM engine consists of 4 modules: 1. Knowledge Curation Module: Collects a broad coverage of information about the given topic. @@ -169,12 +277,19 @@ If you have installed the source code, you can customize STORM based on your own The interface for each module is defined in `knowledge_storm/interface.py`, while their implementations are instantiated in `knowledge_storm/storm_wiki/modules/*`. These modules can be customized according to your specific requirements (e.g., generating sections in bullet point format instead of full paragraphs). +### Co-STORM + +If you have installed the source code, you can customize Co-STORM based on your own use case + +1. Co-STORM introduces multiple LLM agent types (i.e. Co-STORM experts and Moderator). LLM agent interface is defined in `knowledge_storm/interface.py` , while its implementation is instantiated in `knowledge_storm/collaborative_storm/modules/co_storm_agents.py`. Different LLM agent policies can be customized. +2. Co-STORM introduces a collaborative discourse protocol, with its core function centered on turn policy management. We provide an example implementation of turn policy management through `DiscourseManager` in `knowledge_storm/collaborative_storm/engine.py`. It can be customized and further improved. -## Replicate NAACL2024 result -Please switch to the branch `NAACL-2024-code-backup` [here](https://github.com/stanford-oval/storm/tree/NAACL-2024-code-backup). +## Replicate Replicate STORM & Co-STORM paper result +For STORM paper experiments, please switch to the branch `NAACL-2024-code-backup` [here](https://github.com/stanford-oval/storm/tree/NAACL-2024-code-backup). +For Co-STORM paper experiments, please switch to the branch `EMNLP-2024-code-backup` (placeholder for now, will be updated soon). ## Roadmap & Contributions Our team is actively working on: @@ -186,13 +301,23 @@ If you have any questions or suggestions, please feel free to open an issue or p Contact person: [Yijia Shao](mailto:shaoyj@stanford.edu) and [Yucheng Jiang](mailto:yuchengj@stanford.edu) ## Acknowledgement -We would like to thank Wikipedia for their excellent open-source content. The FreshWiki dataset is sourced from Wikipedia, licensed under the Creative Commons Attribution-ShareAlike (CC BY-SA) license. +We would like to thank Wikipedia for its excellent open-source content. The FreshWiki dataset is sourced from Wikipedia, licensed under the Creative Commons Attribution-ShareAlike (CC BY-SA) license. We are very grateful to [Michelle Lam](https://michelle123lam.github.io/) for designing the logo for this project and [Dekun Ma](https://dekun.me) for leading the UI development. ## Citation Please cite our paper if you use this code or part of it in your work: ```bibtex +@misc{jiang2024unknownunknowns, + title={Into the Unknown Unknowns: Engaged Human Learning through Participation in Language Model Agent Conversations}, + author={Yucheng Jiang and Yijia Shao and Dekun Ma and Sina J. Semnani and Monica S. Lam}, + year={2024}, + eprint={2408.15232}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2408.15232}, +} + @inproceedings{shao2024assisting, title={{Assisting in Writing Wikipedia-like Articles From Scratch with Large Language Models}}, author={Yijia Shao and Yucheng Jiang and Theodore A. Kanell and Peter Xu and Omar Khattab and Monica S. Lam}, diff --git a/assets/co-storm-workflow.jpg b/assets/co-storm-workflow.jpg new file mode 100644 index 00000000..23b80926 Binary files /dev/null and b/assets/co-storm-workflow.jpg differ diff --git a/examples/costorm_examples/run_costorm_gpt.py b/examples/costorm_examples/run_costorm_gpt.py new file mode 100644 index 00000000..66d6bd3c --- /dev/null +++ b/examples/costorm_examples/run_costorm_gpt.py @@ -0,0 +1,241 @@ +""" +Co-STORM pipeline powered by GPT-4o/4o-mini and Bing search engine. +You need to set up the following environment variables to run this script: + - OPENAI_API_KEY: OpenAI API key + - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure') + - AZURE_API_BASE: Azure API base URL if using Azure API + - AZURE_API_VERSION: Azure API version if using Azure API + - BING_SEARCH_API_KEY: Biang search API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key + +Output will be structured as below +args.output_dir/ + log.json # Log of information-seeking conversation + report.txt # Final article generated +""" + +import os +import json +from argparse import ArgumentParser +from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner +from knowledge_storm.collaborative_storm.modules.callback import LocalConsolePrintCallBackHandler +from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel +from knowledge_storm.logging_wrapper import LoggingWrapper +from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.utils import load_api_key + + +def main(args): + load_api_key(toml_file_path='secrets.toml') + lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs() + openai_kwargs = { + "api_key": os.getenv("OPENAI_API_KEY"), + "api_provider": "openai", + "temperature": 1.0, + "top_p": 0.9, + "api_base": None, + } if os.getenv('OPENAI_API_TYPE') == 'openai' else { + "api_key": os.getenv("AZURE_API_KEY"), + "temperature": 1.0, + "top_p": 0.9, + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + } + + ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + # If you are using Azure service, make sure the model name matches your own deployed model name. + # The default name here is only used for demonstration and may not match your case. + gpt_4o_mini_model_name = 'gpt-4o-mini' + gpt_4o_model_name = 'gpt-4o' + if os.getenv('OPENAI_API_TYPE') == 'azure': + openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') + openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + + # STORM is a LM system so different components can be powered by different models. + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models + # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm + # which is responsible for generating sections with citations. + question_answering_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) + discourse_manage_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) + utterance_polishing_lm = ModelClass(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) + warmstart_outline_gen_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) + question_asking_lm = ModelClass(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) + knowledge_base_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) + + lm_config.set_question_answering_lm(question_answering_lm) + lm_config.set_discourse_manage_lm(discourse_manage_lm) + lm_config.set_utterance_polishing_lm(utterance_polishing_lm) + lm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm) + lm_config.set_question_asking_lm(question_asking_lm) + lm_config.set_knowledge_base_lm(knowledge_base_lm) + + topic = input('Topic: ') + runner_argument = RunnerArgument( + topic=topic, + retrieve_top_k=args.retrieve_top_k, + max_search_queries=args.max_search_queries, + total_conv_turn=args.total_conv_turn, + max_search_thread=args.max_search_thread, + max_search_queries_per_turn=args.max_search_queries_per_turn, + warmstart_max_num_experts=args.warmstart_max_num_experts, + warmstart_max_turn_per_experts=args.warmstart_max_turn_per_experts, + warmstart_max_thread=args.warmstart_max_thread, + max_thread_num=args.max_thread_num, + max_num_round_table_experts=args.max_num_round_table_experts, + moderator_override_N_consecutive_answering_turn=args.moderator_override_N_consecutive_answering_turn, + node_expansion_trigger_count=args.node_expansion_trigger_count) + logging_wrapper = LoggingWrapper(lm_config) + callback_handler = LocalConsolePrintCallBackHandler() if args.enable_log_print else None + + # Co-STORM is a knowledge curation system which consumes information from the retrieval module. + # Currently, the information source is the Internet and we use search engine API as the retrieval module. + match args.retriever: + case 'bing': + rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=runner_argument.retrieve_top_k) + case 'you': + rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=runner_argument.retrieve_top_k) + case 'brave': + rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=runner_argument.retrieve_top_k) + case 'duckduckgo': + rm = DuckDuckGoSearchRM(k=runner_argument.retrieve_top_k, safe_search='On', region='us-en') + case 'serper': + rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1}) + case 'tavily': + rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=runner_argument.retrieve_top_k, include_raw_content=True) + case 'searxng': + rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=runner_argument.retrieve_top_k) + case _: + raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + + costorm_runner = CoStormRunner(lm_config=lm_config, + runner_argument=runner_argument, + logging_wrapper=logging_wrapper, + rm=rm, + callback_handler=callback_handler) + + # warm start the system + costorm_runner.warm_start() + + # Below is an example of how users may interact with Co-STORM to seek information together + # In actual deployment, we suggest allowing the user to decide whether to observe the agent utterance or inject a turn + + # observing Co-STORM LLM agent utterance for 5 turns + for _ in range(1): + conv_turn = costorm_runner.step() + print(f"**{conv_turn.role}**: {conv_turn.utterance}\n") + + # active engaging by injecting your utterance + your_utterance = input('Your utterance: ') + costorm_runner.step(user_utterance=your_utterance) + + # continue observing + conv_turn = costorm_runner.step() + print(f"**{conv_turn.role}**: {conv_turn.utterance}\n") + + # generate report + costorm_runner.knowledge_base.reogranize() + article = costorm_runner.generate_report() + + # save results + os.makedirs(args.output_dir, exist_ok=True) + + # Save article + with open(os.path.join(args.output_dir, "report.md"), "w") as f: + f.write(article) + + # Save logging + log_dump = costorm_runner.dump_logging_and_reset() + with open(os.path.join(args.output_dir, "log.json"), "w") as f: + json.dump(log_dump, f, indent=2) + + +if __name__ == '__main__': + parser = ArgumentParser() + # global arguments + parser.add_argument('--output-dir', type=str, default='./results/co-storm', + help='Directory to store the outputs.') + parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], + help='The search engine API to use for retrieving information.') + # hyperparameters for co-storm + parser.add_argument( + '--retrieve_top_k', + type=int, + default=10, + help='Retrieve top k results for each query in retriever.' + ) + parser.add_argument( + '--max_search_queries', + type=int, + default=2, + help='Maximum number of search queries to consider for each question.' + ) + parser.add_argument( + '--total_conv_turn', + type=int, + default=20, + help='Maximum number of turns in conversation.' + ) + parser.add_argument( + '--max_search_thread', + type=int, + default=5, + help='Maximum number of parallel threads for retriever.' + ) + parser.add_argument( + '--max_search_queries_per_turn', + type=int, + default=3, + help='Maximum number of search queries to consider in each turn.' + ) + parser.add_argument( + '--warmstart_max_num_experts', + type=int, + default=3, + help='Max number of experts in perspective-guided QA during warm start.' + ) + parser.add_argument( + '--warmstart_max_turn_per_experts', + type=int, + default=2, + help='Max number of turns per perspective during warm start.' + ) + parser.add_argument( + '--warmstart_max_thread', + type=int, + default=3, + help='Max number of threads for parallel perspective-guided QA during warm start.' + ) + parser.add_argument( + '--max_thread_num', + type=int, + default=10, + help=("Maximum number of threads to use. " + "Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API.") + ) + parser.add_argument( + '--max_num_round_table_experts', + type=int, + default=2, + help='Max number of active experts in round table discussion.' + ) + parser.add_argument( + '--moderator_override_N_consecutive_answering_turn', + type=int, + default=3, + help=('Number of consecutive expert answering turns before the moderator overrides the conversation.') + ) + parser.add_argument( + '--node_expansion_trigger_count', + type=int, + default=10, + help='Trigger node expansion for nodes that contain more than N snippets.' + ) + + # Boolean flags + parser.add_argument( + '--enable_log_print', + action='store_true', + help='If set, enable console log print.' + ) + + main(parser.parse_args()) diff --git a/examples/README.md b/examples/storm_examples/README.md similarity index 91% rename from examples/README.md rename to examples/storm_examples/README.md index 6aa04a38..083bd4f6 100644 --- a/examples/README.md +++ b/examples/storm_examples/README.md @@ -11,7 +11,7 @@ We host a number of example scripts for various customization of STORM (e.g., us 2. Run the following command under the root directory of the repository: ``` - python examples/run_storm_wiki_mistral.py \ + python examples/storm_examples/run_storm_wiki_mistral.py \ --url $URL \ --port $PORT \ --output-dir $OUTPUT_DIR \ @@ -50,7 +50,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca To create the vector store offline, run ``` - python examples/run_storm_wiki_gpt_with_VectorRM.py \ + python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ --output-dir $OUTPUT_DIR \ --vector-db-mode offline \ --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \ @@ -65,7 +65,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca To create the vector store online on a Qdrant server, run ``` - python examples/run_storm_wiki_gpt_with_VectorRM.py \ + python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ --output-dir $OUTPUT_DIR \ --vector-db-mode online \ --online-vector-db-url $ONLINE_VECTOR_DB_URL \ @@ -83,12 +83,12 @@ By default, STORM is grounded on the Internet using the search engine, but it ca - Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above. ``` - python examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV + python examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV ``` - Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.) ``` - python examples/run_storm_wiki_gpt_with_VectorRM.py \ + python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ --output-dir $OUTPUT_DIR \ --vector-db-mode offline \ --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \ @@ -102,7 +102,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca - For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link). ``` - python examples/run_storm_wiki_gpt_with_VectorRM.py \ + python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \ --output-dir $OUTPUT_DIR \ --vector-db-mode offline \ --offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \ diff --git a/examples/helper/process_kaggle_arxiv_abstract_dataset.py b/examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py similarity index 100% rename from examples/helper/process_kaggle_arxiv_abstract_dataset.py rename to examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py diff --git a/examples/run_storm_wiki_claude.py b/examples/storm_examples/run_storm_wiki_claude.py similarity index 100% rename from examples/run_storm_wiki_claude.py rename to examples/storm_examples/run_storm_wiki_claude.py diff --git a/examples/run_storm_wiki_deepseek.py b/examples/storm_examples/run_storm_wiki_deepseek.py similarity index 100% rename from examples/run_storm_wiki_deepseek.py rename to examples/storm_examples/run_storm_wiki_deepseek.py diff --git a/examples/run_storm_wiki_gemini.py b/examples/storm_examples/run_storm_wiki_gemini.py similarity index 100% rename from examples/run_storm_wiki_gemini.py rename to examples/storm_examples/run_storm_wiki_gemini.py diff --git a/examples/run_storm_wiki_gpt.py b/examples/storm_examples/run_storm_wiki_gpt.py similarity index 100% rename from examples/run_storm_wiki_gpt.py rename to examples/storm_examples/run_storm_wiki_gpt.py diff --git a/examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py similarity index 99% rename from examples/run_storm_wiki_gpt_with_VectorRM.py rename to examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py index 12a46e4d..8fa4d0d6 100644 --- a/examples/run_storm_wiki_gpt_with_VectorRM.py +++ b/examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py @@ -27,10 +27,8 @@ """ import os -import sys from argparse import ArgumentParser -sys.path.append('./') from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs from knowledge_storm.rm import VectorRM from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel diff --git a/examples/run_storm_wiki_groq.py b/examples/storm_examples/run_storm_wiki_groq.py similarity index 97% rename from examples/run_storm_wiki_groq.py rename to examples/storm_examples/run_storm_wiki_groq.py index 55419c27..0dcaadbb 100644 --- a/examples/run_storm_wiki_groq.py +++ b/examples/storm_examples/run_storm_wiki_groq.py @@ -18,17 +18,10 @@ """ import os -import sys import re -import logging from argparse import ArgumentParser from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -# Get the absolute path to the directory containing lm.py -lm_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'knowledge_storm')) - -# Add this path to sys.path -sys.path.insert(0, lm_path) # Now import lm directly import lm diff --git a/examples/run_storm_wiki_mistral.py b/examples/storm_examples/run_storm_wiki_mistral.py similarity index 100% rename from examples/run_storm_wiki_mistral.py rename to examples/storm_examples/run_storm_wiki_mistral.py diff --git a/examples/run_storm_wiki_ollama.py b/examples/storm_examples/run_storm_wiki_ollama.py similarity index 99% rename from examples/run_storm_wiki_ollama.py rename to examples/storm_examples/run_storm_wiki_ollama.py index 519467b3..465e065a 100644 --- a/examples/run_storm_wiki_ollama.py +++ b/examples/storm_examples/run_storm_wiki_ollama.py @@ -21,7 +21,6 @@ from dspy import Example -sys.path.append('./src') from knowledge_storm.lm import OllamaClient from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs diff --git a/examples/run_storm_wiki_ollama_with_searxng.py b/examples/storm_examples/run_storm_wiki_ollama_with_searxng.py similarity index 100% rename from examples/run_storm_wiki_ollama_with_searxng.py rename to examples/storm_examples/run_storm_wiki_ollama_with_searxng.py diff --git a/examples/run_storm_wiki_serper.py b/examples/storm_examples/run_storm_wiki_serper.py similarity index 100% rename from examples/run_storm_wiki_serper.py rename to examples/storm_examples/run_storm_wiki_serper.py diff --git a/knowledge_storm/__init__.py b/knowledge_storm/__init__.py index 5ce0f2c5..f93158a7 100644 --- a/knowledge_storm/__init__.py +++ b/knowledge_storm/__init__.py @@ -1,7 +1,10 @@ -from .storm_wiki.engine import ( - STORMWikiLMConfigs, - STORMWikiRunnerArguments, - STORMWikiRunner, -) +from .storm_wiki import * +from .collaborative_storm import * +from .encoder import * +from .interface import * +from .lm import * +from .rm import * +from .utils import * +from .dataclass import * -__version__ = "0.2.8" +__version__ = "1.0.0" diff --git a/knowledge_storm/collaborative_storm/__init__.py b/knowledge_storm/collaborative_storm/__init__.py new file mode 100644 index 00000000..b85715b2 --- /dev/null +++ b/knowledge_storm/collaborative_storm/__init__.py @@ -0,0 +1,2 @@ +from .modules import * +from .engine import * diff --git a/knowledge_storm/collaborative_storm/engine.py b/knowledge_storm/collaborative_storm/engine.py new file mode 100644 index 00000000..01684b9b --- /dev/null +++ b/knowledge_storm/collaborative_storm/engine.py @@ -0,0 +1,745 @@ +import dspy +import os +from dataclasses import dataclass, field, asdict +from typing import List, Union, Literal, Optional, Dict + +from .modules import collaborative_storm_utils as collaborative_storm_utils +from .modules.callback import BaseCallbackHandler +from .modules.co_storm_agents import ( + SimulatedUser, + PureRAGAgent, + Moderator, + CoStormExpert, +) +from .modules.expert_generation import GenerateExpertModule +from .modules.warmstart_hierarchical_chat import WarmStartModule +from ..dataclass import ConversationTurn, KnowledgeBase +from ..interface import LMConfigs, Agent +from ..logging_wrapper import LoggingWrapper +from ..lm import OpenAIModel, AzureOpenAIModel, TogetherClient +from ..rm import BingSearch + + +class CollaborativeStormLMConfigs(LMConfigs): + """Configurations for LLM used in different parts of Co-STORM. + + Given that different parts in Co-STORM framework have different complexity, we use different LLM configurations + to achieve a balance between quality and efficiency. If no specific configuration is provided, we use the default + setup in the paper. + """ + + def __init__(self): + self.question_answering_lm = None + self.discourse_manage_lm = None + self.utterance_polishing_lm = None + self.warmstart_outline_gen_lm = None + self.question_asking_lm = None + self.knowledge_base_lm = None + + def init( + self, + lm_type: Literal["openai", "azure", "together"], + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 0.9, + ): + if lm_type and lm_type == "openai": + openai_kwargs = { + "api_key": os.getenv("OPENAI_API_KEY"), + "api_provider": "openai", + "temperature": temperature, + "top_p": top_p, + "api_base": None, + } + self.question_answering_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs + ) + self.discourse_manage_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=500, **openai_kwargs + ) + self.utterance_polishing_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=2000, **openai_kwargs + ) + self.warmstart_outline_gen_lm = OpenAIModel( + model="gpt-4-1106-preview", max_tokens=500, **openai_kwargs + ) + self.question_asking_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=300, **openai_kwargs + ) + self.knowledge_base_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs + ) + elif lm_type and lm_type == "azure": + azure_kwargs = { + "api_key": os.getenv("AZURE_API_KEY"), + "temperature": temperature, + "top_p": top_p, + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + } + self.question_answering_lm = AzureOpenAIModel( + model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat" + ) + self.discourse_manage_lm = AzureOpenAIModel( + model="gpt-4o", max_tokens=500, **azure_kwargs, model_type="chat" + ) + self.utterance_polishing_lm = AzureOpenAIModel( + model="gpt-4o", max_tokens=2000, **azure_kwargs, model_type="chat" + ) + self.warmstart_outline_gen_lm = AzureOpenAIModel( + model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat" + ) + self.question_asking_lm = AzureOpenAIModel( + model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat" + ) + self.knowledge_base_lm = AzureOpenAIModel( + model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat" + ) + elif lm_type and lm_type == "together": + together_kwargs = { + "api_key": os.getenv("TOGETHER_API_KEY"), + "temperature": temperature, + "top_p": top_p, + } + self.question_answering_lm = TogetherClient( + model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + max_tokens=1000, + model_type="chat", + **together_kwargs, + ) + self.discourse_manage_lm = TogetherClient( + model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + max_tokens=500, + model_type="chat", + **together_kwargs, + ) + self.utterance_polishing_lm = TogetherClient( + model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + max_tokens=2000, + model_type="chat", + **together_kwargs, + ) + self.warmstart_outline_gen_lm = TogetherClient( + model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + max_tokens=500, + model_type="chat", + **together_kwargs, + ) + self.question_asking_lm = TogetherClient( + model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + max_tokens=300, + model_type="chat", + **together_kwargs, + ) + self.knowledge_base_lm = TogetherClient( + model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + max_tokens=1000, + model_type="chat", + **together_kwargs, + ) + else: + raise Exception( + "No valid OpenAI API provider is provided. Cannot use default LLM configurations." + ) + + def set_question_answering_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.question_answering_lm = model + + def set_discourse_manage_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.discourse_manage_lm = model + + def set_utterance_polishing_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.utterance_polishing_lm = model + + def set_warmstart_outline_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.warmstart_outline_gen_lm = model + + def set_question_asking_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.question_asking_lm = model + + def set_knowledge_base_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.knowledge_base_lm = model + + def collect_and_reset_lm_usage(self): + lm_usage = {} + for attr_name in self.__dict__: + if "_lm" in attr_name and hasattr( + getattr(self, attr_name), "get_usage_and_reset" + ): + usage = getattr(self, attr_name).get_usage_and_reset() + if any( + value["prompt_tokens"] != 0 or value["completion_tokens"] != 0 + for value in usage.values() + ): + lm_usage[attr_name] = usage + return lm_usage + + def to_dict(self): + """ + Converts the CollaborativeStormLMConfigs instance to a dictionary representation. + + Returns: + dict: The dictionary representation of the CollaborativeStormLMConfigs. + """ + config_dict = {} + for attr_name in self.__dict__: + config_dict[attr_name] = getattr(self, attr_name).kwargs + return config_dict + + +@dataclass +class RunnerArgument: + """Arguments for controlling the STORM Wiki pipeline.""" + + topic: str = field( + metadata={"help": "Topic of discourse"}, + ) + retrieve_top_k: int = field( + default=10, + metadata={"help": "retrieve top k results for each query in retriever"}, + ) + max_search_queries: int = field( + default=2, + metadata={ + "help": "Maximum number of search queries to consider for each question." + }, + ) + total_conv_turn: int = field( + default=20, + metadata={"help": "Maximum number turn in conversation."}, + ) + max_search_thread: int = field( + default=5, + metadata={"help": "Maximum number of parallel thread for retriever"}, + ) + max_search_queries_per_turn: int = field( + default=3, + metadata={"help": "Maximum number of search queries to consider in each turn."}, + ) + warmstart_max_num_experts: int = field( + default=3, + metadata={ + "help": "Max number of experts in perspective guided QA in warm start process" + }, + ) + warmstart_max_turn_per_experts: int = field( + default=2, + metadata={"help": "Max number of turns per perspective in warm start process"}, + ) + warmstart_max_thread: int = field( + default=3, + metadata={ + "help": "Max number thread for parallel perspective guided QA in warm start process" + }, + ) + max_thread_num: int = field( + default=10, + metadata={ + "help": "Maximum number of threads to use. " + "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API." + }, + ) + max_num_round_table_experts: int = field( + default=2, + metadata={"help": "Max number of active experts in round table discussion."}, + ) + moderator_override_N_consecutive_answering_turn: int = field( + default=3, + metadata={ + "help": "Number of consecutive experts answering turn before moderator override the conversation" + }, + ) + node_expansion_trigger_count: int = field( + default=10, + metadata={ + "help": "Trigger node expansion for node that contain more than N snippets" + }, + ) + disable_moderator: bool = field( + default=False, + metadata={"help": "If True, disable moderator."}, + ) + disable_multi_experts: bool = field( + default=False, + metadata={"help": "If True, disable moderator."}, + ) + rag_only_baseline_mode: bool = field( + default=False, + metadata={"help": "If True, switch to rag online baseline mode"}, + ) + + def to_dict(self): + """ + Converts the RunnerArgument instance to a dictionary representation. + + Returns: + dict: The dictionary representation of the RunnerArgument. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data): + """ + Constructs a RunnerArgument instance from a dictionary representation. + + Args: + data (dict): The dictionary representation of the RunnerArgument. + + Returns: + RunnerArgument: The constructed RunnerArgument instance. + """ + return cls(**data) + + +@dataclass +class TurnPolicySpec: + """ + Represents the policy specifications for determining the behavior of a conversation turn. + + Attributes: + should_reorganize_knowledge_base (bool): + A flag that indicates whether the knowledge base should be reorganized after the current turn. + + should_update_experts_list (bool): + A flag that indicates whether the list of experts should be updated based on the conversation context. + + should_polish_utterance (bool): + A flag that indicates whether the generated utterance should be polished (e.g., refined or rephrased) before it is used in the conversation. + + agent (Agent): + The `Agent` responsible for generating utterances or responses during the conversation turn. + This agent interacts with the knowledge base and the conversation history to produce responses. + """ + + should_reorganize_knowledge_base: bool = False + should_update_experts_list: bool = False + should_polish_utterance: bool = False + agent: Agent = None + + +class DiscourseManager: + def __init__( + self, + logging_wrapper: LoggingWrapper, + lm_config: CollaborativeStormLMConfigs, + runner_argument: RunnerArgument, + rm: dspy.Retrieve, + callback_handler: BaseCallbackHandler, + ): + # parameter management + self.lm_config = lm_config + self.runner_argument = runner_argument + self.logging_wrapper = logging_wrapper + self.callback_handler = callback_handler + self.rm = rm + # role management + self.experts: List[CoStormExpert] = [] + self.simulated_user: SimulatedUser = SimulatedUser( + topic=self.runner_argument.topic, + role_name="Guest", + role_description="", + intent=None, + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + callback_handler=self.callback_handler, + ) + self.pure_rag_agent: PureRAGAgent = PureRAGAgent( + topic=self.runner_argument.topic, + role_name="PureRAG", + role_description="", + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=self.rm, + callback_handler=self.callback_handler, + ) + self.moderator: Moderator = Moderator( + topic=self.runner_argument.topic, + role_name="Moderator", + role_description="", + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + callback_handler=self.callback_handler, + ) + self.general_knowledge_provider = CoStormExpert( + topic=self.runner_argument.topic, + role_name="General Knowledge Provider", + role_description="Focus on broadly covering the basic facts about the question.", + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=self.rm, + callback_handler=self.callback_handler, + ) + self.generate_expert_module = GenerateExpertModule( + engine=self.lm_config.discourse_manage_lm + ) + self.next_turn_moderator_override = False + + def serialize_experts(self) -> List[Dict]: + return [ + { + "topic": expert.topic, + "role_name": expert.role_name, + "role_description": expert.role_description, + } + for expert in self.experts + ] + + def deserialize_experts(self, data: List[Dict]): + for expert_data in data: + self.experts.append( + CoStormExpert( + topic=expert_data["topic"], + role_name=expert_data["role_name"], + role_description=expert_data["role_description"], + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=self.rm, + callback_handler=self.callback_handler, + ) + ) + + def _should_generate_question( + self, conversation_history: List[ConversationTurn] + ) -> bool: + consecutive_non_questioning_turn = 0 + for conv_turn in reversed(conversation_history): + if conv_turn.utterance_type not in [ + "Original Question", + "Information Request", + ]: + consecutive_non_questioning_turn += 1 + else: + break + return ( + consecutive_non_questioning_turn + >= self.runner_argument.moderator_override_N_consecutive_answering_turn + ) + + def _parse_expert_names_to_agent(self, expert_descriptions: Union[str, List[str]]): + if type(expert_descriptions) == str: + expert_descriptions = [expert_descriptions] + agents: CoStormExpert = [] + for expert_name in expert_descriptions: + role_name, role_description = expert_name.split(":") + role_name = role_name.strip() + role_description = role_description.strip() + new_costorm_expert = CoStormExpert( + topic=self.runner_argument.topic, + role_name=role_name, + role_description=role_description, + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=self.rm, + callback_handler=self.callback_handler, + ) + agents.append(new_costorm_expert) + return agents + + def _update_expert_list_from_utterance(self, focus: str, background_info: str): + expert_names = self.generate_expert_module( + topic=self.runner_argument.topic, + background_info=background_info, + focus=focus, + num_experts=self.runner_argument.max_num_round_table_experts, + ).experts + self.experts = self._parse_expert_names_to_agent(expert_names) + + def _is_last_turn_questioning(self, conversation_history: List[ConversationTurn]): + return conversation_history and conversation_history[-1].utterance_type in [ + "Original Question", + "Information Request", + ] + + def get_next_turn_policy( + self, + conversation_history: List[ConversationTurn], + dry_run=False, + simulate_user=False, + simulate_user_intent: str = None, + ) -> TurnPolicySpec: + next_turn_policy = TurnPolicySpec() + if simulate_user: + self.simulated_user.intent = simulate_user_intent + next_turn_policy.agent = self.simulated_user + elif self.runner_argument.rag_only_baseline_mode: + assert self.conversation_history[-1].role == "Guest" + next_turn_policy.agent = self.pure_rag_agent + elif ( + not self.runner_argument.disable_moderator + and self._should_generate_question(conversation_history) + ): + next_turn_policy.agent = self.moderator + next_turn_policy.should_reorganize_knowledge_base = True + elif self.next_turn_moderator_override: + next_turn_policy.agent = self.moderator + if not dry_run: + self.next_turn_moderator_override = False + # experts RAG gen + else: + next_turn_policy.agent = self.general_knowledge_provider + if ( + not self._is_last_turn_questioning(conversation_history) + and not self.runner_argument.disable_multi_experts + ): + if dry_run: + next_turn_policy.agent = self.experts[0] + else: + next_turn_policy.agent = self.experts.pop(0) + self.experts.append(next_turn_policy.agent) + next_turn_policy.should_update_experts_list = ( + self._is_last_turn_questioning(conversation_history) + and not self.runner_argument.disable_multi_experts + ) + next_turn_policy.should_polish_utterance = True + return next_turn_policy + + +class CoStormRunner: + def __init__( + self, + lm_config: CollaborativeStormLMConfigs, + runner_argument: RunnerArgument, + logging_wrapper: LoggingWrapper, + rm: Optional[dspy.Retrieve] = None, + callback_handler: BaseCallbackHandler = None, + ): + self.runner_argument = runner_argument + self.lm_config = lm_config + self.logging_wrapper = logging_wrapper + self.callback_handler = callback_handler + if rm is None: + self.rm = BingSearch(k=runner_argument.retrieve_top_k) + else: + self.rm = rm + self.conversation_history = [] + self.warmstart_conv_archive = [] + self.knowledge_base = KnowledgeBase( + topic=self.runner_argument.topic, + knowledge_base_lm=self.lm_config.knowledge_base_lm, + node_expansion_trigger_count=self.runner_argument.node_expansion_trigger_count, + ) + self.discourse_manager = DiscourseManager( + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=self.rm, + callback_handler=callback_handler, + ) + + def to_dict(self): + return { + "runner_argument": self.runner_argument.to_dict(), + "lm_config": self.lm_config.to_dict(), + "conversation_history": [ + turn.to_dict() for turn in self.conversation_history + ], + "warmstart_conv_archive": [ + turn.to_dict() for turn in self.warmstart_conv_archive + ], + "experts": self.discourse_manager.serialize_experts(), + "knowledge_base": self.knowledge_base.to_dict(), + } + + @classmethod + def from_dict(cls, data): + # FIXME: does not use the lm_config data but naively use default setting + lm_config = CollaborativeStormLMConfigs() + lm_config.init(lm_type=os.getenv("OPENAI_API_TYPE")) + costorm_runner = cls( + lm_config=lm_config, + runner_argument=RunnerArgument.from_dict(data["runner_argument"]), + logging_wrapper=LoggingWrapper(lm_config), + ) + costorm_runner.conversation_history = [ + ConversationTurn.from_dict(turn) for turn in data["conversation_history"] + ] + costorm_runner.warmstart_conv_archive = [ + ConversationTurn.from_dict(turn) + for turn in data.get("warmstart_conv_archive", []) + ] + costorm_runner.discourse_manager.deserialize_experts(data["experts"]) + costorm_runner.knowledge_base = KnowledgeBase.from_dict( + data=data["knowledge_base"], + knowledge_base_lm=costorm_runner.lm_config.knowledge_base_lm, + node_expansion_trigger_count=costorm_runner.runner_argument.node_expansion_trigger_count, + ) + return costorm_runner + + def warm_start(self): + """ + Warm start co-storm system to conduct background information search in order to build shared conceptual space with user. + This stage is a mini-STORM, spawning multiple LLM agent with different perspective and perform multi-round conversation. + The knowledge base (i.e. mind map) will be initialize using the collected information. + + It will also generate a first draft of report and use it to produce an engaging and concise conversation presented to the + user to catch up with system's knowledge about the topic. + """ + with self.logging_wrapper.log_pipeline_stage( + pipeline_stage=f"warm start stage" + ): + if not self.runner_argument.rag_only_baseline_mode: + warm_start_module = WarmStartModule( + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=self.rm, + callback_handler=self.callback_handler, + ) + + warmstart_conv, warmstart_revised_conv, warmstart_experts = ( + warm_start_module.initiate_warm_start( + topic=self.runner_argument.topic, + knowledge_base=self.knowledge_base, + ) + ) + self.discourse_manager.experts = ( + self.discourse_manager._parse_expert_names_to_agent( + warmstart_experts + ) + ) + self.discourse_manager.next_turn_moderator_override = True + self.conversation_history = ( + warmstart_revised_conv if warmstart_revised_conv else warmstart_conv + ) + self.warmstart_conv_archive = warmstart_conv + self.knowledge_base.reogranize() + else: + if self.knowledge_base is None: + self.knowledge_base = KnowledgeBase( + topic=self.runner_argument.topic + ) + if self.conversation_history is None: + self.conversation_history = [] + conv_turn = ( + self.discourse_manager.pure_rag_agent.generate_topic_background() + ) + self.conversation_history.append(conv_turn) + self.knowledge_base.update_from_conv_turn( + conv_turn=conv_turn, + allow_create_new_node=True, + insert_under_root=self.runner_argument.rag_only_baseline_mode, + ) + + def generate_report(self) -> str: + """ + Generate report leveraging organized collected information in the knowledge base (i.e. mind map). + The article generation follows the paradigm in STORM paper, where it considers mind map nodes as section names, and generate the report section by section. + + Returns: + str: A string representing the report, with "#" "##" indicating hierarchical sections and [1][2] indicating references. + """ + with self.logging_wrapper.log_pipeline_stage("report generation stage"): + with self.logging_wrapper.log_event( + "report generation stage: generate report" + ): + return self.knowledge_base.to_report() + + def dump_logging_and_reset(self): + return self.logging_wrapper.dump_logging_and_reset() + + def step( + self, + user_utterance: str = "", + simulate_user: bool = False, + simulate_user_intent: str = "", + ) -> ConversationTurn: + """ + Yields a single turn in the conversation flow. + + This method take a user input when user choose to inject an utterance or generates the next system utterance based on the current conversation history and defined discourse policies. + It handles updating the conversation history, managing expert lists, and interacting with the knowledge base. + Additionally, it logs each stage of the conversation for monitoring and debugging purposes. + + Args: + user_utterance (str, optional): The input provided by the user. If provided, this utterance is added directly to the conversation history and returns with no further action. + simulate_user (bool, optional): This is designed for automatic experiments using a LLM agent to simulate user actions. Flag indicating whether to simulate user behavior. When set to `True`, the system will generate user intents based on predefined simulation logic. Defaults to `False`. + simulate_user_intent (str, optional): This is designed for automatic experiments using a LLM agent to simulate user actions. Specifies the intent to simulate for the user. This is used when `simulate_user` is `True` to guide the simulated user's responses, + + Returns: + ConversationTurn: An object representing the latest turn in the conversation. + + Workflow: + 1. User Utterance Handling + - If `user_utterance` is provided, it is appended to the `conversation_history` + + 2. System Utterance Generation + - If no `user_utterance` is provided, the method proceeds to generate the next system utterance. + - Determines the next turn policy by consulting the `discourse_manager` with the current conversation history. + - Generates a new utterance using the agent defined in the turn policy, leveraging the `knowledge_base` and `conversation_history`. + - If the turn policy indicates that the experts list should be updated, it updates the expert list based on the latest utterances. + + 4. Knowledge Base Update + - Inserts the new turn into the `knowledge_base`, optionally allowing the creation of new nodes or inserting under the root based on the `rag_only_baseline_mode` flag. + - If the turn policy specifies, it reorganizes the `knowledge_base` to maintain optimal structure and relevance. + """ + last_conv_turn = self.conversation_history[-1] + cur_turn_name = f"conv turn: {len(self.conversation_history) + 1}" + with self.logging_wrapper.log_pipeline_stage( + pipeline_stage=f"{cur_turn_name} stage" + ): + conv_turn = None + if user_utterance: + self.discourse_manager.next_turn_moderator_override = False + conv_turn = ConversationTurn( + role="Guest", + raw_utterance=user_utterance, + utterance_type="Original Question", + ) + self.conversation_history.append(conv_turn) + else: + with self.logging_wrapper.log_event( + f"{cur_turn_name}: get turn policy" + ): + if self.callback_handler is not None: + self.callback_handler.on_turn_policy_planning_start() + turn_policy = self.discourse_manager.get_next_turn_policy( + conversation_history=self.conversation_history, + simulate_user=simulate_user, + simulate_user_intent=simulate_user_intent, + dry_run=False, + ) + + with self.logging_wrapper.log_event( + f"{cur_turn_name}: generate utterance" + ): + conv_turn = turn_policy.agent.generate_utterance( + knowledge_base=self.knowledge_base, + conversation_history=self.conversation_history, + ) + + if turn_policy.should_update_experts_list: + with self.logging_wrapper.log_event( + f"{cur_turn_name}: update experts list" + ): + self.discourse_manager._update_expert_list_from_utterance( + focus=last_conv_turn.raw_utterance, + background_info=conv_turn.raw_utterance, + ) + + if conv_turn is not None: + self.conversation_history.append(conv_turn) + with self.logging_wrapper.log_event( + f"{cur_turn_name}: insert into knowledge base" + ): + if self.callback_handler is not None: + self.callback_handler.on_mindmap_insert_start() + self.knowledge_base.update_from_conv_turn( + conv_turn=conv_turn, + allow_create_new_node=True, + insert_under_root=self.runner_argument.rag_only_baseline_mode, + ) + if self.callback_handler is not None: + self.callback_handler.on_mindmap_insert_end() + if turn_policy.should_reorganize_knowledge_base: + with self.logging_wrapper.log_event( + f"{cur_turn_name}: reorganize knowledge base" + ): + if self.callback_handler is not None: + self.callback_handler.on_mindmap_reorg_start() + self.knowledge_base.reogranize() + return conv_turn diff --git a/knowledge_storm/collaborative_storm/modules/__init__.py b/knowledge_storm/collaborative_storm/modules/__init__.py new file mode 100644 index 00000000..3807c645 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/__init__.py @@ -0,0 +1,8 @@ +from .article_generation import * +from .grounded_question_answering import * +from .grounded_question_generation import * +from .information_insertion_module import * +from .simulate_user import * +from .warmstart_hierarchical_chat import * +from .knowledge_base_summary import * +from .costorm_expert_utterance_generator import * diff --git a/knowledge_storm/collaborative_storm/modules/article_generation.py b/knowledge_storm/collaborative_storm/modules/article_generation.py new file mode 100644 index 00000000..be614007 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/article_generation.py @@ -0,0 +1,123 @@ +import dspy +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Set, Union + +from .collaborative_storm_utils import clean_up_section +from ...dataclass import KnowledgeBase, KnowledgeNode + + +class ArticleGenerationModule(dspy.Module): + """Use the information collected from the information-seeking conversation to write a section.""" + + def __init__( + self, + engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + ): + super().__init__() + self.write_section = dspy.Predict(WriteSection) + self.engine = engine + + def _get_cited_information_string( + self, + all_citation_index: Set[int], + knowledge_base: KnowledgeBase, + max_words: int = 1500, + ): + information = [] + cur_word_count = 0 + for index in sorted(list(all_citation_index)): + info = knowledge_base.info_uuid_to_info_dict[index] + snippet = info.snippets[0] + info_text = f"[{index}]: {snippet} (Question: {info.meta['question']}. Query: {info.meta['query']})" + cur_snippet_length = len(info_text.split()) + if cur_snippet_length + cur_word_count > max_words: + break + cur_word_count += cur_snippet_length + information.append(info_text) + return "\n".join(information) + + def gen_section( + self, topic: str, node: KnowledgeNode, knowledge_base: KnowledgeBase + ): + if node is None or len(node.content) == 0: + return "" + if ( + node.synthesize_output is not None + and node.synthesize_output + and not node.need_regenerate_synthesize_output + ): + return node.synthesize_output + all_citation_index = node.collect_all_content() + information = self._get_cited_information_string( + all_citation_index=all_citation_index, knowledge_base=knowledge_base + ) + with dspy.settings.context(lm=self.engine): + synthesize_output = clean_up_section( + self.write_section( + topic=topic, info=information, section=node.name + ).output + ) + node.synthesize_output = synthesize_output + node.need_regenerate_synthesize_output = False + return node.synthesize_output + + def forward(self, knowledge_base: KnowledgeBase): + all_nodes = knowledge_base.collect_all_nodes() + node_to_paragraph = {} + + # Define a function to generate paragraphs for nodes + def _node_generate_paragraph(node): + node_gen_paragraph = self.gen_section( + topic=knowledge_base.topic, node=node, knowledge_base=knowledge_base + ) + lines = node_gen_paragraph.split("\n") + if lines[0].strip().replace("*", "").replace("#", "") == node.name: + lines = lines[1:] + node_gen_paragraph = "\n".join(lines) + path = " -> ".join(node.get_path_from_root()) + return path, node_gen_paragraph + + with ThreadPoolExecutor(max_workers=5) as executor: + # Submit all tasks + future_to_node = { + executor.submit(_node_generate_paragraph, node): node + for node in all_nodes + } + + # Collect the results as they complete + for future in as_completed(future_to_node): + path, node_gen_paragraph = future.result() + node_to_paragraph[path] = node_gen_paragraph + + def helper(cur_root, level): + to_return = [] + if cur_root is not None: + hash_tag = "#" * level + " " + cur_path = " -> ".join(cur_root.get_path_from_root()) + node_gen_paragraph = node_to_paragraph[cur_path] + to_return.append(f"{hash_tag}{cur_root.name}\n{node_gen_paragraph}") + for child in cur_root.children: + to_return.extend(helper(child, level + 1)) + return to_return + + to_return = [] + for child in knowledge_base.root.children: + to_return.extend(helper(child, level=1)) + + return "\n".join(to_return) + + +class WriteSection(dspy.Signature): + """Write a Wikipedia section based on the collected information. You will be given the topic, the section you are writing and relevant information. + Each information will be provided with the raw content along with question and query lead to that information. + Here is the format of your writing: + Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. + """ + + info = dspy.InputField(prefix="The collected information:\n", format=str) + topic = dspy.InputField(prefix="The topic of the page: ", format=str) + section = dspy.InputField(prefix="The section you need to write: ", format=str) + output = dspy.OutputField( + prefix="Write the section with proper inline citations (Start your writing. Don't include the page title, section name, or try to write other sections. Do not start the section with topic name.):\n", + format=str, + ) diff --git a/knowledge_storm/collaborative_storm/modules/callback.py b/knowledge_storm/collaborative_storm/modules/callback.py new file mode 100644 index 00000000..9d610125 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/callback.py @@ -0,0 +1,110 @@ +from typing import List +from ...interface import Information + + +class BaseCallbackHandler: + """Base callback handler to manage callbacks from the Co-STORM pipeline.""" + + def on_turn_policy_planning_start(self, **kwargs): + """Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn.""" + pass + + def on_expert_action_planning_start(self, **kwargs): + """Run when the expert action planning begins, preparing to determine the actions that each expert should take.""" + pass + + def on_expert_action_planning_end(self, **kwargs): + """Run when the expert action planning ends, after deciding the actions that each expert should take.""" + pass + + def on_expert_information_collection_start(self, **kwargs): + """Run when the expert information collection starts, start gathering all necessary data from selected sources.""" + pass + + def on_expert_information_collection_end(self, info: List[Information], **kwargs): + """Run when the expert information collection ends, after gathering all necessary data from selected sources.""" + pass + + def on_expert_utterance_generation_end(self, **kwargs): + """Run when the expert utterance generation ends, before creating responses or statements from each expert.""" + pass + + def on_expert_utterance_polishing_start(self, **kwargs): + """Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content.""" + pass + + def on_mindmap_insert_start(self, **kwargs): + """Run when the process of inserting new information into the mindmap starts.""" + pass + + def on_mindmap_insert_end(self, **kwargs): + """Run when the process of inserting new information into the mindmap ends.""" + pass + + def on_mindmap_reorg_start(self, **kwargs): + """Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information.""" + pass + + def on_expert_list_update_start(self, **kwargs): + """Run when the expert list update starts, to modify or refresh the list of active experts.""" + pass + + def on_article_generation_start(self, **kwargs): + """Run when the article generation process begins, to compile and format the final article content.""" + pass + + def on_warmstart_update(self, message, **kwargs): + """Run when the warm start process has update.""" + pass + + +class LocalConsolePrintCallBackHandler(BaseCallbackHandler): + def __init__(self): + pass + + def on_turn_policy_planning_start(self, **kwargs): + """Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn.""" + print("Start planning next expert; inspect mind map; inspect system state.") + + def on_expert_action_planning_start(self, **kwargs): + """Run when the expert action planning begins, preparing to determine the actions that each expert should take.""" + print("Reviewing discourse history; Deciding utterance intent.") + + def on_expert_information_collection_start(self, **kwargs): + """Run when the expert information collection ends, after gathering all necessary data from selected sources.""" + print("Start searching with the search engine; browsing collected information.") + + def on_expert_information_collection_end(self, info: List[Information], **kwargs): + """Run when the expert information collection ends, after gathering all necessary data from selected sources.""" + if info: + urls = [i.url for i in info] + information_string = "\n".join([f"Finish browsing {url}" for url in urls]) + print(information_string) + + def on_expert_utterance_generation_end(self, **kwargs): + """Run when the expert utterance generation ends, before creating responses or statements from each expert.""" + print("Finish generating utterance from collected information.") + + def on_expert_utterance_polishing_start(self, **kwargs): + """Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content.""" + print("Start polishing utterance.") + + def on_mindmap_insert_start(self, **kwargs): + """Run when the process of inserting new information into the mindmap starts.""" + print("Start inserting information into mind map.") + + def on_mindmap_insert_end(self, **kwargs): + """Run when the process of inserting new information into the mindmap ends.""" + print("Finish inserting information into mind map.") + + def on_mindmap_reorg_start(self, **kwargs): + """Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information.""" + print("Start re-organizing mind map.") + + def on_expert_list_update_start(self, **kwargs): + """Run when the expert list update starts, to modify or refresh the list of active experts.""" + print("Start updating expert candidates.") + + def on_warmstart_update(self, message, **kwargs): + """Run when the warm start process has update.""" + print(f"Warm start update: {message}") diff --git a/knowledge_storm/collaborative_storm/modules/co_storm_agents.py b/knowledge_storm/collaborative_storm/modules/co_storm_agents.py new file mode 100644 index 00000000..e7f60299 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/co_storm_agents.py @@ -0,0 +1,381 @@ +import dspy +from itertools import zip_longest +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +from typing import List, Optional, TYPE_CHECKING + +from .callback import BaseCallbackHandler +from .collaborative_storm_utils import ( + extract_storm_info_snippet, + _get_answer_question_module_instance, +) +from .costorm_expert_utterance_generator import CoStormExpertUtteranceGenerationModule +from .grounded_question_generation import GroundedQuestionGenerationModule +from .simulate_user import GenSimulatedUserUtterance +from ...dataclass import ConversationTurn, KnowledgeBase +from ...encoder import get_text_embeddings +from ...interface import Agent, Information, LMConfigs +from ...logging_wrapper import LoggingWrapper + +if TYPE_CHECKING: + from ..engine import RunnerArgument + + +class CoStormExpert(Agent): + """ + Represents an expert agent in the Co-STORM framework. + The `CoStormExpert` is a specialized type of `Agent` that is tasked with participating in roundtable discussions within the Co-STORM system. + The expert uses language models to generate action plans, answer questions, and polish its utterances based on the current conversation history and knowledge base. + It interacts with modules for action planning and question answering grounding on provided retrieval models. + + Args: + topic (str): The conversation topic that the expert specializes in. + role_name (str): The perspective of the expert's role (e.g. AI enthusiast, drug discovery expert, etc.) + role_description (str): A description of the perspective of the experts + lm_config (LMConfigs): Configuration for the language models + runner_argument (RunnerArgument): Co-STORM runner argument + logging_wrapper (LoggingWrapper): An instance of `LoggingWrapper` to log events. + rm (Optional[dspy.Retrieve], optional): A retrieval module used for fetching external knowledge or context. + callback_handler (BaseCallbackHandler, optional): Handles log message printing + """ + + def __init__( + self, + topic: str, + role_name: str, + role_description: str, + lm_config: LMConfigs, + runner_argument: "RunnerArgument", + logging_wrapper: LoggingWrapper, + rm: Optional[dspy.Retrieve] = None, + callback_handler: BaseCallbackHandler = None, + ): + super().__init__(topic, role_name, role_description) + self.lm_config = lm_config + self.runner_argument = runner_argument + self.logging_wrapper = logging_wrapper + self.callback_handler = callback_handler + self.costorm_agent_utterance_generator = ( + self._get_costorm_expert_utterance_generator(rm=rm) + ) + + def _get_costorm_expert_utterance_generator( + self, rm: Optional[dspy.Retrieve] = None + ): + return CoStormExpertUtteranceGenerationModule( + action_planning_lm=self.lm_config.discourse_manage_lm, + utterance_polishing_lm=self.lm_config.utterance_polishing_lm, + answer_question_module=_get_answer_question_module_instance( + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=rm, + ), + logging_wrapper=self.logging_wrapper, + callback_handler=self.callback_handler, + ) + + def generate_utterance( + self, + knowledge_base: KnowledgeBase, + conversation_history: List[ConversationTurn], + ): + with self.logging_wrapper.log_event( + "CoStormExpert generate utternace: get knowledge base summary" + ): + if self.callback_handler is not None: + self.callback_handler.on_expert_action_planning_start() + conversation_summary = knowledge_base.get_knowledge_base_summary() + with self.logging_wrapper.log_event( + "CoStormExpert.generate_utterance generate utterance" + ): + last_conv_turn = conversation_history[-1] + conv_turn = self.costorm_agent_utterance_generator( + topic=self.topic, + current_expert=self.get_role_description(), + conversation_summary=conversation_summary, + last_conv_turn=last_conv_turn, + ).conversation_turn + with self.logging_wrapper.log_event( + "CoStormExpert generate utterance: polish utterance" + ): + if self.callback_handler is not None: + self.callback_handler.on_expert_utterance_polishing_start() + self.costorm_agent_utterance_generator.polish_utterance( + conversation_turn=conv_turn, last_conv_turn=last_conv_turn + ) + return conv_turn + + +class SimulatedUser(Agent): + """ + Simulated Users is a special type of Agent in Co-STORM that simulates real user interaction behavior based on the given intent. + + This class can be used for automatic experiments. + For more information, please refer to Section 3.4 of Co-STORM paper: https://www.arxiv.org/pdf/2408.15232 + """ + + def __init__( + self, + topic: str, + role_name: str, + role_description: str, + intent: str, + lm_config: LMConfigs, + runner_argument: "RunnerArgument", + logging_wrapper: LoggingWrapper, + callback_handler: BaseCallbackHandler = None, + ): + super().__init__(topic, role_name, role_description) + self.intent = intent + self.lm_config = lm_config + self.runner_argument = runner_argument + self.logging_wrapper = logging_wrapper + self.gen_simulated_user_utterance = GenSimulatedUserUtterance( + engine=self.lm_config.question_answering_lm + ) + self.callback_handler = callback_handler + + def generate_utterance( + self, + knowledge_base: KnowledgeBase, + conversation_history: List[ConversationTurn], + ): + assert ( + self.intent is not None and self.intent + ), "Simulate user intent is not initialized." + + with self.logging_wrapper.log_event( + "SimulatedUser generate utternace: generate utterance" + ): + utterance = self.gen_simulated_user_utterance( + topic=self.topic, intent=self.intent, conv_history=conversation_history + ) + return ConversationTurn( + role="Guest", raw_utterance=utterance, utterance_type="Original Question" + ) + + +class Moderator(Agent): + """ + The moderator's role in the Co-STORM framework is to inject new perspectives into the conversation to avoid stagnation, repetition, or overly niche discussions. + This is achieved by generating questions based on unused, uncited snippets of information retrieved since the last moderator's turn. + The selected information is reranked according to its relevance to the conversation topic and its dissimilarity to the original question. + The resulting top-ranked snippets are used to generate an informed question to be presented to the conversation participants. + + For more information, please refer to Section 3.5 of Co-STORM paper: https://www.arxiv.org/pdf/2408.15232 + """ + + def __init__( + self, + topic: str, + role_name: str, + role_description: str, + lm_config: LMConfigs, + runner_argument: "RunnerArgument", + logging_wrapper: LoggingWrapper, + callback_handler: BaseCallbackHandler = None, + ): + super().__init__(topic, role_name, role_description) + self.lm_config = lm_config + self.runner_argument = runner_argument + self.logging_wrapper = logging_wrapper + self.grounded_question_generation_module = GroundedQuestionGenerationModule( + engine=self.lm_config.question_asking_lm + ) + self.callback_handler = callback_handler + + def _get_conv_turn_unused_information( + self, conv_turn: ConversationTurn, knowledge_base: KnowledgeBase + ): + # extract all snippets from raw retrieved information + raw_retrieved_info: List[Information] = conv_turn.raw_retrieved_info + raw_retrieved_single_snippet_info: List[Information] = [] + for info in raw_retrieved_info: + for snippet_idx in range(len(info.snippets)): + raw_retrieved_single_snippet_info.append( + extract_storm_info_snippet(info, snippet_index=snippet_idx) + ) + # get all cited information + cited_info = list(knowledge_base.info_uuid_to_info_dict.values()) + cited_info_hash_set = set([hash(info) for info in cited_info]) + cited_snippets = [info.snippets[0] for info in cited_info] + # get list of unused information + unused_information: List[Information] = [ + info + for info in raw_retrieved_single_snippet_info + if hash(info) not in cited_info_hash_set + ] + if not unused_information: + return [] + # extract snippets to get embeddings + unused_information_snippets = [info.snippets[0] for info in unused_information] + # get embeddings + cache = knowledge_base.embedding_cache + unused_snippets_embeddings, _ = get_text_embeddings( + unused_information_snippets, embedding_cache=cache, max_workers=100 + ) + claim_embedding, _ = get_text_embeddings( + conv_turn.claim_to_make, embedding_cache=cache + ) + query_embedding, _ = get_text_embeddings( + conv_turn.queries, embedding_cache=cache + ) + cited_snippets_embedding, _ = get_text_embeddings( + cited_snippets, embedding_cache=cache + ) + # calculate similarity + query_similarities = cosine_similarity( + unused_snippets_embeddings, query_embedding + ) + max_query_similarity = np.max(query_similarities, axis=1) + cited_snippets_similarity = np.max( + cosine_similarity(unused_snippets_embeddings, cited_snippets_embedding), + axis=1, + ) + cited_snippets_similarity = np.clip(cited_snippets_similarity, 0, 1) + # use claim similarity to filter out "real" not useful data + claim_similarity = cosine_similarity( + unused_snippets_embeddings, claim_embedding.reshape(1, -1) + ).flatten() + claim_similarity = np.where(claim_similarity >= 0.25, 1.0, 0.0) + # calculate score: snippet that is close to topic but far from query + query_sim_weight = 0.5 + cited_snippets_sim_weight = 1 - query_sim_weight + combined_scores = ( + ((1 - max_query_similarity) ** query_sim_weight) + * ((1 - cited_snippets_similarity) ** cited_snippets_sim_weight) + * claim_similarity + ) + sorted_indices = np.argsort(combined_scores)[::-1] + return [unused_information[idx] for idx in sorted_indices] + + def _get_sorted_unused_snippets( + self, + knowledge_base: KnowledgeBase, + conversation_history: List[ConversationTurn], + last_n_conv_turn: int = 2, + ): + # get last N conv turn and batch encode all related strings + considered_conv_turn = [] + batch_snippets = [self.topic] + for conv_turn in reversed(conversation_history): + if len(considered_conv_turn) == last_n_conv_turn: + break + if conv_turn.utterance_type == "Questioning": + break + considered_conv_turn.append(conv_turn) + batch_snippets.extend( + sum([info.snippets for info in conv_turn.raw_retrieved_info], []) + ) + batch_snippets.append(conv_turn.claim_to_make) + batch_snippets.extend(conv_turn.queries) + cache = knowledge_base.embedding_cache + get_text_embeddings(batch_snippets, embedding_cache=cache, max_workers=300) + + # get sorted unused snippets for each turn + sorted_snippets = [] + for conv_turn in considered_conv_turn: + sorted_snippets.append( + self._get_conv_turn_unused_information( + conv_turn=conv_turn, knowledge_base=knowledge_base + ) + ) + + # use round robin rule to merge these snippets + merged_snippets = [] + for elements in zip_longest(*sorted_snippets, fillvalue=None): + merged_snippets.extend(e for e in elements if e is not None) + return merged_snippets + + def generate_utterance( + self, + knowledge_base: KnowledgeBase, + conversation_history: List[ConversationTurn], + ): + with self.logging_wrapper.log_event( + "Moderator generate utternace: get unused snippets" + ): + unused_snippets: List[Information] = self._get_sorted_unused_snippets( + knowledge_base=knowledge_base, conversation_history=conversation_history + ) + with self.logging_wrapper.log_event( + "Moderator generate utternace: QuestionGeneration module" + ): + generated_question = self.grounded_question_generation_module( + topic=self.topic, + knowledge_base=knowledge_base, + last_conv_turn=conversation_history[-1], + unused_snippets=unused_snippets, + ) + return ConversationTurn( + role=self.role_name, + raw_utterance=generated_question.raw_utterance, + utterance_type="Original Question", + utterance=generated_question.utterance, + cited_info=generated_question.cited_info, + ) + + +class PureRAGAgent(Agent): + """ + PureRAGAgent only handles grounded question generation by retrieving information from the retriever based on the query. + It does not utilize any other information besides the query itself. + + It's designed for Co-STORM paper baseline comparison. + """ + + def __init__( + self, + topic: str, + role_name: str, + role_description: str, + lm_config: LMConfigs, + runner_argument: "RunnerArgument", + logging_wrapper: LoggingWrapper, + rm: Optional[dspy.Retrieve] = None, + callback_handler: BaseCallbackHandler = None, + ): + super().__init__(topic, role_name, role_description) + self.lm_config = lm_config + self.runner_argument = runner_argument + self.logging_wrapper = logging_wrapper + self.grounded_question_answering_module = _get_answer_question_module_instance( + lm_config=self.lm_config, + runner_argument=self.runner_argument, + logging_wrapper=self.logging_wrapper, + rm=rm, + ) + + def _gen_utterance_from_question(self, question: str): + grounded_answer = self.grounded_question_answering_module( + topic=self.topic, + question=question, + mode="brief", + style="conversational and concise", + ) + conversation_turn = ConversationTurn( + role=self.role_name, raw_utterance="", utterance_type="Potential Answer" + ) + conversation_turn.claim_to_make = question + conversation_turn.raw_utterance = grounded_answer.response + conversation_turn.utterance = grounded_answer.response + conversation_turn.queries = grounded_answer.queries + conversation_turn.raw_retrieved_info = grounded_answer.raw_retrieved_info + conversation_turn.cited_info = grounded_answer.cited_info + return conversation_turn + + def generate_topic_background(self): + return self._gen_utterance_from_question(self.topic) + + def generate_utterance( + self, + knowledge_base: KnowledgeBase, + conversation_history: List[ConversationTurn], + ): + with self.logging_wrapper.log_event( + "PureRAGAgent generate utternace: generate utterance" + ): + return self._gen_utterance_from_question( + question=conversation_history[-1].utterance + ) diff --git a/knowledge_storm/collaborative_storm/modules/collaborative_storm_utils.py b/knowledge_storm/collaborative_storm/modules/collaborative_storm_utils.py new file mode 100644 index 00000000..d337de60 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/collaborative_storm_utils.py @@ -0,0 +1,261 @@ +import dspy +import os +import re +import sys +import toml +from typing import List, Tuple, Dict, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ..engine import RunnerArgument +from ...interface import Information, Retriever, LMConfigs +from ...logging_wrapper import LoggingWrapper +from ...rm import BingSearch + + +def extract_storm_info_snippet(info: Information, snippet_index: int) -> Information: + """ + Constructs a new Information instance with only the specified snippet index. + + Args: + storm_info (Information): The original Information instance. + snippet_index (int): The index of the snippet to retain. + + Returns: + Information: A new Information instance with only the specified snippet. + """ + if snippet_index < 0 or snippet_index >= len(info.snippets): + raise ValueError("Snippet index out of range") + + new_snippets = [info.snippets[snippet_index]] + new_storm_info = Information( + info.url, info.description, new_snippets, info.title, info.meta + ) + return new_storm_info + + +def format_search_results( + searched_results: List[Information], + info_max_num_words: int = 1000, + mode: str = "brief", +) -> Tuple[str, Dict[int, Information]]: + """ + Constructs a string from a list of search results with a specified word limit and returns a mapping of indices to Information. + + Args: + searched_results (List[Information]): List of Information objects to process. + info_max_num_words (int, optional): Maximum number of words allowed in the output string. Defaults to 1000. + mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information. + 'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'. + + Returns: + Tuple[str, Dict[int, Information]]: + - Formatted string with search results, constrained by the word limit. + - Dictionary mapping indices to the corresponding Information objects. + """ + total_length = 0 + + extracted_snippet_queue = [] + max_snippets = ( + max(len(info.snippets) for info in searched_results) if searched_results else 0 + ) + max_snippets = 1 if mode == "brief" else max_snippets + abort = False + included_snippets = set() + for i in range(max_snippets): + for info in searched_results: + if i < len(info.snippets) and not abort: + cur_snippet = info.snippets[i] + cur_snippet_len = len(info.snippets[i].split()) + if total_length + cur_snippet_len > info_max_num_words: + abort = True + break + if cur_snippet not in included_snippets: + included_snippets.add(cur_snippet) + info = extract_storm_info_snippet(info, snippet_index=i) + extracted_snippet_queue.append(info) + total_length += cur_snippet_len + output = [] + index_mapping = {} + for idx, info in enumerate(extracted_snippet_queue): + output.append(f"[{idx + 1}]: {info.snippets[0]}") + index_mapping[idx + 1] = info + assert -1 not in index_mapping + return "\n".join(output), index_mapping + + +def extract_cited_storm_info( + response: str, index_to_storm_info: Dict[int, Information] +) -> Dict[int, Information]: + """ + Extracts a sub-dictionary of Information instances that are cited in the response. + + Args: + response (str): The response string containing inline citations like [1], [2], etc. + index_to_storm_info (Dict[int, Information]): A dictionary mapping indices to Information instances. + + Returns: + Dict[int, Information]: A sub-dictionary with only the indices that appear in the response. + """ + cited_indices = set(map(int, re.findall(r"\[(\d+)\]", response))) + cited_storm_info = { + index: info + for index, info in index_to_storm_info.items() + if index in cited_indices + } + return cited_storm_info + + +def trim_output_after_hint(response: str, hint: str) -> str: + """ + Trims the output string to only keep the substring after the given hint (not including the hint). + + Args: + response (str): The original output string. + hint (str): The hint string after which the substring should be kept. + + Returns: + str: The trimmed output string, or the original string if the hint is not found. + """ + if hint in response: + start_index = response.find(hint) + len(hint) + return response[start_index:].strip() + return response.strip("\n") + + +def separate_citations(text: str) -> str: + """ + Separates multiple citations within square brackets into individual citations. + + Args: + text (str): The input string containing citations. + + Returns: + str: The string with separated citations. + """ + + # Define a function to process each match + def replace_citations(match): + citations = match.group(1).split(",") + return "".join(f"[{citation.strip()}]" for citation in citations) + + # Use regular expressions to find and replace citations + pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]") + return pattern.sub(replace_citations, text) + + +def extract_and_remove_citations(text: str) -> Tuple[str, List[int]]: + """ + Removes single inline citations from the input string and returns the modified string and a list of citation integers. + + Args: + text (str): The input string containing citations. + + Returns: + Tuple[str, List[int]]: The string after removal of citations and a list of citation integers. + """ + citations = [] + + # Define a function to process each match + def extract_citation(match): + citation = int(match.group(1)) + citations.append(citation) + return "" + + # Use regular expressions to find and replace citations + pattern = re.compile(r"\[(\d+)\]") + modified_text = pattern.sub(extract_citation, text) + + return modified_text, citations + + +def keep_first_and_last_paragraph(text: str) -> str: + """ + Processes the input text to keep the first and last paragraphs and replace + the middle paragraphs with '[content omitted due to space limit]'. + + Args: + text (str): The input text containing paragraphs separated by '\n\n'. + + Returns: + str: The processed text. + """ + paragraphs = text.split("\n\n") + + if len(paragraphs) <= 3: + return text + + first_paragraph = paragraphs[0] + last_paragraph = "\n\n".join(paragraphs[-2:]) + return ( + f"{first_paragraph}\n\n[content omitted due to space limit]\n\n{last_paragraph}" + ) + + +def clean_up_section(text): + """Clean up a section: + 1. Remove uncompleted sentences (usually due to output token limitation). + 2. Deduplicate individual groups of citations. + 3. Remove unnecessary summary.""" + + paragraphs = text.split("\n") + output_paragraphs = [] + summary_sec_flag = False + for p in paragraphs: + p = p.strip() + if len(p) == 0: + continue + if not p.startswith("#"): + p = separate_citations(p) + if summary_sec_flag: + if p.startswith("#"): + summary_sec_flag = False + else: + continue + if ( + p.startswith("Overall") + or p.startswith("In summary") + or p.startswith("In conclusion") + ): + continue + if "# Summary" in p or "# Conclusion" in p: + summary_sec_flag = True + continue + output_paragraphs.append(p) + + return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format. + + +def load_api_key(toml_file_path): + try: + with open(toml_file_path, "r") as file: + data = toml.load(file) + except FileNotFoundError: + print(f"File not found: {toml_file_path}", file=sys.stderr) + return + except toml.TomlDecodeError: + print(f"Error decoding TOML file: {toml_file_path}", file=sys.stderr) + return + # Set environment variables + for key, value in data.items(): + os.environ[key] = str(value) + + +def _get_answer_question_module_instance( + lm_config: LMConfigs, + runner_argument: "RunnerArgument", + logging_wrapper: LoggingWrapper, + rm: Optional[dspy.Retrieve] = None, +): + from .grounded_question_answering import AnswerQuestionModule + + # configure retriever + if rm is None: + rm = BingSearch(k=runner_argument.retrieve_top_k) + retriever = Retriever(rm=rm, max_thread=runner_argument.max_search_thread) + # return AnswerQuestionModule instance + return AnswerQuestionModule( + retriever=retriever, + max_search_queries=runner_argument.max_search_queries, + question_answering_lm=lm_config.question_answering_lm, + logging_wrapper=logging_wrapper, + ) diff --git a/knowledge_storm/collaborative_storm/modules/costorm_expert_utterance_generator.py b/knowledge_storm/collaborative_storm/modules/costorm_expert_utterance_generator.py new file mode 100644 index 00000000..2c7c3505 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/costorm_expert_utterance_generator.py @@ -0,0 +1,160 @@ +import dspy +from typing import Union + +from .callback import BaseCallbackHandler +from .collaborative_storm_utils import ( + trim_output_after_hint, + extract_and_remove_citations, + keep_first_and_last_paragraph, +) + +from .grounded_question_answering import AnswerQuestionModule +from .grounded_question_generation import ConvertUtteranceStyle +from ...dataclass import ConversationTurn +from ...logging_wrapper import LoggingWrapper + + +class GenExpertActionPlanning(dspy.Signature): + """ + You are an invited speaker in the round table conversation. Your task is to make a very short note to your assistant to help you prepare for your turn in the conversation. + You will be given the topic we are discussing, your expertise, and the conversation history. + Take a look at conversation history, especially last few turns, then let your assistant prepare the material for you with one of following ways. + 1. Original Question: Initiates a new question to other speakers. + 2. Further Details: Provides additional information. + 3. Information Request: Requests information from other speakers. + 4. Potential Answer: Offers a possible solution or answer. + + Strictly follow this format: [type of contribution]: [one sentence description]. For example, Original Question: [description] + """ + + topic = dspy.InputField(prefix="topic of discussion: ", format=str) + expert = dspy.InputField(prefix="You are inivited as: ", format=str) + summary = dspy.InputField(prefix="Discussion history: \n", format=str) + last_utterance = dspy.InputField( + prefix="Last utterance in the conversation: \n", format=str + ) + resposne = dspy.OutputField( + prefix="Now give your note. Start with one of [Original Question, Further Details, Information Request, Potential Answer] with one sentence description\n", + format=str, + ) + + +class CoStormExpertUtteranceGenerationModule(dspy.Module): + def __init__( + self, + action_planning_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + utterance_polishing_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + answer_question_module: AnswerQuestionModule, + logging_wrapper: LoggingWrapper, + callback_handler: BaseCallbackHandler = None, + ): + self.action_planning_lm = action_planning_lm + self.utterance_polishing_lm = utterance_polishing_lm + self.expert_action = dspy.Predict(GenExpertActionPlanning) + self.change_style = dspy.Predict(ConvertUtteranceStyle) + self.answer_question_module = answer_question_module + self.logging_wrapper = logging_wrapper + self.callback_handler = callback_handler + + def parse_action(self, action): + action_types = [ + "Original Question", + "Further Details", + "Information Request", + "Potential Answer", + ] + for action_type in action_types: + if f"{action_type}:" in action: + return action_type, trim_output_after_hint(action, f"{action_type}:") + elif f"[{action_type}]:" in action: + return action_type, trim_output_after_hint(action, f"[{action_type}]:") + return "Undefined", "" + + def polish_utterance( + self, conversation_turn: ConversationTurn, last_conv_turn: ConversationTurn + ): + # change utterance style + action_type = conversation_turn.utterance_type + with self.logging_wrapper.log_event( + "RoundTableConversationModule.ConvertUtteranceStyle" + ): + with dspy.settings.context( + lm=self.utterance_polishing_lm, show_guidelines=False + ): + action_string = ( + f"{action_type} about: {conversation_turn.claim_to_make}" + ) + if action_type in ["Original Question", "Information Request"]: + action_string = f"{action_type}" + last_expert_utterance_wo_citation, _ = extract_and_remove_citations( + last_conv_turn.utterance + ) + trimmed_last_expert_utterance = keep_first_and_last_paragraph( + last_expert_utterance_wo_citation + ) + utterance = self.change_style( + expert=conversation_turn.role, + action=action_string, + prev=trimmed_last_expert_utterance, + content=conversation_turn.raw_utterance, + ).utterance + conversation_turn.utterance = utterance + + def forward( + self, + topic: str, + current_expert: str, + conversation_summary: str, + last_conv_turn: ConversationTurn, + ): + last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance) + if last_conv_turn.utterance_type in [ + "Original Question", + "Information Request", + ]: + action_type = "Potential Answer" + action_content = last_utterance + else: + with self.logging_wrapper.log_event( + "CoStormExpertUtteranceGenerationModule: GenExpertActionPlanning" + ): + with dspy.settings.context( + lm=self.action_planning_lm, show_guidelines=False + ): + action = self.expert_action( + topic=topic, + expert=current_expert, + summary=conversation_summary, + last_utterance=last_utterance, + ).resposne + action_type, action_content = self.parse_action(action) + + if self.callback_handler is not None: + self.callback_handler.on_expert_action_planning_end() + # get response + conversation_turn = ConversationTurn( + role=current_expert, raw_utterance="", utterance_type=action_type + ) + + if action_type == "Undefined": + raise Exception(f"unexpected output: {action}") + elif action_type in ["Further Details", "Potential Answer"]: + with self.logging_wrapper.log_event( + "RoundTableConversationModule: QuestionAnswering" + ): + grounded_answer = self.answer_question_module( + topic=topic, + question=action_content, + mode="brief", + style="conversational and concise", + callback_handler=self.callback_handler, + ) + conversation_turn.claim_to_make = action_content + conversation_turn.raw_utterance = grounded_answer.response + conversation_turn.queries = grounded_answer.queries + conversation_turn.raw_retrieved_info = grounded_answer.raw_retrieved_info + conversation_turn.cited_info = grounded_answer.cited_info + elif action_type in ["Original Question", "Information Request"]: + conversation_turn.raw_utterance = action_content + + return dspy.Prediction(conversation_turn=conversation_turn) diff --git a/knowledge_storm/collaborative_storm/modules/expert_generation.py b/knowledge_storm/collaborative_storm/modules/expert_generation.py new file mode 100644 index 00000000..a95915a6 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/expert_generation.py @@ -0,0 +1,83 @@ +import dspy +import re +from typing import Union + + +class GenerateExpertGeneral(dspy.Signature): + """You need to select a group of diverse experts who will be suitable to be invited to a roundtable discussion on the given topic. + Each expert should represent a different perspective, role, or affiliation related to this topic. + You can use the background information provided about the topic for inspiration. For each expert, add a description of their expertise and what they will focus on during the discussion. + No need to include speakers name in the output. + Strictly follow format below: + 1. [speaker 1 role]: [speaker 1 short description] + 2. [speaker 2 role]: [speaker 2 short description] + """ + + topic = dspy.InputField(prefix="Topic of interest:", format=str) + background_info = dspy.InputField( + prefix="Background information about the topic:\n", format=str + ) + topN = dspy.InputField(prefix="Number of speakers needed: ", format=str) + experts = dspy.OutputField(format=str) + + +class GenerateExpertWithFocus(dspy.Signature): + """ + You need to select a group of speakers who will be suitable to have roundtable discussion on the [topic] of specific [focus]. + You may consider inviting speakers having opposite stands on the topic; speakers representing different interest parties; Ensure that the selected speakers are directly connected to the specific context and scenario provided. + For example, if the discussion focus is about a recent event at a specific university, consider inviting students, faculty members, journalists covering the event, university officials, and local community members. + Use the background information provided about the topic for inspiration. For each speaker, add a description of their interests and what they will focus on during the discussion. + No need to include speakers name in the output. + Strictly follow format below: + 1. [speaker 1 role]: [speaker 1 short description] + 2. [speaker 2 role]: [speaker 2 short description] + """ + + topic = dspy.InputField(prefix="Topic of interest:", format=str) + background_info = dspy.InputField(prefix="Background information:\n", format=str) + focus = dspy.InputField(prefix="Discussion focus: ", format=str) + topN = dspy.InputField(prefix="Number of speakers needed: ", format=str) + experts = dspy.OutputField(format=str) + + +class GenerateExpertModule(dspy.Module): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.engine = engine + self.generate_expert_general = dspy.Predict(GenerateExpertGeneral) + self.generate_expert_w_focus = dspy.ChainOfThought(GenerateExpertWithFocus) + + def trim_background(self, background: str, max_words: int = 100): + words = background.split() + cur_len = len(words) + if cur_len <= max_words: + return background + trimmed_words = words[: min(cur_len, max_words)] + trimmed_background = " ".join(trimmed_words) + return f"{trimmed_background} [rest content omitted]." + + def forward( + self, topic: str, num_experts: int, background_info: str = "", focus: str = "" + ): + with dspy.settings.context(lm=self.engine, show_guidelines=False): + if not focus: + output = self.generate_expert_general( + topic=topic, background_info=background_info, topN=num_experts + ).experts + else: + background_info = self.trim_background( + background=background_info, max_words=100 + ) + output = self.generate_expert_w_focus( + topic=topic, + background_info=background_info, + focus=focus, + topN=num_experts, + ).experts + output = output.replace("*", "").replace("[", "").replace("]", "") + expert_list = [] + for s in output.split("\n"): + match = re.search(r"\d+\.\s*(.*)", s) + if match: + expert_list.append(match.group(1)) + expert_list = [expert.strip() for expert in expert_list if expert.strip()] + return dspy.Prediction(experts=expert_list, raw_output=output) diff --git a/knowledge_storm/collaborative_storm/modules/grounded_question_answering.py b/knowledge_storm/collaborative_storm/modules/grounded_question_answering.py new file mode 100644 index 00000000..4065018b --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/grounded_question_answering.py @@ -0,0 +1,163 @@ +import dspy +from typing import Union, List + +from .callback import BaseCallbackHandler +from .collaborative_storm_utils import ( + trim_output_after_hint, + format_search_results, + extract_cited_storm_info, + separate_citations, +) +from ...logging_wrapper import LoggingWrapper +from ...utils import ArticleTextProcessing +from ...interface import Information + + +class QuestionToQuery(dspy.Signature): + """You want to answer the question or support a claim using Google search. What do you type in the search box? + The question is raised in a round table discussion on a topic. The question may or may not focus on the topic itself. + Write the queries you will use in the following format: + - query 1 + - query 2 + ... + - query n""" + + topic = dspy.InputField(prefix="Topic context:", format=str) + question = dspy.InputField( + prefix="I want to collect information about: ", format=str + ) + queries = dspy.OutputField(prefix="Queries: \n", format=str) + + +class AnswerQuestion(dspy.Signature): + """You are an expert who can use information effectively. You have gathered the related information and will now use the information to form a response. + Make your response as informative as possible and make sure every sentence is supported by the gathered information. + If [Gathered information] is not directly related to the [Topic] and [Question], provide the most relevant answer you can based on the available information, and explain any limitations or gaps. + Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). + You DO NOT need to include a References or Sources section to list the sources at the end. The style of writing should be formal. + """ + + topic = dspy.InputField(prefix="Topic you are discussing about:", format=str) + question = dspy.InputField(prefix="You want to provide insight on: ", format=str) + info = dspy.InputField(prefix="Gathered information:\n", format=str) + style = dspy.InputField(prefix="Style of your response should be:", format=str) + answer = dspy.OutputField( + prefix="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)", + format=str, + ) + + +class AnswerQuestionModule(dspy.Module): + def __init__( + self, + retriever: dspy.Retrieve, + max_search_queries: int, + question_answering_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + logging_wrapper: LoggingWrapper, + ): + super().__init__() + self.question_answering_lm = question_answering_lm + self.question_to_query = dspy.Predict(QuestionToQuery) + self.answer_question = dspy.Predict(AnswerQuestion) + self.retriever = retriever + self.max_search_queries = max_search_queries + self.logging_wrapper = logging_wrapper + + def retrieve_information(self, topic, question): + # decompose question to queries + with self.logging_wrapper.log_event( + f"AnswerQuestionModule.question_to_query ({hash(question)})" + ): + with dspy.settings.context(lm=self.question_answering_lm): + queries = self.question_to_query(topic=topic, question=question).queries + queries = trim_output_after_hint(queries, hint="Queries:") + queries = [ + q.replace("-", "").strip().strip('"').strip('"').strip() + for q in queries.split("\n") + ] + queries = queries[: self.max_search_queries] + self.logging_wrapper.add_query_count(count=len(queries)) + with self.logging_wrapper.log_event( + f"AnswerQuestionModule.retriever.retrieve ({hash(question)})" + ): + # retrieve information using retriever + searched_results: List[Information] = self.retriever.retrieve( + list(set(queries)), exclude_urls=[] + ) + # update storm information meta to include the question + for storm_info in searched_results: + storm_info.meta["question"] = question + return queries, searched_results + + def forward( + self, + topic: str, + question: str, + mode: str = "brief", + style: str = "conversational", + callback_handler: BaseCallbackHandler = None, + ): + """ + Processes a topic and question to generate a response with relevant information and citations. + + Args: + topic (str): The topic of interest. + question (str): The specific question related to the topic. + mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information. + 'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'. + + Returns: + dspy.Prediction: An object containing the following: + - question (str): the question to answer + - queries (List[str]): List of query strings used for information retrieval. + - raw_retrieved_info (List[Information]): List of Information instances retrieved. + - cited_info (Dict[int, Information]): Dictionary of cited Information instances, indexed by their citation number. + - response (str): The generated response string with inline citations. + """ + # retrieve information + if callback_handler is not None: + callback_handler.on_expert_information_collection_start() + queries, searched_results = self.retrieve_information( + topic=topic, question=question + ) + if callback_handler is not None: + callback_handler.on_expert_information_collection_end(searched_results) + # format information string for answer generation + info_text, index_to_information_mapping = format_search_results( + searched_results, mode=mode + ) + answer = "Sorry, there is insufficient information to answer the question." + # generate answer to the question + if info_text: + with self.logging_wrapper.log_event( + f"AnswerQuestionModule.answer_question ({hash(question)})" + ): + with dspy.settings.context( + lm=self.question_answering_lm, show_guidelines=False + ): + answer = self.answer_question( + topic=topic, question=question, info=info_text, style=style + ).answer + answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( + answer + ) + answer = trim_output_after_hint( + answer, + hint="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)", + ) + # enforce single citation index bracket. [1, 2] -> [1][2] + answer = separate_citations(answer) + if callback_handler is not None: + callback_handler.on_expert_utterance_generation_end() + # construct cited search result + cited_searched_results = extract_cited_storm_info( + response=answer, index_to_storm_info=index_to_information_mapping + ) + + return dspy.Prediction( + question=question, + queries=queries, + raw_retrieved_info=searched_results, + cited_info=cited_searched_results, + response=answer, + ) diff --git a/knowledge_storm/collaborative_storm/modules/grounded_question_generation.py b/knowledge_storm/collaborative_storm/modules/grounded_question_generation.py new file mode 100644 index 00000000..331692ca --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/grounded_question_generation.py @@ -0,0 +1,113 @@ +""" +This module handles question generation within the Co-STORM framework, specifically designed to support the Moderator role. + +The Moderator generates insightful, thought-provoking questions that introduce new directions into the conversation. +By leveraging uncited or unused snippets of information retrieved during the discussion, the Moderator ensures the conversation remains dynamic and avoids repetitive or overly niche topics. + +For more detailed information, refer to Section 3.5 of the Co-STORM paper: https://www.arxiv.org/pdf/2408.15232. +""" + +import dspy +from typing import List, Union + +from .collaborative_storm_utils import ( + format_search_results, + extract_and_remove_citations, + keep_first_and_last_paragraph, + extract_cited_storm_info, +) +from ...dataclass import ConversationTurn, KnowledgeBase +from ...interface import Information + + +class KnowledgeBaseSummmary(dspy.Signature): + """Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections. + You will be presented with these sections where "#" denotes level of section. + """ + + topic = dspy.InputField(prefix="topic: ", format=str) + structure = dspy.InputField(prefix="Tree structure: \n", format=str) + output = dspy.OutputField(prefix="Now give brief summary:\n", format=str) + + +class ConvertUtteranceStyle(dspy.Signature): + """ + You are an invited speaker in the round table conversation. + Your task is to make the question or the response more conversational and engaging to facilicate the flow of conversation. + Note that this is ongoing conversation so no need to have welcoming and concluding words. Previous speaker utterance is provided only for making the conversation more natural. + Note that do not hallucinate and keep the citation index like [1] as it is. Also, + """ + + expert = dspy.InputField(prefix="You are inivited as: ", format=str) + action = dspy.InputField( + prefix="You want to contribute to conversation by: ", format=str + ) + prev = dspy.InputField(prefix="Previous speaker said: ", format=str) + content = dspy.InputField( + prefix="Question or response you want to say: ", format=str + ) + utterance = dspy.OutputField( + prefix="Your utterance (keep the information as much as you can with citations, prefer shorter answers without loss of information): ", + format=str, + ) + + +class GroundedQuestionGeneration(dspy.Signature): + """Your job is to find next discussion focus in a roundtable conversation. You will be given previous conversation summary and some information that might assist you discover new discussion focus. + Note that the new discussion focus should bring new angle and perspective to the discussion and avoid repetition. The new discussion focus should be grounded on the available information and push the boundaries of the current discussion for broader exploration. + The new discussion focus should have natural flow from last utterance in the conversation. + Use [1][2] in line to ground your question. + """ + + topic = dspy.InputField(prefix="topic: ", format=str) + summary = dspy.InputField(prefix="Discussion history: \n", format=str) + information = dspy.InputField(prefix="Available information: \n", format=str) + last_utterance = dspy.InputField( + prefix="Last utterance in the conversation: \n", format=str + ) + output = dspy.OutputField( + prefix="Now give next discussion focus in the format of one sentence question:\n", + format=str, + ) + + +class GroundedQuestionGenerationModule(dspy.Module): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.engine = engine + self.gen_focus = dspy.Predict(GroundedQuestionGeneration) + self.polish_style = dspy.Predict(ConvertUtteranceStyle) + self.gen_summary = dspy.Predict(KnowledgeBaseSummmary) + + def forward( + self, + topic: str, + knowledge_base: KnowledgeBase, + last_conv_turn: ConversationTurn, + unused_snippets: List[Information], + ): + information, index_to_information_mapping = format_search_results( + unused_snippets, info_max_num_words=1000 + ) + summary = knowledge_base.get_knowledge_base_summary() + last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance) + with dspy.settings.context(lm=self.engine, show_guidelines=False): + raw_utterance = self.gen_focus( + topic=topic, + summary=summary, + information=information, + last_utterance=last_utterance, + ).output + utterance = self.polish_style( + expert="Roundtable conversation moderator", + action="Raising a new question by natural transit from previous utterance.", + prev=keep_first_and_last_paragraph(last_utterance), + content=raw_utterance, + ).utterance + cited_searched_results = extract_cited_storm_info( + response=utterance, index_to_storm_info=index_to_information_mapping + ) + return dspy.Prediction( + raw_utterance=raw_utterance, + utterance=utterance, + cited_info=cited_searched_results, + ) diff --git a/knowledge_storm/collaborative_storm/modules/information_insertion_module.py b/knowledge_storm/collaborative_storm/modules/information_insertion_module.py new file mode 100644 index 00000000..c858671b --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/information_insertion_module.py @@ -0,0 +1,422 @@ +import dspy +import numpy as np +import re +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from sklearn.metrics.pairwise import cosine_similarity +from typing import List, Union, Dict, Optional + +from .collaborative_storm_utils import trim_output_after_hint +from ...dataclass import KnowledgeNode, KnowledgeBase +from ...encoder import get_text_embeddings +from ...interface import Information + + +class InsertInformation(dspy.Signature): + """Your job is to insert the given information to the knowledge base. The knowledge base is a tree based data structure to organize the collection information. Each knowledge node contains information derived from themantically similar question or intent. + To decide the best placement of the information, you will be navigated in this tree based data structure layer by layer. + You will be presented with the question and query leads to ththeis information, and tree structure. + + Output should strictly follow one of options presetned below with no other information. + - 'insert': to place the information under the current node. + - 'step: [child node name]': to step into a specified child node. + - 'create: [new child node name]': to create new child node and insert the info under it. + + Example outputs: + - insert + - step: node2 + - create: node3 + """ + + intent = dspy.InputField( + prefix="Question and query leads to this info: ", format=str + ) + structure = dspy.InputField(prefix="Tree structure: \n", format=str) + choice = dspy.OutputField(prefix="Choice:\n", format=str) + + +class InsertInformationCandidateChoice(dspy.Signature): + """Your job is to insert the given information to the knowledge base. The knowledge base is a tree based data structure to organize the collection information. Each knowledge node contains information derived from themantically similar question or intent. + You will be presented with the question and query leads to this information, and candidate choices of placement. In these choices, -> denotes parent-child relationship. Note that reasonable may not be in these choices. + + If there exists reasonable choice, output "Best placement: [choice index]"; otherwise, output "No reasonable choice". + """ + + intent = dspy.InputField( + prefix="Question and query leads to this info: ", format=str + ) + choices = dspy.InputField(prefix="Candidate placement:\n", format=str) + decision = dspy.OutputField(prefix="Decision:\n", format=str) + + +class InsertInformationModule(dspy.Module): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.engine = engine + self.insert_info = dspy.ChainOfThought(InsertInformation) + self.candidate_choosing = dspy.Predict(InsertInformationCandidateChoice) + + def _construct_intent(self, question: str, query: str): + intent = "" + if query == "Not applicable": + return question + if question: + intent += f"Question: {question}\n" + if query: + intent += f"Query: {query}\n" + if not intent: + intent = "Not available." + return intent + + def _get_navigation_choice( + self, knowledge_node: KnowledgeNode, question: str, query: str + ): + # construct information intent + intent = self._construct_intent(question, query) + # construct current kb structure + structure = f"Current Node: {knowledge_node.name}\n" + child_names = ", ".join(knowledge_node.get_children_names()) + if child_names: + structure += f"Child Nodes: {child_names}" + navigated_path = " -> ".join(knowledge_node.get_path_from_root()) + structure += f"Path you have nagivated: {navigated_path}" + + # get predicted action + with dspy.settings.context(lm=self.engine): + predicted_action = self.insert_info( + intent=intent, structure=structure + ).choice + + # parse action + cleaned_predicted_action = trim_output_after_hint( + predicted_action, "Choice:" + ).strip() + cleaned_predicted_action = cleaned_predicted_action.strip("-").strip() + if cleaned_predicted_action.startswith("insert"): + return "insert", "" + elif cleaned_predicted_action.startswith("step:"): + node_name = trim_output_after_hint(cleaned_predicted_action, "step:") + return "step", node_name + elif cleaned_predicted_action.startswith("create:"): + node_name = trim_output_after_hint(cleaned_predicted_action, "create:") + return "create", node_name + raise Exception( + f"Undefined predicted action in knowledge navigation. {predicted_action}" + ) + + def layer_by_layer_navigation_placement( + self, + knowledge_base: KnowledgeBase, + question: str, + query: str, + allow_create_new_node: bool = False, + root: Optional[KnowledgeNode] = None, + ): + current_node: KnowledgeNode = knowledge_base.root if root is None else root + + while True: + action_type, node_name = self._get_navigation_choice( + knowledge_node=current_node, question=question, query=query + ) + if action_type == "insert": + return dspy.Prediction( + information_placement=" -> ".join( + current_node.get_path_from_root(root) + ), + note="None", + ) + elif action_type == "step": + for child in current_node.children: + if child.name == node_name: + current_node = child + break + else: + raise ValueError(f"Child node with name {node_name} not found.") + elif action_type == "create": + placement_path = current_node.get_path_from_root(root) + if allow_create_new_node: + placement_path.append(node_name) + note = f"create new node: {{{node_name}}} under {{{current_node.name}}}" + else: + note = f"attempt to create new node: {{{node_name}}} under {{{current_node.name}}}" + return dspy.Prediction( + information_placement=" -> ".join(placement_path), note=note + ) + else: + raise ValueError(f"Unknown action type: {action_type}") + + def _get_sorted_embed_sim_section( + self, + encoded_outline: np.ndarray, + outlines: List[str], + question: str, + query: str, + ): + if encoded_outline is not None and encoded_outline.size > 0: + encoded_query, token_usage = get_text_embeddings(f"{question}, {query}") + sim = cosine_similarity([encoded_query], encoded_outline)[0] + sorted_indices = np.argsort(sim) + sorted_outlines = np.array(outlines)[sorted_indices[::-1]] + return sorted_outlines + else: + return outlines + + def _parse_selected_index(self, string: str): + match = re.search(r"\[(\d+)\]", string) + if match: + return int(match.group(1)) + try: + return int(string.strip()) + except: + pass + return None + + def choose_candidate_from_embedding_ranking( + self, + question: str, + query: str, + encoded_outlines: np.ndarray, + outlines: List[str], + top_N_candidates: int = 5, + ): + sorted_candidates = self._get_sorted_embed_sim_section( + encoded_outlines, outlines, question, query + ) + considered_candidates = sorted_candidates[ + : min(len(sorted_candidates), top_N_candidates) + ] + choices_string = "\n".join( + [ + f"{idx + 1}: {candidate}" + for idx, candidate in enumerate(considered_candidates) + ] + ) + with dspy.settings.context(lm=self.engine, show_guidelines=False): + decision = self.candidate_choosing( + intent=self._construct_intent(question=question, query=query), + choices=choices_string, + ).decision + decision = trim_output_after_hint(decision, hint="Decision:") + if "Best placement:" in decision: + decision = trim_output_after_hint(decision, hint="Best placement:") + selected_index = self._parse_selected_index(decision) + if selected_index is not None: + selected_index = selected_index - 1 + if selected_index < len(sorted_candidates) and selected_index >= 0: + return dspy.Prediction( + information_placement=sorted_candidates[selected_index], + note=f"Choosing from:\n{considered_candidates}", + ) + return None + + def _info_list_to_intent_mapping(self, information_list: List[Information]): + intent_to_placement_dict = {} + for info in information_list: + intent = (info.meta.get("question", ""), info.meta.get("query", "")) + if intent not in intent_to_placement_dict: + intent_to_placement_dict[intent] = None + return intent_to_placement_dict + + def forward( + self, + knowledge_base: KnowledgeBase, + information: Union[Information, List[Information]], + allow_create_new_node: bool = False, + max_thread: int = 5, + insert_root: Optional[KnowledgeNode] = None, + skip_candidate_from_embedding: bool = False, + ): + + if not isinstance(information, List): + information = [information] + intent_to_placement_dict: Dict = self._info_list_to_intent_mapping( + information_list=information + ) + + # process one intent + def process_intent(question: str, query: str): + candidate_placement = None + try: + if not skip_candidate_from_embedding: + candidate_placement = self.choose_candidate_from_embedding_ranking( + question=question, + query=query, + encoded_outlines=encoded_outlines, + outlines=outlines, + top_N_candidates=8, + ) + if candidate_placement is None: + candidate_placement = self.layer_by_layer_navigation_placement( + knowledge_base=knowledge_base, + question=question, + query=query, + allow_create_new_node=allow_create_new_node, + root=insert_root, + ) + return (question, query), candidate_placement + except Exception as e: + print(traceback.format_exc()) + return (question, query), None + + def insert_info_to_kb(info, placement_prediction): + if placement_prediction is not None: + missing_node_handling = ( + "raise error" if not allow_create_new_node else "create" + ) + knowledge_base.insert_information( + path=placement_prediction.information_placement, + information=info, + missing_node_handling=missing_node_handling, + root=insert_root, + ) + + encoded_outlines, outlines = ( + knowledge_base.get_knowledge_base_structure_embedding(root=insert_root) + ) + to_return = [] + if not allow_create_new_node: + # use multi thread as knowledge base structure does not change + with ThreadPoolExecutor(max_workers=max_thread) as executor: + futures = { + executor.submit(process_intent, question, query): (question, query) + for (question, query) in intent_to_placement_dict + } + + for future in as_completed(futures): + (question, query), candidate_placement = future.result() + intent_to_placement_dict[(question, query)] = candidate_placement + # back mapping placement to each information + for info in information: + intent = (info.meta.get("question", ""), info.meta.get("query", "")) + placement_prediction = intent_to_placement_dict.get(intent, None) + insert_info_to_kb(info, placement_prediction) + to_return.append((info, placement_prediction)) + return to_return + else: + # use sequential insert as knowledge base structure might change + for question, query in intent_to_placement_dict: + encoded_outlines, outlines = ( + knowledge_base.get_knowledge_base_structure_embedding( + root=insert_root + ) + ) + _, placement_prediction = process_intent(question=question, query=query) + intent_to_placement_dict[(question, query)] = placement_prediction + + for info in information: + intent = (info.meta.get("question", ""), info.meta.get("query", "")) + placement_prediction = intent_to_placement_dict.get(intent, None) + insert_info_to_kb(info, placement_prediction) + to_return.append((info, placement_prediction)) + return to_return + + +class ExpandSection(dspy.Signature): + """Your task is to expand a section in the mind map by creating new subsections under the given section. + You will be given a list of question and query that are used to collect information. + Output should be subsection names where each section should serve as a coherent and themantic organization of information and corresponding citation numbers. These subsection names are preferred to be concise and precise. + Output follows the format below: + subsection 1 + subsection 2 + subsection 3 + """ + + section = dspy.InputField(prefix="The section you need to expand: ", format=str) + info = dspy.InputField(prefix="The collected information:\n", format=str) + output = dspy.OutputField( + prefix="Now provide the expanded subsection names (If there's no need to expand current section as itself serves good organization, then output None):\n", + format=str, + ) + + +class ExpandNodeModule(dspy.Module): + def __init__( + self, + engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + information_insert_module: dspy.Module, + node_expansion_trigger_count: int, + ): + self.engine = engine + self.expand_section = dspy.Predict(ExpandSection) + self.information_insert_module = information_insert_module + self.node_expansion_trigger_count = node_expansion_trigger_count + + def _get_cited_info_meta_string(self, node, knowledge_base): + meta_string = set() + for index in sorted(list(node.content)): + info = knowledge_base.info_uuid_to_info_dict[index] + intent = f"Question: {info.meta['question']}\nQuery: {info.meta['query']}" + meta_string.add(intent) + + return "\n\n".join(meta_string) + + def _get_expand_subnode_names(self, node, knowledge_base): + information = self._get_cited_info_meta_string(node, knowledge_base) + node_path = node.get_path_from_root() + with dspy.settings.context(lm=self.engine, show_guidelines=False): + output = self.expand_section(section=node_path, info=information).output + subsections = [] + if "\n" in output and output != "None": + subsections = output.split("\n") + # remove any integer followed by a dot and a space, a leading dashline, + # or a specific hint at the start of the string + subsections = [ + re.sub(r"^\d+\.\s|-|" + re.escape(node.name), "", text) + .replace("*", "") + .strip() + for text in subsections + ] + return subsections + + def _find_first_node_to_expand( + self, root: KnowledgeNode, expanded_nodes: List[KnowledgeNode] + ): + if root is None: + return None + if ( + root not in expanded_nodes + and len(root.content) >= self.node_expansion_trigger_count + ): + return root + for child in root.children: + to_return = self._find_first_node_to_expand( + root=child, expanded_nodes=expanded_nodes + ) + if to_return is not None: + return to_return + return None + + def _expand_node(self, node: KnowledgeNode, knowledge_base: KnowledgeBase): + subsection_names = self._get_expand_subnode_names(node, knowledge_base) + if len(subsection_names) <= 1: + return + # create new nodes + for subsection_name in subsection_names: + # remove citation bracket in the subsection name + subsection_name = re.sub(r"\[.*?\]", "", subsection_name) + knowledge_base.insert_node(new_node_name=subsection_name, parent_node=node) + # reset original information placement + original_cited_index = node.content + original_cited_information = [ + knowledge_base.info_uuid_to_info_dict[index] + for index in original_cited_index + ] + node.content = set() + # re-insert under expanded section + self.information_insert_module( + knowledge_base=knowledge_base, + information=original_cited_information, + allow_create_new_node=False, + insert_root=node, + ) + + def forward(self, knowledge_base: KnowledgeBase): + expanded_nodes = [] + while True: + node_to_expand = self._find_first_node_to_expand( + root=knowledge_base.root, expanded_nodes=expanded_nodes + ) + if node_to_expand is None: + break + self._expand_node(node=node_to_expand, knowledge_base=knowledge_base) + expanded_nodes.append(node_to_expand) diff --git a/knowledge_storm/collaborative_storm/modules/knowledge_base_summary.py b/knowledge_storm/collaborative_storm/modules/knowledge_base_summary.py new file mode 100644 index 00000000..fb8e3403 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/knowledge_base_summary.py @@ -0,0 +1,32 @@ +import dspy +from typing import Union +from ...dataclass import KnowledgeBase + + +class KnowledgeBaseSummmary(dspy.Signature): + """Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections. + You will be presented with these sections where "#" denotes level of section. + """ + + topic = dspy.InputField(prefix="topic: ", format=str) + structure = dspy.InputField(prefix="Tree structure: \n", format=str) + output = dspy.OutputField(prefix="Now give brief summary:\n", format=str) + + +class KnowledgeBaseSummaryModule(dspy.Module): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.engine = engine + self.gen_summary = dspy.Predict(KnowledgeBaseSummmary) + + def forward(self, knowledge_base: KnowledgeBase): + structure = knowledge_base.get_node_hierarchy_string( + include_indent=False, + include_full_path=False, + include_hash_tag=True, + include_node_content_count=False, + ) + with dspy.settings.context(lm=self.engine, show_guidelines=False): + summary = self.gen_summary( + topic=knowledge_base.topic, structure=structure + ).output + return summary diff --git a/knowledge_storm/collaborative_storm/modules/simulate_user.py b/knowledge_storm/collaborative_storm/modules/simulate_user.py new file mode 100644 index 00000000..2d03e9db --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/simulate_user.py @@ -0,0 +1,37 @@ +import dspy +from typing import List, Union + +from .collaborative_storm_utils import extract_and_remove_citations +from ...dataclass import ConversationTurn +from ...storm_wiki.modules.knowledge_curation import AskQuestionWithPersona + + +class GenSimulatedUserUtterance(dspy.Module): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.engine = engine + self.ask_qeustion = dspy.Predict(AskQuestionWithPersona) + + def gen_conv_history_string(self, conversation_turns: List[ConversationTurn]): + conv_history = [] + total_turns = len(conversation_turns) + + for i, turn in enumerate(conversation_turns): + utterance, _ = extract_and_remove_citations(turn.utterance) + if i >= total_turns - 4: + conv_history.append(f"{turn.role}: {utterance}") + else: + if turn.claim_to_make: + conv_history.append(f"{turn.role}: {turn.claim_to_make}") + else: + conv_history.append(f"{turn.role}: {utterance}") + + return "\n".join(conv_history) + + def forward(self, topic: str, intent: str, conv_history: List[ConversationTurn]): + conv_history_string = self.gen_conv_history_string(conv_history) + with dspy.settings.context(lm=self.engine, show_guidelines=False): + return self.ask_qeustion( + topic=topic, + persona=f"researcher with interest in {intent}", + conv=conv_history_string, + ).question diff --git a/knowledge_storm/collaborative_storm/modules/warmstart_hierarchical_chat.py b/knowledge_storm/collaborative_storm/modules/warmstart_hierarchical_chat.py new file mode 100644 index 00000000..3357cbc2 --- /dev/null +++ b/knowledge_storm/collaborative_storm/modules/warmstart_hierarchical_chat.py @@ -0,0 +1,408 @@ +""" +Warm starts the Co-STORM system by conducting a background information search to establish a shared conceptual space with the user. + +This stage functions as a mini-STORM, where multiple LLM agents are spawned with different perspectives to engage in multi-round conversations. +The knowledge base (represented as a mind map) is initialized using the information gathered during these exchanges. + +Additionally, the system generates a first draft of the report, which is then used to create a concise and engaging conversation. +The synthesized conversation is presented to the user to help them quickly catch up on the system's current knowledge about the topic. +""" + +import dspy +import concurrent.futures +from threading import Lock +from typing import List, Optional, Union, TYPE_CHECKING + +from .callback import BaseCallbackHandler +from .collaborative_storm_utils import _get_answer_question_module_instance +from .expert_generation import GenerateExpertModule +from .grounded_question_answering import AnswerQuestionModule +from ...dataclass import ConversationTurn, KnowledgeBase +from ...interface import LMConfigs +from ...logging_wrapper import LoggingWrapper +from ...storm_wiki.modules.outline_generation import WritePageOutline +from ...utils import ArticleTextProcessing as AP + + +if TYPE_CHECKING: + from ..engine import RunnerArgument + + +class WarmStartModerator(dspy.Signature): + """ + You are a moderator in a roundtable discussion. The goal is to chat with multiple experts to discuss the facts and background of the topic to familiarize the audience with the topic. + You will be presented with the topic, the history of question you have already asked, and the current expert you are discussing with. + Based on these information, generate the next question for the current expert to further the discussion. + + The output should only include the next question for the current expert. Do not include any other information or preamble. + """ + + topic = dspy.InputField(prefix="Topic for roundtable discussion: ", format=str) + history = dspy.InputField( + prefix="Experts you have already interacted with: ", format=str + ) + current_expert = dspy.InputField(prefix="Expert you are talking with:", format=str) + question = dspy.OutputField( + prefix="Next question for the expert you are talking with: ", format=str + ) + + +class SectionToConvTranscript(dspy.Signature): + """ + You are given a section of a brief report on a specific topic. Your task is to transform this section into an engaging opening discussion for a roundtable conversation. + The goal is to help participants and the audience quickly understand the key information. + Both question and answer should be in the tone of roundtable discussion talking to audiences. + + Specifically, you need to: + 1. Generate an engaging question that leverages section name and topic that opens discussion of the content. + 2. Provide a brief and engaging answer (with all inline citations from original text) derived from the section serving as pointers and avoid too much details. + """ + + topic = dspy.InputField(prefix="topic:", format=str) + section_name = dspy.InputField(prefix="section name:", format=str) + section_content = dspy.InputField(prefix="section content:", format=str) + question = dspy.OutputField(prefix="Now give engaging question only.\nQuestion:") + answer = dspy.OutputField( + prefix="Now give engaging answer only with all inline citations from original text.\nAnswer:" + ) + + +class ReportToConversation(dspy.Module): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.engine = engine + self.section_to_conv_transcript = dspy.Predict(SectionToConvTranscript) + + def forward(self, knowledge_base: KnowledgeBase): + def process_node(node, topic): + with dspy.settings.context(lm=self.engine, show_guidelines=False): + output = self.section_to_conv_transcript( + topic=topic, + section_name=node.get_path_from_root(), + section_content=node.synthesize_output, + ) + question = output.question.replace("Question:", "").strip() + answer = output.answer.replace("Answer:", "").strip() + return question, answer + + conversations = [] + nodes = knowledge_base.collect_all_nodes() + nodes = [node for node in nodes if node.name != "root" and node.content] + topic = knowledge_base.topic + + with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_node = { + executor.submit(process_node, node, topic): node for node in nodes + } + for future in concurrent.futures.as_completed(future_to_node): + node = future_to_node[future] + question, answer = future.result() + conversations.append( + ConversationTurn( + role="Background discussion moderator", + raw_utterance=question, + utterance_type="Original Question", + utterance=question, + cited_info=[ + knowledge_base.info_uuid_to_info_dict[idx] + for idx in AP.parse_citation_indices(question) + ], + ) + ) + conversations.append( + ConversationTurn( + role="Background discussion expert", + raw_utterance=answer, + utterance_type="Potential Answer", + utterance=answer, + cited_info=[ + knowledge_base.info_uuid_to_info_dict[idx] + for idx in AP.parse_citation_indices(answer) + ], + ) + ) + return conversations + + +class WarmStartConversation(dspy.Module): + def __init__( + self, + question_asking_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + generate_expert_module: GenerateExpertModule, + answer_question_module: AnswerQuestionModule, + logging_wrapper: LoggingWrapper, + max_num_experts: int = 3, + max_turn_per_experts: int = 2, + max_thread: int = 3, + callback_handler: BaseCallbackHandler = None, + ): + self.ask_question = dspy.Predict(WarmStartModerator) + self.max_num_experts = max_num_experts + self.max_turn_per_experts = max_turn_per_experts + self.question_asking_lm = question_asking_lm + self.answer_question_module = answer_question_module + self.max_thread = max_thread + self.generate_experts_module = generate_expert_module + self.logging_wrapper = logging_wrapper + self.callback_handler = callback_handler + + def format_dialogue_question_history_string( + self, conversation_history: List[ConversationTurn] + ): + output = [] + for idx, turn in enumerate(conversation_history): + info = turn.claim_to_make if turn.claim_to_make else turn.utterance + output.append(f"{idx + 1}: {info}") + return "\n".join(output) + + def generate_warmstart_experts(self, topic: str): + background_seeking_dialogue = self.get_background_info(topic=topic) + background_info = background_seeking_dialogue.utterance + gen_expert_output = self.generate_experts_module( + topic=topic, + background_info=background_info, + num_experts=self.max_num_experts, + ) + return gen_expert_output.experts, background_seeking_dialogue + + def get_background_info(self, topic: str): + question = f"Background information about {topic}" + answer = self.answer_question_module( + topic=topic, question=question, mode="extensive", style="conversational" + ) + + return ConversationTurn( + role="Default Background Researcher", + raw_utterance=answer.response, + utterance_type="Questioning", + claim_to_make=question, + queries=answer.queries, + raw_retrieved_info=answer.raw_retrieved_info, + cited_info=answer.cited_info, + ) + + def forward(self, topic: str): + with self.logging_wrapper.log_event( + "warm start, perspective guided QA: identify experts" + ): + # do background research, generate some experts + experts, background_seeking_dialogue = self.generate_warmstart_experts( + topic=topic + ) + # init list to store the dialogue history + conversation_history: List[ConversationTurn] = [] + lock = Lock() + + # hierarchical chat: chat with one expert. Generate question, get answer + def process_expert(expert): + expert_name, expert_descriptoin = expert.split(":") + for idx in range(self.max_turn_per_experts): + with self.logging_wrapper.log_event( + f"warm start, perspective guided QA: expert {expert_name}; turn {idx + 1}" + ): + try: + with lock: + history = self.format_dialogue_question_history_string( + conversation_history + ) + with dspy.settings.context(lm=self.question_asking_lm): + question = self.ask_question( + topic=topic, history=history, current_expert=expert + ).question + answer = self.answer_question_module( + topic=topic, + question=question, + mode="brief", + style="conversational", + ) + conversation_turn = ConversationTurn( + role=expert, + claim_to_make=question, + raw_utterance=answer.response, + utterance_type="Support", + queries=answer.queries, + raw_retrieved_info=answer.raw_retrieved_info, + cited_info=answer.cited_info, + ) + if self.callback_handler is not None: + self.callback_handler.on_warmstart_update( + message="\n".join( + [ + f"Finish browsing {url}" + for url in [ + i.url for i in answer.raw_retrieved_info + ] + ] + ) + ) + with lock: + conversation_history.append(conversation_turn) + except Exception as e: + print(f"Error processing expert {expert}: {e}") + + # multi-thread conversation + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_thread + ) as executor: + futures = [ + executor.submit(process_expert, expert) + for expert in experts[: min(len(experts), self.max_num_experts)] + ] + concurrent.futures.wait(futures) + + conversation_history = [background_seeking_dialogue] + conversation_history + + return dspy.Prediction( + conversation_history=conversation_history, experts=experts + ) + + +class GenerateWarmStartOutline(dspy.Signature): + """Generate a outline of the wikipedia-like report from a roundtable discussion. You will be presented discussion points in the conversation and corresponding queries. + You will be given a draft outline which you can borrow some inspiration. Do not include sections that are not mentioned in the given discussion history. + Use "#" to denote section headings, "##" to denote subsection headings, and so on. + Follow these guidelines: + 1. Use "#" for section titles, "##" for subsection titles, "###" for subsubsection titles, and so on. + 2. Do not include any additional information. + 3. Exclude the topic name from the outline. + The organization of outline should adopt wikiepdia style. + """ + + topic = dspy.InputField(prefix="The topic discussed: ", format=str) + draft = dspy.InputField(prefix="Draft outline you can reference to: ", format=str) + conv = dspy.InputField(prefix="Discussion history:\n", format=str) + outline = dspy.OutputField( + prefix='Write the conversation outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n', + format=str, + ) + + +class GenerateWarmStartOutlineModule(dspy.Module): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + self.engine = engine + self.gen_outline = dspy.Predict(GenerateWarmStartOutline) + self.draft_outline = dspy.Predict(WritePageOutline) + + def extract_questions_and_queries(self, conv: List[ConversationTurn]): + context = [] + for turn in conv: + focus = turn.claim_to_make + queries = turn.queries + queries_string = "\n\t".join( + f"Query {idx + 1}: {query}" for idx, query in enumerate(queries) + ) + string = f"Discussion focus {len(context) + 1}: {focus}\n\t{queries_string}" + context.append(string) + return "\n".join(context) + + def get_draft_outline(self, topic: str): + with dspy.settings.context(lm=self.engine): + return self.draft_outline(topic=topic).outline + + def forward(self, topic: str, conv: List[ConversationTurn]): + discussion_history = self.extract_questions_and_queries(conv) + draft_outline = self.get_draft_outline(topic=topic) + with dspy.settings.context(lm=self.engine): + outline = self.gen_outline( + topic=topic, draft=draft_outline, conv=discussion_history + ).outline + outline = AP.clean_up_outline(outline) + return dspy.Prediction(outline=outline, draft_outline=draft_outline) + + +class WarmStartModule: + def __init__( + self, + lm_config: LMConfigs, + runner_argument: "RunnerArgument", + logging_wrapper: LoggingWrapper, + rm: Optional[dspy.Retrieve] = None, + callback_handler: BaseCallbackHandler = None, + ): + generate_expert_module = GenerateExpertModule( + engine=lm_config.discourse_manage_lm + ) + self.warmstart_conv = WarmStartConversation( + question_asking_lm=lm_config.question_asking_lm, + generate_expert_module=generate_expert_module, + answer_question_module=_get_answer_question_module_instance( + lm_config=lm_config, + runner_argument=runner_argument, + logging_wrapper=logging_wrapper, + rm=rm, + ), + max_num_experts=runner_argument.warmstart_max_num_experts, + max_turn_per_experts=runner_argument.warmstart_max_turn_per_experts, + max_thread=runner_argument.warmstart_max_thread, + logging_wrapper=logging_wrapper, + callback_handler=callback_handler, + ) + self.warmstart_outline_gen_module = GenerateWarmStartOutlineModule( + engine=lm_config.warmstart_outline_gen_lm + ) + self.report_to_conversation = ReportToConversation(lm_config.knowledge_base_lm) + self.logging_wrapper = logging_wrapper + self.callback_handler = callback_handler + + def initiate_warm_start(self, topic: str, knowledge_base: KnowledgeBase): + """ + Initiates a warm start process for the given topic by generating a warm start conversation and inserting the + resulting information into a knowledge base. + + Args: + topic (str): The topic for which to initiate the warm start process. + + Returns: + Tuple[List[ConversationTurn], List[str], KnowledgeBase]: + - A list of ConversationTurn instances representing the conversation history. + - A list of strings representing the experts involved in the conversation. + - A KnowledgeBase instance containing the organized information. + """ + warm_start_conversation_history: List[ConversationTurn] = [] + warm_start_experts = None + # get warm start conversations + with self.logging_wrapper.log_event("warm start: perspective guided QA"): + if self.callback_handler is not None: + self.callback_handler.on_warmstart_update( + message="Start getting familiar with the topic by chatting with multiple LLM experts (Step 1 / 4)" + ) + warm_start_result = self.warmstart_conv(topic=topic) + warm_start_conversation_history = warm_start_result.conversation_history + warm_start_experts = warm_start_result.experts + + # get warm start conv outline + with self.logging_wrapper.log_event("warm start: outline generation"): + if self.callback_handler is not None: + self.callback_handler.on_warmstart_update( + "Organizing collected information (Step 2 / 4)" + ) + warm_start_outline_output = self.warmstart_outline_gen_module( + topic=topic, conv=warm_start_conversation_history + ) + # init knowledge base + with self.logging_wrapper.log_event("warm start: insert into knowledge base"): + if self.callback_handler is not None: + self.callback_handler.on_warmstart_update( + "Inserting collected information into knowledge base (Step 3 / 4)" + ) + knowledge_base.insert_from_outline_string( + outline_string=warm_start_outline_output.outline + ) + # insert information to knowledge base + for turn in warm_start_conversation_history: + knowledge_base.update_from_conv_turn( + conv_turn=turn, allow_create_new_node=False + ) + # knowledge base to report + if self.callback_handler is not None: + self.callback_handler.on_warmstart_update( + "Synthesizing background information discussion utterances (Step 4 / 4)" + ) + knowledge_base.to_report() + + # generate engaging conversations + engaging_conversations = self.report_to_conversation(knowledge_base) + return ( + warm_start_conversation_history, + engaging_conversations, + warm_start_experts, + ) diff --git a/knowledge_storm/dataclass.py b/knowledge_storm/dataclass.py new file mode 100644 index 00000000..fd981905 --- /dev/null +++ b/knowledge_storm/dataclass.py @@ -0,0 +1,849 @@ +import dspy +import numpy as np +import re +import threading +from typing import Set, Dict, List, Optional, Union, Tuple + +from .encoder import get_text_embeddings +from .interface import Information + + +class ConversationTurn: + """ + A class to represent a turn in a conversation. + + Attributes: + role (str): A short phrase of the role of the speaker for the current conversation turn. + raw_utterance (str): The response generated by the LM model without polished style and tone. + utterance_type (str): The type of utterance (e.g., statement, question). + claim_to_make (Optional[str]): The point that this utterance tries to make. Should be empty if the utterance type is questioning. + utterance (Optional[str]): The response generated by the model with polished style and tone. Defaults to raw_utterance if not provided. + queries (List[str]): The queries used to gather information to have a grounded answer. + raw_retrieved_info (List['Information']): A list of Information type that is retrieved. + cited_info (Dict[int, 'Information']): A dictionary where the key is the citation index and the value is Information type. + role_description (Optional[str]): A few sentences description of the role. Defaults to an empty string if not provided. + """ + + def __init__( + self, + role: str, + raw_utterance: str, + utterance_type: str, + claim_to_make: Optional[str] = None, + utterance: Optional[str] = None, + queries: Optional[List[str]] = None, + raw_retrieved_info: Optional[List[Information]] = None, + cited_info: Optional[List[Information]] = None, + ): + self.utterance = utterance if utterance is not None else raw_utterance + self.raw_utterance = raw_utterance + self.role = role if ":" not in role else role.split(":")[0] + self.role_description = "" if ":" not in role else role.split(":")[1] + self.queries = queries if queries is not None else [] + self.raw_retrieved_info = ( + raw_retrieved_info if raw_retrieved_info is not None else [] + ) + self.cited_info = cited_info if cited_info is not None else {} + self.utterance_type = utterance_type + self.claim_to_make = claim_to_make if claim_to_make is not None else "" + + def get_all_citation_index(self): + citation_pattern = re.compile(r"\[(\d+)\]") + return list(map(int, citation_pattern.findall(self.utterance))) + + def to_dict(self): + raw_retrieved_info = [info.to_dict() for info in self.raw_retrieved_info] + return { + "utterance": self.utterance, + "raw_utterance": self.raw_utterance, + "role": self.role, + "role_description": self.role_description, + "queries": self.queries, + "utterance_type": self.utterance_type, + "claim_to_make": self.claim_to_make, + "raw_retrieved_info": raw_retrieved_info, + "cited_info": None, + } + + @classmethod + def from_dict(cls, conv_turn_dict: Dict): + raw_retrieved_info = [ + Information.from_dict(info) for info in conv_turn_dict["raw_retrieved_info"] + ] + + return cls( + utterance=conv_turn_dict["utterance"], + raw_utterance=conv_turn_dict["raw_utterance"], + role=f"{conv_turn_dict['role']}: {conv_turn_dict['role_description']}", + queries=conv_turn_dict["queries"], + raw_retrieved_info=raw_retrieved_info, + cited_info=None, + utterance_type=conv_turn_dict["utterance_type"], + claim_to_make=conv_turn_dict["claim_to_make"], + ) + + +class KnowledgeNode: + """ + Class representing a node in the knowledge base. + + Attributes: + name (str): The name of the node. + content (list): A list of Information instances. + children (list): A list of child KnowledgeNode instances. + parent (KnowledgeNode): The parent node of the current node. + """ + + def __init__( + self, + name: str, + content: Optional[str] = None, + parent: Optional["KnowledgeNode"] = None, + children: Optional[List["KnowledgeNode"]] = None, + synthesize_output: Optional[str] = None, + need_regenerate_synthesize_output: bool = True, + ): + """ + Initializes a KnowledgeNode instance. + + Args: + name (str): The name of the node. + content (list, optional): A list of information uuid. Defaults to None. + parent (KnowledgeNode, optional): The parent node of the current node. Defaults to None. + """ + self.name = name + self.content: Set[int] = set(content) if content is not None else set() + self.children = [] if children is None else children + self.parent = parent + self.synthesize_output = synthesize_output + self.need_regenerate_synthesize_output = need_regenerate_synthesize_output + + def collect_all_content(self): + """ + Collects all content from the current node and its descendants. + + Returns: + Set[int]: A set containing all content from the current node and its descendants. + """ + all_content = set(self.content) + for child in self.children: + all_content.update(child.collect_all_content()) + return all_content + + def has_child(self, child_node_name: str): + """ + Check if the node has the child of given name. + """ + return child_node_name in [child.name for child in self.children] + + def add_child(self, child_node_name: str, duplicate_handling: str = "skip"): + """ + Adds a child node to the current node. + duplicate_handling (str): How to handle duplicate nodes. Options are "skip", "none", and "raise error". + """ + if self.has_child(child_node_name): + if duplicate_handling == "skip": + for child in self.children: + if child.name == child_node_name: + return child + elif duplicate_handling == "raise error": + raise Exception( + f"Insert node error. Node {child_node_name} already exists under its parent node {self.name}." + ) + child_node = KnowledgeNode(name=child_node_name, parent=self) + self.children.append(child_node) + return child_node + + def get_parent(self): + """ + Returns the parent node of the current node. + + Returns: + KnowledgeNode: The parent node of the current node. + """ + return self.parent + + def get_children(self): + """ + Returns the children of the current node. + + Returns: + list: A list of child KnowledgeNode instances. + """ + return self.children + + def get_children_names(self): + """ + Returns a list of children names. + """ + return [child.name for child in self.children] + + def __repr__(self): + """ + Returns a string representation of the KnowledgeNode instance. + + Returns: + str: String representation of the KnowledgeNode instance. + """ + return f"KnowledgeNode(name={self.name}, content={self.content}, children={len(self.children)})" + + def get_path_from_root(self, root: Optional["KnowledgeNode"] = None): + """ + Get a list of names from the root to this node. + + Returns: + List[str]: A list of node names from the root to this node. + """ + path = [] + current_node = self + while current_node: + path.append(current_node.name) + if root is not None and current_node.name == root.name: + break + current_node = current_node.parent + return path[::-1] + + def insert_information(self, information_index: int): + if information_index not in self.content: + self.need_regenerate_synthesize_output = True + self.content.add(information_index) + + def get_all_descendents(self) -> List["KnowledgeNode"]: + """ + Get a list of all descendant nodes. + + Returns: + List[KnowledgeNode]: A list of all descendant nodes. + """ + descendents = [] + + def collect_descendents(node): + for child in node.children: + descendents.append(child) + collect_descendents(child) + + collect_descendents(self) + return descendents + + def get_all_predecessors(self) -> List["KnowledgeNode"]: + """ + Get a list of all predecessor nodes (from current node to root). + + Returns: + List[KnowledgeNode]: A list of all predecessor nodes. + """ + predecessors = [] + current_node = self.parent + while current_node is not None: + predecessors.append(current_node) + current_node = current_node.parent + return predecessors + + def to_dict(self): + """ + Converts the KnowledgeNode instance to a dictionary representation. + + Returns: + dict: The dictionary representation of the KnowledgeNode. + """ + return { + "name": self.name, + "content": list(self.content), + "children": [child.to_dict() for child in self.children], + "parent": self.parent.name if self.parent else None, + "synthesize_output": self.synthesize_output, + "need_regenerate_synthesize_output": self.need_regenerate_synthesize_output, + } + + @classmethod + def from_dict(cls, data): + """ + Constructs a KnowledgeNode instance from a dictionary representation. + + Args: + data (dict): The dictionary representation of the KnowledgeNode. + + Returns: + KnowledgeNode: The constructed KnowledgeNode instance. + """ + + def helper(cls, data, parent_node=None): + if parent_node is not None: + assert data["parent"] is not None and data["parent"] == parent_node.name + node = cls( + name=data["name"], + content=data["content"], + parent=parent_node, + children=None, + synthesize_output=data.get("synthesize_output", None), + need_regenerate_synthesize_output=data.get( + "need_regenerate_synthesize_output", True + ), + ) + for child_data in data["children"]: + child_node = helper(cls, child_data, parent_node=node) + node.children.append(child_node) + return node + + return helper(cls, data) + + +class KnowledgeBase: + """ + Represents the dynamic, hierarchical mind map used in Co-STORM to track and organize discourse. + + The knowledge base serves as a shared conceptual space between the user and the system, allowing for effective collaboration by reducing the user's cognitive load and ensuring that the discourse is easy to follow. + + The knowledge base is structured as a tree (or mind map) that dynamically organizes collected information and concepts as the conversation progresses. + + The mind map consists of concepts (nodes) and edges that represent parent-child relationships among topics. Each concept is linked to retrieved information, + which is placed under the most appropriate concept based on its associated question and semantic similarity. + + For more details, please refer to Section 3.2 of Co-STORM paper: https://www.arxiv.org/pdf/2408.15232 + Attributes: + root (KnowledgeNode): The root node of the hierarchical knowledge base, representing the top-level concept. + + """ + + def __init__( + self, + topic: str, + knowledge_base_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + node_expansion_trigger_count: int, + ): + """ + Initializes a KnowledgeBase instance. + + Args: + topic (str): The topic of the knowledge base + expand_node_module (dspy.Module): The module that organize knowledge base in place. + The module should accept knowledge base as param. E.g. expand_node_module(self) + article_generation_module (dspy.Module): The module that generate report from knowledge base. + The module should return string. E.g. report = article_generation_module(self) + """ + from .collaborative_storm.modules.article_generation import ( + ArticleGenerationModule, + ) + from .collaborative_storm.modules.information_insertion_module import ( + InsertInformationModule, + ExpandNodeModule, + ) + from .collaborative_storm.modules.knowledge_base_summary import ( + KnowledgeBaseSummaryModule, + ) + + self.topic: str = topic + + self.information_insert_module = InsertInformationModule( + engine=knowledge_base_lm + ) + self.expand_node_module = ExpandNodeModule( + engine=knowledge_base_lm, + information_insert_module=self.information_insert_module, + node_expansion_trigger_count=node_expansion_trigger_count, + ) + self.article_generation_module = ArticleGenerationModule( + engine=knowledge_base_lm + ) + self.gen_summary_module = KnowledgeBaseSummaryModule(engine=knowledge_base_lm) + + self.root: KnowledgeNode = KnowledgeNode(name="root") + self.kb_embedding = { + "hash": hash(""), + "encoded_structure": np.array([[]]), + "structure_string": "", + } + self.embedding_cache: Dict[str, np.ndarray] = {} + self.info_uuid_to_info_dict: Dict[int, Information] = {} + self.info_hash_to_uuid_dict: Dict[int, int] = {} + self._lock = threading.Lock() + + def to_dict(self): + info_uuid_to_info_dict = { + key: value.to_dict() for key, value in self.info_uuid_to_info_dict.items() + } + return { + "topic": self.topic, + "tree": self.root.to_dict(), + "info_uuid_to_info_dict": info_uuid_to_info_dict, + "info_hash_to_uuid_dict": self.info_hash_to_uuid_dict, + } + + @classmethod + def from_dict( + cls, + data: Dict, + knowledge_base_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + node_expansion_trigger_count: int, + ): + knowledge_base = cls( + topic=data["topic"], + knowledge_base_lm=knowledge_base_lm, + node_expansion_trigger_count=node_expansion_trigger_count, + ) + knowledge_base.root = KnowledgeNode.from_dict(data["tree"]) + knowledge_base.info_hash_to_uuid_dict = { + int(key): int(value) + for key, value in data["info_hash_to_uuid_dict"].items() + } + info_uuid_to_info_dict = { + int(key): Information.from_dict(value) + for key, value in data["info_uuid_to_info_dict"].items() + } + knowledge_base.info_uuid_to_info_dict = info_uuid_to_info_dict + return knowledge_base + + def get_knowledge_base_structure_embedding( + self, root: Optional[KnowledgeNode] = None + ) -> Tuple[np.ndarray, List[str]]: + outline_string = self.get_node_hierarchy_string( + include_indent=False, + include_full_path=True, + include_hash_tag=False, + root=root, + ) + outline_string_hash = hash(outline_string) + if outline_string_hash != self.kb_embedding["hash"]: + outline_strings: List[str] = outline_string.split("\n") + cleaned_outline_strings = [ + outline.replace(" -> ", ", ") for outline in outline_strings + ] + encoded_outline, _ = get_text_embeddings( + cleaned_outline_strings, embedding_cache=self.embedding_cache + ) + self.kb_embedding = { + "hash": outline_string_hash, + "encoded_structure": encoded_outline, + "structure_string": outline_strings, + } + return ( + self.kb_embedding["encoded_structure"], + self.kb_embedding["structure_string"], + ) + + def traverse_down(self, node): + """ + Traverses the tree downward from the given node. + + Args: + node (KnowledgeNode): The node to start the traversal from. + + Returns: + list: A list of KnowledgeNode instances in the order they were visited. + """ + nodes = [] + + def _traverse(current_node): + nodes.append(current_node) + for child in current_node.get_children(): + _traverse(child) + + _traverse(node) + return nodes + + def traverse_up(self, node): + """ + Traverses the tree upward from the given node. + + Args: + node (KnowledgeNode): The node to start the traversal from. + + Returns: + list: A list of KnowledgeNode instances in the order they were visited. + """ + nodes = [] + while node is not None: + nodes.append(node) + node = node.get_parent() + return nodes + + def collect_all_nodes(self): + nodes = [] + + def _collect(node): + nodes.append(node) + for child in node.children: + _collect(child) + + _collect(self.root) + return nodes + + def insert_node( + self, + new_node_name, + parent_node: Optional[KnowledgeNode] = None, + duplicate_handling="skip", + ): + """ + Inserts a new node into the knowledge base under the specified parent node. + + Args: + new_node_name (str): The name of the new node. + parent_node_name (str): The name of the parent node. If None, the new node is inserted under the root. + duplicate_handling (str): How to handle duplicate nodes. Options are "skip", "none", and "raise error". + """ + if parent_node is None: + return self.root.add_child( + new_node_name, duplicate_handling=duplicate_handling + ) + else: + return parent_node.add_child( + new_node_name, duplicate_handling=duplicate_handling + ) + + def find_node(self, current_node, node_name): + """ + Finds a node by name in the knowledge base. + + Args: + current_node (KnowledgeNode): The node to start the search from. + node_name (str): The name of the node to find. + + Returns: + KnowledgeNode: The node with the specified name, or None if not found. + """ + if current_node.name == node_name: + return current_node + for child in current_node.get_children(): + result = self.find_node(child, node_name) + if result is not None: + return result + return None + + def insert_from_outline_string(self, outline_string, duplicate_handling="skip"): + """ + Creates and inserts nodes into the knowledge base from a string outline. + + Args: + outline_string (str): The outline string where each line starts with '#' denoting the level. + duplicate_handling (str): How to handle duplicate nodes. Options are "skip", "none", and "raise error". + """ + last_node_at_level = {} + for line in outline_string.split("\n"): + level = line.count("#") + if level > 0: + title = line.strip("# ").strip() + if title.lower() in ["overview", "summary", "introduction"]: + continue + parent_node = None if level == 1 else last_node_at_level.get(level - 1) + new_node = self.insert_node( + new_node_name=title, + parent_node=parent_node, + duplicate_handling=duplicate_handling, + ) + last_node_at_level[level] = new_node + for deeper_level in list(last_node_at_level.keys()): + if deeper_level > level: + del last_node_at_level[deeper_level] + + def get_node_hierarchy_string( + self, + include_indent=False, + include_full_path=False, + include_hash_tag=True, + include_node_content_count=False, + cited_indices: Optional[List[int]] = None, + root: Optional[KnowledgeNode] = None, + ) -> str: + + def find_node_contain_index(node, index): + """ + Traverses the tree downward from the given node. + + Args: + node (KnowledgeNode): The node to start the traversal from. + + Returns: + list: A list of KnowledgeNode instances in the order they were visited. + """ + nodes = [] + + def _traverse(current_node): + if current_node is not None and index in current_node.content: + nodes.append(current_node) + for child in current_node.get_children(): + _traverse(child) + + _traverse(node) + return nodes + + paths_to_highlight = set() + nodes_to_include = set() + if cited_indices is not None: + for index in cited_indices: + for cur_node in find_node_contain_index(self.root, index): + paths_to_highlight.add(" -> ".join(cur_node.get_path_from_root())) + nodes_to_include.add(cur_node) + nodes_to_include.update(cur_node.get_all_descendents()) + predecessors = cur_node.get_all_predecessors() + for predecessor in predecessors: + nodes_to_include.update(predecessor.children) + nodes_to_include.update(predecessors) + + def should_include_node(node): + if cited_indices is None: + return True + return node in nodes_to_include + + def should_omit_child_nodes(node): + if cited_indices is None: + return False + for child in node.children: + if should_include_node(child): + return False + return True + + def helper(cur_root, level): + to_return = [] + if cur_root is not None: + should_include_current_node = should_include_node(cur_root) + + indent = "" if not include_indent else "\t" * (level - 1) + full_path = " -> ".join(cur_root.get_path_from_root(root=root)) + node_info = cur_root.name if not include_full_path else full_path + hash_tag = "#" * level + " " if include_hash_tag else "" + content_count = ( + f" ({len(cur_root.content)})" if include_node_content_count else "" + ) + special_note = ( + "" + if cited_indices is None or full_path not in paths_to_highlight + else " ⭐" + ) + + if should_include_current_node: + to_return.append( + f"{indent}{hash_tag}{node_info}{content_count}{special_note}" + ) + if should_omit_child_nodes(cur_root): + if len(cur_root.children) > 0: + child_indent = indent = ( + "" if not include_indent else "\t" * (level) + ) + to_return.append(f"{child_indent}...") + else: + for child in cur_root.children: + to_return.extend(helper(child, level + 1)) + return to_return + + to_return = [] + if root is None and self.root is not None: + for child in self.root.children: + to_return.extend(helper(child, level=1)) + else: + to_return.extend(helper(root, level=1)) + + return "\n".join(to_return) + + def find_node_by_path( + self, + path: str, + missing_node_handling="abort", + root: Optional[KnowledgeNode] = None, + ): + """ + Returns the target node given a path string. + + Args: + path (str): The path to the node, with node names connected by " -> ". + missing_node_handling (str): How to handle missing nodes. Options are "abort", "create", and "raise error". + + Returns: + KnowledgeNode: The target node. + """ + node_names = path.split(" -> ") + current_node = self.root if root is None else root + + for name in node_names[1:]: + found_node = next( + (child for child in current_node.children if child.name == name), None + ) + if found_node is None: + if missing_node_handling == "abort": + return + elif missing_node_handling == "create": + new_node = current_node.add_child(child_node_name=name) + current_node = new_node + elif missing_node_handling == "raise error": + structure = self.get_node_hierarchy_string( + include_indent=True, + include_full_path=False, + include_hash_tag=True, + ) + raise Exception( + f"Insert information error. Unable to find node {{{name}}} under {{{current_node.name}}}\n{structure}" + ) + else: + current_node = found_node + return current_node + + def insert_information( + self, + path: str, + information: Information, + missing_node_handling="abort", + root: Optional[KnowledgeNode] = None, + ): + """ + Inserts information into the knowledge base at the specified path. + + Args: + path (str): The placement path string, connected by " -> " linking the name of nodes. + information (Information): The information to insert. + missing_node_handling (str): How to handle missing nodes. Options are "abort", "create", and "raise error". + Return: + uuid of insertion information + """ + with self._lock: + target_node: KnowledgeNode = self.find_node_by_path( + path=path, missing_node_handling=missing_node_handling, root=root + ) + information_hash = hash(information) + if information.citation_uuid == -1: + info_citation_uuid = self.info_hash_to_uuid_dict.get( + information_hash, len(self.info_hash_to_uuid_dict) + 1 + ) + information.citation_uuid = info_citation_uuid + self.info_hash_to_uuid_dict[information_hash] = info_citation_uuid + self.info_uuid_to_info_dict[info_citation_uuid] = information + if target_node is not None: + self.info_uuid_to_info_dict[information.citation_uuid].meta[ + "placement" + ] = " -> ".join(target_node.get_path_from_root()) + target_node.insert_information(information.citation_uuid) + + def trim_empty_leaf_nodes(self): + """ + Trims all leaf nodes that do not have any content. Iteratively does it until all leaf nodes have at least one content. + """ + + def trim_node(node): + if not node.children and not node.content: + return True + node.children = [child for child in node.children if not trim_node(child)] + return not node.children and not node.content + + # Start the trimming process from the root + while True: + before_trim = len(self.get_all_leaf_nodes()) + trim_node(self.root) + after_trim = len(self.get_all_leaf_nodes()) + if before_trim == after_trim: + break + + def get_all_leaf_nodes(self): + """ + Helper function to get all leaf nodes. + + Returns: + List[KnowledgeNode]: A list of all leaf nodes in the knowledge base. + """ + leaf_nodes = [] + + def find_leaf_nodes(node): + if not node.children: + leaf_nodes.append(node) + for child in node.children: + find_leaf_nodes(child) + + find_leaf_nodes(self.root) + return leaf_nodes + + def merge_single_child_nodes(self): + """ + Merges content of a node with its single child and removes the child node. + Iteratively does this from leaf nodes back to the root. + """ + + def merge_node(node): + # Recursively merge children first + for child in node.children: + merge_node(child) + + # If the node has exactly one child, merge its content with the child and remove the child + if len(node.children) == 1: + single_child = node.children[0] + node.content.update(single_child.content) + node.children = single_child.children + for grandchild in node.children: + grandchild.parent = node + + merge_node(self.root) + + def update_all_info_path(self): + def _helper(node): + for citation_idx in node.content: + self.info_uuid_to_info_dict[citation_idx].meta["placement"] = ( + " -> ".join(node.get_path_from_root()) + ) + for child in node.children: + _helper(child) + + _helper(self.root) + + def update_from_conv_turn( + self, + conv_turn: ConversationTurn, + allow_create_new_node: bool = False, + insert_under_root: bool = False, + ): + if conv_turn is None: + return + info_to_insert = list(conv_turn.cited_info.values()) + if insert_under_root: + for info in info_to_insert: + self.insert_information(path=self.root.name, information=info) + else: + self.information_insert_module( + knowledge_base=self, + information=info_to_insert, + allow_create_new_node=allow_create_new_node, + ) + old_to_new_citation_idx_mapping = { + old_idx: info.citation_uuid + for old_idx, info in conv_turn.cited_info.items() + } + + for old_idx, new_idx in old_to_new_citation_idx_mapping.items(): + conv_turn.utterance = conv_turn.utterance.replace( + f"[{old_idx}]", f"[_{new_idx}_]" + ) + conv_turn.raw_utterance = conv_turn.raw_utterance.replace( + f"[{old_idx}]", f"[_{new_idx}_]" + ) + for _, new_idx in old_to_new_citation_idx_mapping.items(): + conv_turn.utterance = conv_turn.utterance.replace( + f"[_{new_idx}_]", f"[{new_idx}]" + ) + conv_turn.utterance.replace("[-1]", "") + conv_turn.raw_utterance = conv_turn.raw_utterance.replace( + f"[_{new_idx}_]", f"[{new_idx}]" + ) + conv_turn.raw_utterance.replace("[-1]", "") + conv_turn.cited_info = None + + def get_knowledge_base_summary(self): + return self.gen_summary_module(self) + + def reogranize(self): + """ + Reorganizes the knowledge base through two main processes: top-down expansion and bottom-up cleaning. + + The reorganization process ensures that the knowledge base remains well-structured and relevant as new information is added. It consists of the following steps: + 1.Top-Down Expansion: Expands nodes that have accumulated significant amounts of information by creating subtopics, + ensuring that each concept remains specific and manageable. + 2.Bottom-Up Cleaning: Cleans the knowledge base by removing empty leaf nodes (nodes with no supporting information) + and merging nodes that have only a single child, simplifying the structure and maintaining clarity. + """ + # pre-processing + self.trim_empty_leaf_nodes() + self.merge_single_child_nodes() + # expand nodes + self.expand_node_module(knowledge_base=self) + # clean up + self.trim_empty_leaf_nodes() + self.merge_single_child_nodes() + self.update_all_info_path() + + def to_report(self): + return self.article_generation_module(knowledge_base=self) diff --git a/knowledge_storm/encoder.py b/knowledge_storm/encoder.py new file mode 100644 index 00000000..3d14e63c --- /dev/null +++ b/knowledge_storm/encoder.py @@ -0,0 +1,169 @@ +import requests +import os +from typing import List, Tuple, Union, Optional, Dict, Literal +import numpy as np + +from concurrent.futures import ThreadPoolExecutor, as_completed + + +class EmbeddingModel: + def __init__(): + pass + + def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: + raise Exception("Not implemented") + + +class OpenAIEmbeddingModel(EmbeddingModel): + def __init__(self, model: str = "text-embedding-3-small", api_key: str = None): + if not api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + + self.url = "https://api.openai.com/v1/embeddings" + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + self.model = model + + def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: + data = {"input": text, "model": "text-embedding-3-small"} + + response = requests.post(self.url, headers=self.headers, json=data) + if response.status_code == 200: + data = response.json() + embedding = np.array(data["data"][0]["embedding"]) + token = data["usage"]["prompt_tokens"] + return embedding, token + else: + response.raise_for_status() + + +class OpenAIEmbeddingModel(EmbeddingModel): + def __init__(self, model: str = "text-embedding-3-small", api_key: str = None): + if not api_key: + api_key = os.getenv("OPENAI_API_KEY") + + self.url = "https://api.openai.com/v1/embeddings" + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + self.model = model + + def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: + data = {"input": text, "model": self.model} + + response = requests.post(self.url, headers=self.headers, json=data) + if response.status_code == 200: + data = response.json() + embedding = np.array(data["data"][0]["embedding"]) + token = data["usage"]["prompt_tokens"] + return embedding, token + else: + response.raise_for_status() + + +class TogetherEmbeddingModel: + def __init__(self, model: str = "BAAI/bge-large-en-v1.5", api_key: str = None): + import together + + self.model = model + if not api_key: + api_key = os.getenv("TOGETHER_API_KEY") + self.together_client = together.Together(api_key=api_key) + + def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: + response = self.together_client.embeddings.create(input=text, model=self.model) + return response.data[0].embedding, -1 + + +class AzureOpenAIEmbeddingModel: + def __init__(self, model: str = "text-embedding-3-small", api_key: str = None): + from openai import AzureOpenAI + + self.model = model + if not api_key: + api_key = os.getenv("AZURE_API_KEY") + + self.client = AzureOpenAI( + api_key=api_key, + api_version=os.getenv("AZURE_API_VERSION"), + azure_endpoint=os.getenv("AZURE_API_BASE"), + ) + + def get_embedding(self, text: str) -> Tuple[np.ndarray, int]: + response = self.client.embeddings.create(input=text, model=self.model) + + embedding = np.array(response.data[0].embedding) + token = response.usage.prompt_tokens + return embedding, token + + +def get_text_embeddings( + texts: Union[str, List[str]], + max_workers: int = 5, + embedding_cache: Optional[Dict[str, np.ndarray]] = None, +) -> Tuple[np.ndarray, int]: + """ + Get text embeddings using OpenAI's text-embedding-3-small model. + + Args: + texts (Union[str, List[str]]): A single text string or a list of text strings to embed. + max_workers (int): The maximum number of workers for parallel processing. + api_key (str): The API key for accessing OpenAI's services. + embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings. + + Returns: + Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage. + """ + embedding_model = None + encoder_type = os.getenv("ENCODER_API_TYPE") + if encoder_type and encoder_type == "openai": + embedding_model = OpenAIEmbeddingModel() + elif encoder_type and encoder_type == "azure": + embedding_model = AzureOpenAIEmbeddingModel() + elif encoder_type == encoder_type == "together": + embedding_model = TogetherEmbeddingModel() + else: + raise Exception( + "No valid encoder type is provided. Check /secrets.toml for the field ENCODER_API_TYPE" + ) + + def fetch_embedding(text: str) -> Tuple[str, np.ndarray, int]: + if embedding_cache is not None and text in embedding_cache: + return ( + text, + embedding_cache[text], + 0, + ) # Returning 0 tokens since no API call is made + embedding, token_usage = embedding_model.get_embedding(text) + return text, embedding, token_usage + + if isinstance(texts, str): + _, embedding, tokens = fetch_embedding(texts) + return np.array(embedding), tokens + + embeddings = [] + total_tokens = 0 + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(fetch_embedding, text): text for text in texts} + + for future in as_completed(futures): + try: + text, embedding, tokens = future.result() + embeddings.append((text, embedding, tokens)) + total_tokens += tokens + except Exception as e: + print(f"An error occurred for text: {futures[future]}") + print(e) + + # Sort results to match the order of the input texts + embeddings.sort(key=lambda x: texts.index(x[0])) + if embedding_cache is not None: + for text, embedding, _ in embeddings: + embedding_cache[text] = embedding + embeddings = [result[1] for result in embeddings] + + return np.array(embeddings), total_tokens diff --git a/knowledge_storm/interface.py b/knowledge_storm/interface.py index f6c11bd9..5922602f 100644 --- a/knowledge_storm/interface.py +++ b/knowledge_storm/interface.py @@ -1,27 +1,23 @@ +import concurrent.futures +import dspy import functools +import hashlib +import json import logging import time from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, TYPE_CHECKING + +from .utils import ArticleTextProcessing logging.basicConfig( level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s" ) logger = logging.getLogger(__name__) - -class Information(ABC): - """Abstract base class to represent basic information. - - Attributes: - uuid (str): The unique identifier for the information. - meta (dict): The meta information associated with the information. - """ - - def __init__(self, uuid, meta={}): - self.uuid = uuid - self.meta = meta +if TYPE_CHECKING: + from .logging_wrapper import LoggingWrapper class InformationTable(ABC): @@ -42,6 +38,101 @@ def retrieve_information(**kwargs): pass +class Information: + """Class to represent detailed information. + + Inherits from Information to include a unique identifier (URL), and extends + it with a description, snippets, and title of the storm information. + + Attributes: + description (str): Brief description. + snippets (list): List of brief excerpts or snippets. + title (str): The title or headline of the information. + url (str): The unique URL (serving as UUID) of the information. + """ + + def __init__(self, url, description, snippets, title, meta=None): + """Initialize the Information object with detailed attributes. + + Args: + url (str): The unique URL serving as the identifier for the information. + description (str): Detailed description. + snippets (list): List of brief excerpts or snippet. + title (str): The title or headline of the information. + """ + self.description = description + self.snippets = snippets + self.title = title + self.url = url + self.meta = meta if meta is not None else {} + self.citation_uuid = -1 + + def __hash__(self): + return hash( + ( + self.url, + tuple(sorted(self.snippets)), + ) + ) + + def __eq__(self, other): + if not isinstance(other, Information): + return False + return ( + self.url == other.url + and set(self.snippets) == set(other.snippets) + and self._meta_str() == other._meta_str() + ) + + def __hash__(self): + return int( + self._md5_hash((self.url, tuple(sorted(self.snippets)), self._meta_str())), + 16, + ) + + def _meta_str(self): + """Generate a string representation of relevant meta information.""" + return f"Question: {self.meta.get('question', '')}, Query: {self.meta.get('query', '')}" + + def _md5_hash(self, value): + """Generate an MD5 hash for a given value.""" + if isinstance(value, (dict, list, tuple)): + value = json.dumps(value, sort_keys=True) + return hashlib.md5(str(value).encode("utf-8")).hexdigest() + + @classmethod + def from_dict(cls, info_dict): + """Create a Information object from a dictionary. + Usage: info = Information.from_dict(storm_info_dict) + + Args: + info_dict (dict): A dictionary containing keys 'url', 'description', + 'snippets', and 'title' corresponding to the object's attributes. + + Returns: + Information: An instance of Information. + """ + info = cls( + url=info_dict["url"], + description=info_dict["description"], + snippets=info_dict["snippets"], + title=info_dict["title"], + meta=info_dict.get("meta", None), + ) + info.citation_uuid = int(info_dict.get("citation_uuid", -1)) + return info + + def to_dict(self): + return { + "url": self.url, + "description": self.description, + "snippets": self.snippets, + "title": self.title, + "meta": self.meta, + "citation_uuid": self.citation_uuid, + } + + class ArticleSectionNode: """ The ArticleSectionNode is the dataclass for handling the section of the article. @@ -166,7 +257,7 @@ def prune_empty_nodes(self, node=None): return node -class Retriever(ABC): +class Retriever: """ An abstract base class for retriever modules. It provides a template for retrieving information based on a query. @@ -175,19 +266,14 @@ class Retriever(ABC): The retrieval model/search engine used for each part should be declared with a suffix '_rm' in the attribute name. """ - def __init__(self, search_top_k): - self.search_top_k = search_top_k - - def update_search_top_k(self, k): - self.search_top_k = k + def __init__(self, rm: dspy.Retrieve, max_thread: int = 1): + self.max_thread = max_thread + self.rm = rm def collect_and_reset_rm_usage(self): combined_usage = [] - for attr_name in self.__dict__: - if "_rm" in attr_name and hasattr( - getattr(self, attr_name), "get_usage_and_reset" - ): - combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) + if hasattr(getattr(self, "rm"), "get_usage_and_reset"): + combined_usage.append(getattr(self, "rm").get_usage_and_reset()) name_to_usage = {} for usage in combined_usage: @@ -199,21 +285,38 @@ def collect_and_reset_rm_usage(self): return name_to_usage - @abstractmethod - def retrieve(self, query: Union[str, List[str]], **kwargs) -> List[Information]: - """ - Retrieves information based on a query. - - This method must be implemented by subclasses to specify how information is retrieved. - - Args: - query (Union[str, List[str]]): The query or list of queries to retrieve information for. - **kwargs: Additional keyword arguments that might be necessary for the retrieval process. - - Returns: - List[Information]: A list of Information objects retrieved based on the query. - """ - pass + def retrieve( + self, query: Union[str, List[str]], exclude_urls: List[str] = [] + ) -> List[Information]: + queries = query if isinstance(query, list) else [query] + to_return = [] + + def process_query(q): + retrieved_data_list = self.rm( + query_or_queries=[q], exclude_urls=exclude_urls + ) + local_to_return = [] + for data in retrieved_data_list: + for i in range(len(data["snippets"])): + # STORM generate the article with citations. We do not consider multi-hop citations. + # Remove citations in the source to avoid confusion. + data["snippets"][i] = ArticleTextProcessing.remove_citations( + data["snippets"][i] + ) + storm_info = Information.from_dict(data) + storm_info.meta["query"] = q + local_to_return.append(storm_info) + return local_to_return + + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_thread + ) as executor: + results = list(executor.map(process_query, queries)) + + for result in results: + to_return.extend(result) + + return to_return class KnowledgeCurationModule(ABC): @@ -458,3 +561,49 @@ def reset(self): self.time = {} self.lm_cost = {} self.rm_cost = {} + + +class Agent(ABC): + """ + Interface for STORM and Co-STORM LLM agent + + This class must be implemented by any subclass of `Agent` to define how the agent generates an utterance. + The generated utterance can be influenced by the conversation history, knowledge base, and any additional parameters passed via `kwargs`. + The implementation should align with the specific role and perspective of the agent, as defined by the agent's topic, role name, and role description. + + Args: + knowledge_base (KnowledgeBase): The current knowledge base (e.g., mind map in Co-STORM) that contains the accumulated information relevant to the conversation. + conversation_history (List[ConversationTurn]): A list of past conversation turns, providing context for generating the next utterance. + The agent can refer to this history to maintain continuity and relevance in the conversation. + logging_wrapper (LoggingWrapper): A wrapper used for logging important events during the utterance generation process. + **kwargs: Additional arguments that can be passed to the method for more specialized utterance generation behavior depending on the agent's specific implementation. + + Returns: + ConversationTurn: A new conversation turn generated by the agent, containing the agent's response, including the role, utterance type, and relevant information from the knowledge base. + + Notes: + - Subclasses of `Agent` should define the exact strategy for generating the utterance, which could involve interacting with a language model, retrieving relevant knowledge, or following specific conversational policies. + - The agent's role, perspective, and the knowledge base content will influence how the utterance is formulated. + """ + + from .dataclass import KnowledgeBase, ConversationTurn + + def __init__(self, topic: str, role_name: str, role_description: str): + self.topic = topic + self.role_name = role_name + self.role_description = role_description + + def get_role_description(self): + if self.role_description: + return f"{self.role_name}: {self.role_description}" + return self.role_name + + @abstractmethod + def generate_utterance( + self, + knowledge_base: KnowledgeBase, + conversation_history: List[ConversationTurn], + logging_wrapper: "LoggingWrapper", + **kwargs, + ): + pass diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index 2c0773d4..0cae49be 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -24,7 +24,7 @@ class OpenAIModel(dspy.OpenAI): def __init__( self, - model: str = "gpt-3.5-turbo-instruct", + model: str = "gpt-4o-mini", api_key: Optional[str] = None, model_type: Literal["chat", "text"] = None, **kwargs, @@ -211,7 +211,7 @@ def __init__( self, api_base: Optional[str] = None, api_version: Optional[str] = None, - model: str = "gpt-3.5-turbo-instruct", + model: str = "gpt-4o-mini", api_key: Optional[str] = None, model_type: Literal["chat", "text"] = "chat", **kwargs, @@ -674,21 +674,28 @@ class TogetherClient(dspy.HFModel): def __init__( self, model, + api_key: Optional[str] = None, apply_tokenizer_chat_template=False, hf_tokenizer_name=None, + model_type: Literal["chat", "text"] = "chat", **kwargs, ): """Copied from dspy/dsp/modules/hf_client.py with the support of applying tokenizer chat template.""" super().__init__(model=model, is_client=True) self.session = requests.Session() - self.api_base = ( - "https://api.together.xyz/v1/completions" - if os.getenv("TOGETHER_API_BASE") is None - else os.getenv("TOGETHER_API_BASE") + self.api_key = api_key = ( + os.environ.get("TOGETHER_API_KEY") if api_key is None else api_key ) - self.token = os.getenv("TOGETHER_API_KEY") self.model = model + self.model_type = model_type + if os.getenv("TOGETHER_API_BASE") is None: + if self.model_type == "chat": + self.api_base = "https://api.together.xyz/v1/chat/completions" + else: + self.api_base = "https://api.together.xyz/v1/completions" + else: + self.api_base = os.getenv("TOGETHER_API_BASE") # self.use_inst_template = False # if any(keyword in self.model.lower() for keyword in ["inst", "instruct"]): @@ -705,12 +712,12 @@ def __init__( stop_default = "\n\n---" self.kwargs = { - "temperature": 0.0, - "max_tokens": 512, - "top_p": 1, - "top_k": 20, + "temperature": kwargs.get("temperature", 0.0), + "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", 1), "repetition_penalty": 1, - "n": 1, + "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), "stop": stop_default if "stop" not in kwargs else kwargs["stop"], **kwargs, } @@ -745,9 +752,7 @@ def get_usage_and_reset(self): max_time=1000, on_backoff=backoff_hdlr, ) - def _generate(self, prompt, use_chat_api=False, **kwargs): - url = f"{self.api_base}" - + def _generate(self, prompt, **kwargs): kwargs = {**self.kwargs, **kwargs} stop = kwargs.get("stop") @@ -762,8 +767,7 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): ) # prompt = f"[INST]{prompt}[/INST]" if self.use_inst_template else prompt - if use_chat_api: - url = f"{self.api_base}/chat/completions" + if self.model_type == "chat": messages = [ { "role": "system", @@ -793,13 +797,13 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): "stop": stop, } - headers = {"Authorization": f"Bearer {self.token}"} + headers = {"Authorization": f"Bearer {self.api_key}"} - with self.session.post(url, headers=headers, json=body) as resp: + with self.session.post(self.api_base, headers=headers, json=body) as resp: resp_json = resp.json() # Log the token usage from the Together API response. self.log_usage(resp_json) - if use_chat_api: + if self.model_type == "chat": # completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")] completions = [ resp_json.get("choices", [])[0] diff --git a/knowledge_storm/logging_wrapper.py b/knowledge_storm/logging_wrapper.py new file mode 100644 index 00000000..48d7e294 --- /dev/null +++ b/knowledge_storm/logging_wrapper.py @@ -0,0 +1,212 @@ +from contextlib import contextmanager +import time +import pytz +from datetime import datetime + +# Define California timezone +CALIFORNIA_TZ = pytz.timezone("America/Los_Angeles") + + +class EventLog: + def __init__(self, event_name): + self.event_name = event_name + self.start_time = None + self.end_time = None + self.child_events = {} + + def record_start_time(self): + self.start_time = datetime.now( + pytz.utc + ) # Store in UTC for consistent timezone conversion + + def record_end_time(self): + self.end_time = datetime.now( + pytz.utc + ) # Store in UTC for consistent timezone conversion + + def get_total_time(self): + if self.start_time and self.end_time: + return (self.end_time - self.start_time).total_seconds() + return 0 + + def get_start_time(self): + if self.start_time: + # Format to milliseconds + return self.start_time.astimezone(CALIFORNIA_TZ).strftime( + "%Y-%m-%d %H:%M:%S.%f" + )[:-3] + return None + + def get_end_time(self): + if self.end_time: + # Format to milliseconds + return self.end_time.astimezone(CALIFORNIA_TZ).strftime( + "%Y-%m-%d %H:%M:%S.%f" + )[:-3] + return None + + def add_child_event(self, child_event): + self.child_events[child_event.event_name] = child_event + + def get_child_events(self): + return self.child_events + + +class LoggingWrapper: + def __init__(self, lm_config): + self.logging_dict = {} + self.lm_config = lm_config + self.current_pipeline_stage = None + self.event_stack = [] + self.pipeline_stage_active = False + + def _pipeline_stage_start(self, pipeline_stage: str): + if self.pipeline_stage_active: + raise RuntimeError( + "A pipeline stage is already active. End the current stage before starting a new one." + ) + + self.current_pipeline_stage = pipeline_stage + self.logging_dict[pipeline_stage] = { + "time_usage": {}, + "lm_usage": {}, + "lm_history": [], + "query_count": 0, + } + self.pipeline_stage_active = True + + def _event_start(self, event_name: str): + if not self.pipeline_stage_active: + raise RuntimeError("No pipeline stage is currently active.") + + if not self.event_stack and self.current_pipeline_stage: + # Top-level event (directly under the pipeline stage) + if ( + event_name + not in self.logging_dict[self.current_pipeline_stage]["time_usage"] + ): + event = EventLog(event_name=event_name) + event.record_start_time() + self.logging_dict[self.current_pipeline_stage]["time_usage"][ + event_name + ] = event + self.event_stack.append(event) + else: + self.logging_dict[self.current_pipeline_stage]["time_usage"][ + event_name + ].record_start_time() + elif self.event_stack: + # Nested event (under another event) + parent_event = self.event_stack[-1] + if event_name not in parent_event.get_child_events(): + event = EventLog(event_name=event_name) + event.record_start_time() + parent_event.add_child_event(event) + self.logging_dict[self.current_pipeline_stage]["time_usage"][ + event_name + ] = event + self.event_stack.append(event) + else: + parent_event.get_child_events()[event_name].record_start_time() + else: + raise RuntimeError( + "Cannot start an event without an active pipeline stage or parent event." + ) + + def _event_end(self, event_name: str): + if not self.pipeline_stage_active: + raise RuntimeError("No pipeline stage is currently active.") + + if not self.event_stack: + raise RuntimeError("No parent event is currently active.") + + if self.event_stack: + current_event_log = self.event_stack[-1] + if event_name in current_event_log.get_child_events(): + current_event_log.get_child_events()[event_name].record_end_time() + elif ( + event_name + in self.logging_dict[self.current_pipeline_stage]["time_usage"] + ): + self.logging_dict[self.current_pipeline_stage]["time_usage"][ + event_name + ].record_end_time() + else: + raise AssertionError( + f"Failure to record end time for event {event_name}. Start time is not recorded." + ) + if current_event_log.event_name == event_name: + self.event_stack.pop() + else: + raise RuntimeError("Cannot end an event without an active parent event.") + + def _pipeline_stage_end(self): + if not self.pipeline_stage_active: + raise RuntimeError("No pipeline stage is currently active to end.") + + self.logging_dict[self.current_pipeline_stage][ + "lm_usage" + ] = self.lm_config.collect_and_reset_lm_usage() + self.logging_dict[self.current_pipeline_stage][ + "lm_history" + ] = self.lm_config.collect_and_reset_lm_history() + self.pipeline_stage_active = False + + def add_query_count(self, count): + if not self.pipeline_stage_active: + raise RuntimeError( + "No pipeline stage is currently active to add query count." + ) + + self.logging_dict[self.current_pipeline_stage]["query_count"] += count + + @contextmanager + def log_event(self, event_name): + if not self.pipeline_stage_active: + raise RuntimeError("No pipeline stage is currently active.") + + self._event_start(event_name) + yield + self._event_end(event_name) + + @contextmanager + def log_pipeline_stage(self, pipeline_stage): + if self.pipeline_stage_active: + print( + "A pipeline stage is already active, ending the current stage safely." + ) + self._pipeline_stage_end() + + start_time = time.time() + try: + self._pipeline_stage_start(pipeline_stage) + yield + except Exception as e: + print(f"Error occurred during pipeline stage '{pipeline_stage}': {e}") + finally: + self.logging_dict[self.current_pipeline_stage]["total_wall_time"] = ( + time.time() - start_time + ) + self._pipeline_stage_end() + + def dump_logging_and_reset(self, reset_logging=True): + log_dump = {} + for pipeline_stage, pipeline_log in self.logging_dict.items(): + time_stamp_log = { + event_name: { + "total_time_seconds": event.get_total_time(), + "start_time": event.get_start_time(), + "end_time": event.get_end_time(), + } + for event_name, event in pipeline_log["time_usage"].items() + } + log_dump[pipeline_stage] = { + "time_usage": time_stamp_log, + "lm_usage": pipeline_log["lm_usage"], + "lm_history": pipeline_log["lm_history"], + "query_count": pipeline_log["query_count"], + "total_wall_time": pipeline_log["total_wall_time"], + } + if reset_logging: + self.logging_dict.clear() + return log_dump diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 80b8a385..7f029e79 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -333,10 +333,79 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st return collected_results +class StanfordOvalArxivRM(dspy.Retrieve): + """[Alpha] This retrieval class is for internal use only, not intended for the public.""" + + def __init__(self, endpoint, k=3): + super().__init__(k=k) + self.endpoint = endpoint + self.usage = 0 + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + + return {"CS224vArxivRM": usage} + + def _retrieve(self, query: str): + payload = {"query": query, "num_blocks": self.k} + + response = requests.post( + self.endpoint, json=payload, headers={"Content-Type": "application/json"} + ) + + # Check if the request was successful + if response.status_code == 200: + data = response.json()[0] + results = [] + for i in range(len(data["title"])): + result = { + "title": data["title"][i], + "url": data["title"][i], + "snippets": [data["text"][i]], + "description": "N/A", + "meta": {"section_title": data["full_section_title"][i]}, + } + results.append(result) + + return results + else: + raise Exception( + f"Error: Unable to retrieve results. Status code: {response.status_code}" + ) + + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): + collected_results = [] + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + + for query in queries: + try: + results = self._retrieve(query) + collected_results.extend(results) + except Exception as e: + logging.error(f"Error occurs when searching query {query}: {e}") + return collected_results + + class SerperRM(dspy.Retrieve): """Retrieve information from custom queries using Serper.dev.""" - def __init__(self, serper_search_api_key=None, query_params=None): + def __init__( + self, + serper_search_api_key=None, + k=3, + query_params=None, + ENABLE_EXTRA_SNIPPET_EXTRACTION=False, + min_char_count: int = 150, + snippet_chunk_size: int = 1000, + webpage_helper_max_threads=10, + ): """Args: serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/ query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query. @@ -355,9 +424,21 @@ def __init__(self, serper_search_api_key=None, query_params=None): qdr:m str: Date time range for past month. qdr:y str: Date time range for past year. """ - super().__init__() + super().__init__(k=k) self.usage = 0 - self.query_params = query_params + self.query_params = None + self.ENABLE_EXTRA_SNIPPET_EXTRACTION = ENABLE_EXTRA_SNIPPET_EXTRACTION + self.webpage_helper = WebPageHelper( + min_char_count=min_char_count, + snippet_chunk_size=snippet_chunk_size, + max_thread_num=webpage_helper_max_threads, + ) + + if query_params is None: + self.query_params = {"num": k, "autocorrect": True, "page": 1} + else: + self.query_params = query_params + self.query_params.update({"num": k}) self.serper_search_api_key = serper_search_api_key if not self.serper_search_api_key and not os.environ.get("SERPER_API_KEY"): raise RuntimeError( @@ -435,34 +516,41 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st # Array of dictionaries that will be used by Storm to create the jsons collected_results = [] + if self.ENABLE_EXTRA_SNIPPET_EXTRACTION: + urls = [] + for result in self.results: + organic_results = result.get("organic", []) + for organic in organic_results: + url = organic.get("link") + if url: + urls.append(url) + valid_url_to_snippets = self.webpage_helper.urls_to_snippets(urls) + else: + valid_url_to_snippets = {} + for result in self.results: try: # An array of dictionaries that contains the snippets, title of the document and url that will be used. organic_results = result.get("organic") - knowledge_graph = result.get("knowledgeGraph") for organic in organic_results: - snippets = [] - snippets.append(organic.get("snippet")) - if knowledge_graph != None: - collected_results.append( - { - "snippets": snippets, - "title": organic.get("title"), - "url": organic.get("link"), - "description": knowledge_graph.get("description"), - } - ) - else: - # Common for knowledge graph to be None, set description to empty string - collected_results.append( - { - "snippets": snippets, - "title": organic.get("title"), - "url": organic.get("link"), - "description": "", - } + snippets = [organic.get("snippet")] + if self.ENABLE_EXTRA_SNIPPET_EXTRACTION: + snippets.extend( + valid_url_to_snippets.get(url, {}).get("snippets", []) ) + collected_results.append( + { + "snippets": snippets, + "title": organic.get("title"), + "url": organic.get("link"), + "description": ( + knowledge_graph.get("description") + if knowledge_graph is not None + else "" + ), + } + ) except: continue diff --git a/knowledge_storm/storm_wiki/engine.py b/knowledge_storm/storm_wiki/engine.py index de9f5f1c..62887a2a 100644 --- a/knowledge_storm/storm_wiki/engine.py +++ b/knowledge_storm/storm_wiki/engine.py @@ -12,10 +12,9 @@ from .modules.knowledge_curation import StormKnowledgeCurationModule from .modules.outline_generation import StormOutlineGenerationModule from .modules.persona_generator import StormPersonaGenerator -from .modules.retriever import StormRetriever from .modules.storm_dataclass import StormInformationTable, StormArticle -from ..interface import Engine, LMConfigs -from ..lm import OpenAIModel +from ..interface import Engine, LMConfigs, Retriever +from ..lm import OpenAIModel, AzureOpenAIModel from ..utils import FileIOHelper, makeStringRed, truncate_filename @@ -39,6 +38,7 @@ def __init__(self): def init_openai_model( self, openai_api_key: str, + azure_api_key: str, openai_type: Literal["openai", "azure"], api_base: Optional[str] = None, api_version: Optional[str] = None, @@ -46,19 +46,27 @@ def init_openai_model( top_p: Optional[float] = 0.9, ): """Legacy: Corresponding to the original setup in the NAACL'24 paper.""" + azure_kwargs = { + "api_key": azure_api_key, + "temperature": temperature, + "top_p": top_p, + "api_base": api_base, + "api_version": api_version, + } + openai_kwargs = { "api_key": openai_api_key, - "api_provider": openai_type, + "api_provider": "openai", "temperature": temperature, "top_p": top_p, "api_base": None, } if openai_type and openai_type == "openai": self.conv_simulator_lm = OpenAIModel( - model="gpt-3.5-turbo-instruct", max_tokens=500, **openai_kwargs + model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs ) self.question_asker_lm = OpenAIModel( - model="gpt-3.5-turbo", max_tokens=500, **openai_kwargs + model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs ) # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.) self.outline_gen_lm = OpenAIModel( @@ -70,6 +78,32 @@ def init_openai_model( self.article_polish_lm = OpenAIModel( model="gpt-4o-2024-05-13", max_tokens=4000, **openai_kwargs ) + elif openai_type and openai_type == "azure": + self.conv_simulator_lm = OpenAIModel( + model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs + ) + self.question_asker_lm = AzureOpenAIModel( + model="gpt-4o-mini-2024-07-18", + max_tokens=500, + **azure_kwargs, + model_type="chat", + ) + # use combination of openai and azure-openai as azure-openai does not support gpt-4 in standard deployment + self.outline_gen_lm = AzureOpenAIModel( + model="gpt-4o", max_tokens=400, **azure_kwargs, model_type="chat" + ) + self.article_gen_lm = AzureOpenAIModel( + model="gpt-4o-mini-2024-07-18", + max_tokens=700, + **azure_kwargs, + model_type="chat", + ) + self.article_polish_lm = AzureOpenAIModel( + model="gpt-4o-mini-2024-07-18", + max_tokens=4000, + **azure_kwargs, + model_type="chat", + ) else: logging.warning( "No valid OpenAI API provider is provided. Cannot use default LLM configurations." @@ -145,7 +179,7 @@ def __init__( self.args = args self.lm_configs = lm_configs - self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k) + self.retriever = Retriever(rm=rm, max_thread=self.args.max_thread_num) storm_persona_generator = StormPersonaGenerator( self.lm_configs.question_asker_lm ) diff --git a/knowledge_storm/storm_wiki/modules/article_generation.py b/knowledge_storm/storm_wiki/modules/article_generation.py index 2e711465..a23b7886 100644 --- a/knowledge_storm/storm_wiki/modules/article_generation.py +++ b/knowledge_storm/storm_wiki/modules/article_generation.py @@ -7,8 +7,8 @@ import dspy from .callback import BaseCallbackHandler -from .storm_dataclass import StormInformationTable, StormArticle, StormInformation -from ...interface import ArticleGenerationModule +from .storm_dataclass import StormInformationTable, StormArticle +from ...interface import ArticleGenerationModule, Information from ...utils import ArticleTextProcessing @@ -33,7 +33,7 @@ def __init__( def generate_section( self, topic, section_name, information_table, section_outline, section_query ): - collected_info: List[StormInformation] = [] + collected_info: List[Information] = [] if information_table is not None: collected_info = information_table.retrieve_information( queries=section_query, search_top_k=self.retrieve_top_k @@ -143,11 +143,7 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.engine = engine def forward( - self, - topic: str, - outline: str, - section: str, - collected_info: List[StormInformation], + self, topic: str, outline: str, section: str, collected_info: List[Information] ): info = "" for idx, storm_info in enumerate(collected_info): diff --git a/knowledge_storm/storm_wiki/modules/article_polish.py b/knowledge_storm/storm_wiki/modules/article_polish.py index fb85b0f3..171054db 100644 --- a/knowledge_storm/storm_wiki/modules/article_polish.py +++ b/knowledge_storm/storm_wiki/modules/article_polish.py @@ -85,14 +85,16 @@ def __init__( self.polish_page = dspy.Predict(PolishPage) def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True): - with dspy.settings.context(lm=self.write_lead_engine): + # NOTE: Change show_guidelines to false to make the generation more robust to different LM families. + with dspy.settings.context(lm=self.write_lead_engine, show_guidelines=False): lead_section = self.write_lead( topic=topic, draft_page=draft_page ).lead_section if "The lead section:" in lead_section: lead_section = lead_section.split("The lead section:")[1].strip() if polish_whole_page: - with dspy.settings.context(lm=self.polish_engine): + # NOTE: Change show_guidelines to false to make the generation more robust to different LM families. + with dspy.settings.context(lm=self.polish_engine, show_guidelines=False): page = self.polish_page(draft_page=draft_page).page else: page = draft_page diff --git a/knowledge_storm/storm_wiki/modules/knowledge_curation.py b/knowledge_storm/storm_wiki/modules/knowledge_curation.py index bde27678..d6b89295 100644 --- a/knowledge_storm/storm_wiki/modules/knowledge_curation.py +++ b/knowledge_storm/storm_wiki/modules/knowledge_curation.py @@ -8,8 +8,8 @@ from .callback import BaseCallbackHandler from .persona_generator import StormPersonaGenerator -from .storm_dataclass import DialogueTurn, StormInformationTable, StormInformation -from ...interface import KnowledgeCurationModule, Retriever +from .storm_dataclass import DialogueTurn, StormInformationTable +from ...interface import KnowledgeCurationModule, Retriever, Information from ...utils import ArticleTextProcessing try: @@ -166,7 +166,7 @@ class QuestionToQuery(dspy.Signature): class AnswerQuestion(dspy.Signature): """You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants to write a Wikipedia page on topic you know. You have gathered the related information and will now use the information to form a response. - Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.". + Make your response as informative as possible, ensuring that every sentence is supported by the gathered information. If the [gathered information] is not directly related to the [topic] or [question], provide the most relevant answer based on the available information. If no appropriate answer can be formulated, respond with, “I cannot answer this question based on the available information,” and explain any limitations or gaps. """ topic = dspy.InputField(prefix="Topic you are discussing about:", format=str) @@ -196,14 +196,13 @@ def __init__( super().__init__() self.generate_queries = dspy.Predict(QuestionToQuery) self.retriever = retriever - self.retriever.update_search_top_k(search_top_k) self.answer_question = dspy.Predict(AnswerQuestion) self.engine = engine self.max_search_queries = max_search_queries self.search_top_k = search_top_k def forward(self, topic: str, question: str, ground_truth_url: str): - with dspy.settings.context(lm=self.engine): + with dspy.settings.context(lm=self.engine, show_guidelines=False): # Identify: Break down question into queries. queries = self.generate_queries(topic=topic, question=question).queries queries = [ @@ -212,7 +211,7 @@ def forward(self, topic: str, question: str, ground_truth_url: str): ] queries = queries[: self.max_search_queries] # Search - searched_results: List[StormInformation] = self.retriever.retrieve( + searched_results: List[Information] = self.retriever.retrieve( list(set(queries)), exclude_urls=[ground_truth_url] ) if len(searched_results) > 0: diff --git a/knowledge_storm/storm_wiki/modules/retriever.py b/knowledge_storm/storm_wiki/modules/retriever.py index 85df63ec..691382b0 100644 --- a/knowledge_storm/storm_wiki/modules/retriever.py +++ b/knowledge_storm/storm_wiki/modules/retriever.py @@ -3,7 +3,6 @@ import dspy -from .storm_dataclass import StormInformation from ...interface import Retriever, Information from ...utils import ArticleTextProcessing @@ -232,26 +231,3 @@ def is_valid_wikipedia_source(url): return False return True - - -class StormRetriever(Retriever): - def __init__(self, rm: dspy.Retrieve, k=3): - super().__init__(search_top_k=k) - self._rm = rm - if hasattr(rm, "is_valid_source"): - rm.is_valid_source = is_valid_wikipedia_source - - def retrieve( - self, query: Union[str, List[str]], exclude_urls: List[str] = [] - ) -> List[Information]: - retrieved_data_list = self._rm( - query_or_queries=query, exclude_urls=exclude_urls - ) - for data in retrieved_data_list: - for i in range(len(data["snippets"])): - # STORM generate the article with citations. We do not consider multi-hop citations. - # Remove citations in the source to avoid confusion. - data["snippets"][i] = ArticleTextProcessing.remove_citations( - data["snippets"][i] - ) - return [StormInformation.from_dict(data) for data in retrieved_data_list] diff --git a/knowledge_storm/storm_wiki/modules/storm_dataclass.py b/knowledge_storm/storm_wiki/modules/storm_dataclass.py index 43826ecc..75812d9c 100644 --- a/knowledge_storm/storm_wiki/modules/storm_dataclass.py +++ b/knowledge_storm/storm_wiki/modules/storm_dataclass.py @@ -11,69 +11,13 @@ from ...utils import ArticleTextProcessing, FileIOHelper -class StormInformation(Information): - """Class to represent detailed information. - - Inherits from Information to include a unique identifier (URL), and extends - it with a description, snippets, and title of the storm information. - - Attributes: - description (str): Brief description. - snippets (list): List of brief excerpts or snippets. - title (str): The title or headline of the information. - url (str): The unique URL (serving as UUID) of the information. - """ - - def __init__(self, uuid, description, snippets, title): - """Initialize the StormInformation object with detailed attributes. - - Args: - uuid (str): The unique URL serving as the identifier for the information. - description (str): Detailed description. - snippets (list): List of brief excerpts or snippet. - title (str): The title or headline of the information. - """ - super().__init__(uuid=uuid, meta={}) - self.description = description - self.snippets = snippets - self.title = title - self.url = self.uuid - - @classmethod - def from_dict(cls, info_dict): - """Create a StormInformation object from a dictionary. - Usage: storm_info = StormInformation.from_dict(storm_info_dict) - - Args: - info_dict (dict): A dictionary containing keys 'uuid', 'description', - 'snippets', and 'title' corresponding to the object's attributes. - - Returns: - StormInformation: An instance of StormInformation. - """ - return cls( - info_dict["url"], - info_dict["description"], - info_dict["snippets"], - info_dict["title"], - ) - - def to_dict(self): - return { - "url": self.uuid, - "description": self.description, - "snippets": self.snippets, - "title": self.title, - } - - class DialogueTurn: def __init__( self, agent_utterance: str = None, user_utterance: str = None, search_queries: Optional[List[str]] = None, - search_results: Optional[List[Union[StormInformation, Dict]]] = None, + search_results: Optional[List[Union[Information, Dict]]] = None, ): self.agent_utterance = agent_utterance self.user_utterance = user_utterance @@ -83,7 +27,7 @@ def __init__( if self.search_results: for idx in range(len(self.search_results)): if type(self.search_results[idx]) == dict: - self.search_results[idx] = StormInformation.from_dict( + self.search_results[idx] = Information.from_dict( self.search_results[idx] ) @@ -91,7 +35,6 @@ def log(self): """ Returns a json object that contains all information inside `self` """ - return OrderedDict( { "agent_utterance": self.agent_utterance, @@ -115,14 +58,14 @@ class StormInformationTable(InformationTable): def __init__(self, conversations=List[Tuple[str, List[DialogueTurn]]]): super().__init__() self.conversations = conversations - self.url_to_info: Dict[str, StormInformation] = ( + self.url_to_info: Dict[str, Information] = ( StormInformationTable.construct_url_to_info(self.conversations) ) @staticmethod def construct_url_to_info( conversations: List[Tuple[str, List[DialogueTurn]]] - ) -> Dict[str, StormInformation]: + ) -> Dict[str, Information]: url_to_info = {} for persona, conv in conversations: @@ -177,7 +120,7 @@ def prepare_table_for_retrieval(self): def retrieve_information( self, queries: Union[List[str], str], search_top_k - ) -> List[StormInformation]: + ) -> List[Information]: selected_urls = [] selected_snippets = [] if type(queries) is str: @@ -231,13 +174,13 @@ def find_section( return None def _merge_new_info_to_references( - self, new_info_list: List[StormInformation], index_to_keep=None + self, new_info_list: List[Information], index_to_keep=None ) -> Dict[int, int]: """ Merges new storm information into existing references and updates the citation index mapping. Args: - new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. + new_info_list (List[Information]): A list of dictionaries representing new storm information. index_to_keep (List[int]): A list of index of the new_info_list to keep. If none, keep all. Returns: @@ -308,7 +251,7 @@ def insert_or_create_section( def update_section( self, current_section_content: str, - current_section_info_list: List[StormInformation], + current_section_info_list: List[Information], parent_section_name: Optional[str] = None, ) -> Optional[ArticleSectionNode]: """ @@ -552,7 +495,7 @@ def from_string(cls, topic_name: str, article_text: str, references: dict): article = cls(topic_name=topic_name) article.insert_or_create_section(article_dict=article_dict) for url in list(references["url_to_info"]): - references["url_to_info"][url] = StormInformation.from_dict( + references["url_to_info"][url] = Information.from_dict( references["url_to_info"][url] ) article.reference = references diff --git a/knowledge_storm/utils.py b/knowledge_storm/utils.py index 1749609b..2e3cbb65 100644 --- a/knowledge_storm/utils.py +++ b/knowledge_storm/utils.py @@ -4,7 +4,9 @@ import os import pickle import re +import regex import sys +import time from typing import List, Dict import httpx @@ -18,6 +20,8 @@ from tqdm import tqdm from trafilatura import extract +from .lm import OpenAIModel + logging.getLogger("httpx").setLevel(logging.WARNING) # Disable INFO logging for httpx. @@ -415,12 +419,14 @@ def deduplicate_group(match): @staticmethod def clean_up_citation(conv): for turn in conv.dlg_history: - turn.agent_utterance = turn.agent_utterance[ - : turn.agent_utterance.find("References:") - ] - turn.agent_utterance = turn.agent_utterance[ - : turn.agent_utterance.find("Sources:") - ] + if "References:" in turn.agent_utterance: + turn.agent_utterance = turn.agent_utterance[ + : turn.agent_utterance.find("References:") + ] + if "Sources:" in turn.agent_utterance: + turn.agent_utterance = turn.agent_utterance[ + : turn.agent_utterance.find("Sources:") + ] turn.agent_utterance = turn.agent_utterance.replace("Answer:", "").strip() try: max_ref_num = max( @@ -484,7 +490,8 @@ def clean_up_outline(outline, topic=""): outline = re.sub(r"#[#]? Summary.*?(?=##|$)", "", outline, flags=re.DOTALL) outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", "", outline, flags=re.DOTALL) outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", "", outline, flags=re.DOTALL) - + # clean up citation in outline + outline = re.sub(r"\[.*?\]", "", outline) return outline @staticmethod @@ -519,7 +526,8 @@ def clean_up_section(text): continue output_paragraphs.append(p) - return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format. + # Join with '\n\n' for markdown format. + return "\n\n".join(output_paragraphs) @staticmethod def update_citation_index(s, citation_map): @@ -693,3 +701,89 @@ def urls_to_snippets(self, urls: List[str]) -> Dict: articles[u]["snippets"] = self.text_splitter.split_text(articles[u]["text"]) return articles + + +def user_input_appropriateness_check(user_input): + my_openai_model = OpenAIModel( + api_key=os.getenv("OPENAI_API_KEY"), + api_provider="openai", + model="gpt-4o-mini-2024-07-18", + max_tokens=10, + temperature=0.0, + top_p=0.9, + ) + + if len(user_input.split()) > 20: + return "The input is too long. Please make your input topic more concise!" + + if not re.match(r'^[a-zA-Z0-9\s\-"\,\.?\']*$', user_input): + return "The input contains invalid characters. The input should only contain a-z, A-Z, 0-9, space, -/\"/,./?/'." + + prompt = f"""Here is a topic input into a knowledge curation engine that can write a Wikipedia-like article for the topic. Please judge whether it is appropriate or not for the engine to curate information for this topic based on English search engine. The following types of inputs are inappropriate: +1. Inputs that may be related to illegal, harmful, violent, racist, or sexual purposes. +2. Inputs that are given using languages other than English. Currently, the engine can only support English. +3. Inputs that are related to personal experience or personal information. Currently, the engine can only use information from the search engine. +4. Inputs that are not aimed at topic research or inquiry. For example, asks requiring detailed execution, such as calculations, programming, or specific service searches fall outside the engine's scope of capabilities. +If the topic is appropriate for the engine to process, output "Yes."; otherwise, output "No. The input violates reason [1/2/3/4]". +User input: {user_input}""" + reject_reason_info = { + 1: "Sorry, this input may be related to sensitive topics. Please try another topic. " + "(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. " + "We apologize for any inconvenience.)", + 2: "Sorry, the current engine can only support English. Please try another topic. " + "(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. " + "We apologize for any inconvenience.)", + 3: "Sorry, the current engine cannot process topics related to personal experience. Please try another topic. " + "(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. " + "We apologize for any inconvenience.)", + 4: "Sorry, STORM cannot follow arbitrary instruction. Please input a topic you want to learn about. " + "(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. " + "We apologize for any inconvenience.)", + } + + try: + response = my_openai_model(prompt)[0].replace("[", "").replace("]", "") + if response.startswith("No"): + match = regex.search(r"reason\s(\d+)", response) + if match: + reject_reason = int(match.group(1)) + if reject_reason in reject_reason_info: + return reject_reason_info[reject_reason] + else: + return ( + "Sorry, the input is inappropriate. Please try another topic!" + ) + return "Sorry, the input is inappropriate. Please try another topic!" + + except Exception as e: + return "Sorry, the input is inappropriate. Please try another topic!" + return "Approved" + + +def purpose_appropriateness_check(user_input): + my_openai_model = OpenAIModel( + api_key=os.getenv("OPENAI_API_KEY"), + api_provider="openai", + model="gpt-4o-mini-2024-07-18", + max_tokens=10, + temperature=0.0, + top_p=0.9, + ) + + prompt = f""" + Here is a purpose input into a report generation engine that can create a long-form report on any topic of interest. + Please judge whether the provided purpose is valid for using this service. + Try to judge if given purpose is non-sense like random words or just try to get around the sanity check. + You should not make the rule too strict. + + If the purpose is valid, output "Yes."; otherwise, output "No" followed by reason. + User input: {user_input} + """ + try: + response = my_openai_model(prompt)[0].replace("[", "").replace("]", "") + if response.startswith("No"): + return "Please provide a more detailed explanation on your purpose of requesting this article." + + except Exception as e: + return "Please provide a more detailed explanation on your purpose of requesting this article." + return "Approved" diff --git a/setup.py b/setup.py index 076ea6d9..d7bd6cf7 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="knowledge-storm", - version="0.2.8", + version="1.0.0", author="Yijia Shao, Yucheng Jiang", author_email="shaoyj@stanford.edu, yuchengj@stanford.edu", description="STORM: A language model-powered knowledge curation engine.", @@ -30,10 +30,9 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ], - python_requires=">=3.9", + python_requires=">=3.10", install_requires=requirements, )