diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ec2185b..528c3069 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,10 +3,33 @@ This changelog follows the specifications detailed in: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), although we have not yet reached a `1.0.0` release. -## Unreleased +## 0.3.0 ### Added + +* Added new driver script for TA3 interactions that uses a new YAML config format for ADMs +* Added several ADM config files for new driver script +* Added a new ADM HybridKaleidoADM which defers to a Llama2SingleKDMAADM instance to fill out action parameters +* Added new abstract class for action based ADMs (called ActionBasedADM), requires a `choose_action` method +* Implemented ActionBasedADM `choose_action` method on the KaleidoADM, Llama2SingleKDMAADM, and a new ADM HybridKaleidoADM * Added alignment accuracy metric in self-evaluation framework +* Added re-usable methods for filling out action parameters to Llama2SingleKDMAADM +* Added short KDMA descriptions for moral deservingness and maximization for Kaleido +* Added new prompt template for selecting the target character of an action +* Added high and low alignment system prompts for SoarTech's maximization KDMA + +### Changed + +* Replaced instances of "casualties" with "characters" as per the new new TA3 scenario data format +* Changed TA3 interface component over to using TA3 client module (rather than raw HTTP requests) +* Moved the previous `run_align_system.py` script to `run_simplified_align_system.py`, replacing it with the new primary CLI script +* Updated README with respect to new CLI script +* Changed some prompts to not display vitals with a value of None + +### Fixed + +* Fixed issue with logging of choice scores after multiple-sampling with voting +* Fixed issue where per-sample LLM outputs weren't being logged correctly ## 0.2.6 @@ -24,6 +47,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm * Fixed issue with configurable KDMA Estimator and Distance functions for Kaleido ADM +### Changed + +* Better error message on TA3 API action taken failure + ## Version 0.2.5 diff --git a/README.md b/README.md index 9ba8b558..591319e3 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,6 @@ Repository](https://github.com/NextCenturyCorporation/itm-evaluation-server). There's a corresponding client module: [TA3 Evaluation Client](https://github.com/NextCenturyCorporation/itm-evaluation-client) -Note that this client module isn't a required dependency for the ALIGN system code. - #### Soartech's TA1 API Soartech's TA1 service code can be found at: [Soartech's TA1 @@ -43,13 +41,14 @@ install git+https://github.com/ITM-Kitware/align-system.git`. ## Running the system against the TA3 action-based API ``` -$ run_action_based_align_system --help -usage: run_action_based_align_system [-h] {TA3ActionBased} ... +$ run_align_system --help +usage: run_align_system [-h] {TA3ActionBased} ... -ALIGN Action Based System CLI +ALIGN System CLI positional arguments: - {TA3ActionBased} Select interface. Adding --help after interface selection will print interface and system specified arguments + {TA3ActionBased} Select interface. Adding --help after interface selection will print interface and + system specified arguments TA3ActionBased Interface with CACI's TA3 web-based service options: @@ -59,9 +58,13 @@ options: Running `--help` after the selected interface prints the full set of options for the interface and system. E.g.: ``` -$ run_action_based_align_system TA3ActionBased --help -usage: run_action_based_align_system TA3ActionBased [-h] [-u USERNAME] [-s SESSION_TYPE] [-e API_ENDPOINT] [--training-session] [-m MODEL] [-t] [-a ALGORITHM] [-A ALGORITHM_KWARGS] - [--similarity-measure SIMILARITY_MEASURE] +$ run_align_system TA3ActionBased --help +usage: run_align_system TA3ActionBased [-h] [-u USERNAME] [-s SESSION_TYPE] + [-e API_ENDPOINT] [--training-session] + [--scenario-id SCENARIO_ID] -c ADM_CONFIG [-t] + [-l LOGLEVEL] [--logfile-path LOGFILE_PATH] + [--save-input-output-to-path SAVE_INPUT_OUTPUT_TO_PATH] + [--save-alignment-score-to-path SAVE_ALIGNMENT_SCORE_TO_PATH] options: -h, --help show this help message and exit @@ -70,28 +73,30 @@ options: -s SESSION_TYPE, --session-type SESSION_TYPE TA3 API Session Type (default: "eval") -e API_ENDPOINT, --api_endpoint API_ENDPOINT - Restful API endpoint for scenarios / probes (default: "http://127.0.0.1:8080") + Restful API endpoint for scenarios / probes (default: + "http://127.0.0.1:8080") --training-session Return training related information from API requests - -m MODEL, --model MODEL - LLM Baseline model to use + --scenario-id SCENARIO_ID + Specific scenario to run + -c ADM_CONFIG, --adm-config ADM_CONFIG + Path to ADM config YAML -t, --align-to-target Align algorithm to target KDMAs - -a ALGORITHM, --algorithm ALGORITHM - Algorithm to use - -A ALGORITHM_KWARGS, --algorithm-kwargs ALGORITHM_KWARGS - JSON encoded dictionary of kwargs for algorithm initialization - --similarity-measure SIMILARITY_MEASURE - Similarity measure to use (default: 'bert') + -l LOGLEVEL, --loglevel LOGLEVEL + --logfile-path LOGFILE_PATH + Also write log output to the specified file + --save-input-output-to-path SAVE_INPUT_OUTPUT_TO_PATH + Save system inputs and outputs to a file + --save-alignment-score-to-path SAVE_ALIGNMENT_SCORE_TO_PATH + Save alignment score output to a file ``` Here's an example invocation of the system using the TA3 Action-based interface (assuming it's running locally on port `8080`): ``` $ run_action_based_align_system TA3ActionBased \ - -e "http://127.0.0.1:8080" \ - --algorithm "llama_index" \ - --model falcon \ - -s soartech \ - --algorithm-kwargs '{"domain_docs_dir": "/data/shared/MVPData/DomainDocumentsPDF"}' + --adm-config adm_configs/metrics-evaluation/single_kdma_adm_adept_baseline.yml \ + --api_endpoint "http://127.0.0.1:8080" \ + --session-type adept ``` *NOTE* - The first time you run the system it can take upwards of a @@ -102,11 +107,11 @@ model is cached. ## Running the system against TA1 services or local files -In the Python environment you have set up, a CLI application called `run_align_system` should now be available. This single entrypoint supports interfacing with both local files on disk, and the TA3 web-based API. Running the script with `--help` shows which interfaces are available: +In the Python environment you have set up, a CLI application called `run_simplified_align_system` should now be available. This single entrypoint supports interfacing with both local files on disk, and the TA3 web-based API. Running the script with `--help` shows which interfaces are available: ``` -$ run_align_system --help -usage: run_align_system [-h] {TA1Soartech,LocalFiles,TA1Adept} ... +$ run_simplified_align_system --help +usage: run_simplified_align_system [-h] {TA1Soartech,LocalFiles,TA1Adept} ... ALIGN System CLI @@ -124,8 +129,8 @@ options: Running `--help` after the selected interface prints the full set of options for the interface and system. E.g.: ``` -$ run_align_system TA1Soartech --help -usage: run_align_system TA1Soartech [-h] [-s [SCENARIOS ...]] [--alignment-targets [ALIGNMENT_TARGETS ...]] [-e API_ENDPOINT] [-m MODEL] [-t] [-a ALGORITHM] [-A ALGORITHM_KWARGS] [--similarity-measure SIMILARITY_MEASURE] +$ run_simplified_align_system TA1Soartech --help +usage: run_simplified_align_system TA1Soartech [-h] [-s [SCENARIOS ...]] [--alignment-targets [ALIGNMENT_TARGETS ...]] [-e API_ENDPOINT] [-m MODEL] [-t] [-a ALGORITHM] [-A ALGORITHM_KWARGS] [--similarity-measure SIMILARITY_MEASURE] options: -h, --help show this help message and exit @@ -153,7 +158,7 @@ options: We've included some example scenario, probe, and alignment target data for testing. These files can be found in the `example_data` directory. Here's an example system invocation with the provided example files: ``` -run_align_system LocalFiles \ +run_simplified_align_system LocalFiles \ -s example_data/scenario_1/scenario.json \ --alignment-target-filepath example_data/scenario_1/alignment_target.json \ -p example_data/scenario_1/probe{1,2,3,4}.json \ @@ -163,56 +168,64 @@ run_align_system LocalFiles \ --align-to-target ``` -## ADM Invocations +## Metrics Evaluation ADM Invocations -### Simple Action-based Baseline ADM +### Aligned ADM for ADEPT scenarios -Simple baseline (unaligned) system using the `falcon` model: ``` run_action_based_align_system TA3ActionBased \ - --algorithm "llama_index" \ - --model falcon \ - -s soartech \ - --algorithm-kwargs '{"retrieval_enabled": false}' \ - --algorithm "llama_index" \ - --model falcon + --adm-config adm_configs/metrics-evaluation/delivered/single_kdma_adm_adept.yml \ + --username single_kdma_aligned_adm_adept \ + --align-to-target \ + --session-type adept ``` -### Simple Action-based Aligned ADM +### Aligned Hybrid Kaleido ADM for ADEPT scenarios -Simple aligned system using the `falcon` model (requires domain document PDFs): ``` run_action_based_align_system TA3ActionBased \ - --algorithm "llama_index" \ - --model falcon \ - -s soartech \ - --algorithm-kwargs '{"domain_docs_dir": "/path/to/DomainDocumentsPDF"}' \ - --algorithm-kwargs '{"retrieval_enabled": false}' \ - --algorithm "llama_index" \ - --model falcon \ - --align-to-target + --adm-config adm_configs/metrics-evaluation/delivered/hybrid_kaleido.yml \ + --username hybrid_kaleido_aligned_adm_adept \ + --align-to-target \ + --session-type adept ``` -### Action-based Chat Baseline ADM +### Baseline ADM for ADEPT scenarios + +``` +run_action_based_align_system TA3ActionBased \ + --adm-config adm_configs/metrics-evaluation/delivered/single_kdma_adm_baseline.yml \ + --username single_kdma_baseline_adm_adept \ + --session-type adept +``` -Unaligned system using a Llama 2 chat model: +### Aligned ADM for SoarTech scenarios ``` -run_action_based_chat_baseline TA3ActionBased \ - -s adept \ - --model meta-llama/Llama-2-13b-chat-hf +run_action_based_align_system TA3ActionBased \ + --adm-config adm_configs/metrics-evaluation/delivered/single_kdma_adm_soartech.yml \ + --username single_kdma_aligned_adm_soartech \ + --align-to-target \ + --session-type soartech ``` -### Action-based Chat Aligned ADM +### Aligned Hybrid Kaleido ADM for SoarTech scenarios -Aligned system using a Llama 2 chat model: +``` +run_action_based_align_system TA3ActionBased \ + --adm-config adm_configs/metrics-evaluation/delivered/hybrid_kaleido.yml \ + --username hybrid_kaleido_aligned_adm_soartech \ + --align-to-target \ + --session-type soartech +``` + +### Baseline ADM for SoarTech scenarios ``` -run_action_based_chat_baseline TA3ActionBased \ - -s adept \ - --model meta-llama/Llama-2-13b-chat-hf \ - --precision half \ - --align-to-target +run_action_based_align_system TA3ActionBased \ + --adm-config adm_configs/metrics-evaluation/delivered/single_kdma_adm_baseline.yml \ + --username single_kdma_baseline_adm_soartech \ + --session-type soartech ``` diff --git a/adm_configs/kaleido_config.yml b/adm_configs/kaleido_config.yml new file mode 100644 index 00000000..b47ab346 --- /dev/null +++ b/adm_configs/kaleido_config.yml @@ -0,0 +1,9 @@ +adm: + name: 'KaleidoADM' + init_kwargs: + model_name: 'allenai/kaleido-large' + use_tqdm: False + + inference_kwargs: + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' diff --git a/adm_configs/metrics-evaluation/delivered/hybrid_kaleido.yml b/adm_configs/metrics-evaluation/delivered/hybrid_kaleido.yml new file mode 100644 index 00000000..d9c8a230 --- /dev/null +++ b/adm_configs/metrics-evaluation/delivered/hybrid_kaleido.yml @@ -0,0 +1,17 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-large' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 diff --git a/adm_configs/metrics-evaluation/delivered/single_kdma_adm_adept.yml b/adm_configs/metrics-evaluation/delivered/single_kdma_adm_adept.yml new file mode 100644 index 00000000..ff3c39d2 --- /dev/null +++ b/adm_configs/metrics-evaluation/delivered/single_kdma_adm_adept.yml @@ -0,0 +1,12 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 5 + n_positive_samples: 5 + shuffle: true diff --git a/adm_configs/metrics-evaluation/delivered/single_kdma_adm_baseline.yml b/adm_configs/metrics-evaluation/delivered/single_kdma_adm_baseline.yml new file mode 100644 index 00000000..9952a832 --- /dev/null +++ b/adm_configs/metrics-evaluation/delivered/single_kdma_adm_baseline.yml @@ -0,0 +1,12 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: true + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true diff --git a/adm_configs/metrics-evaluation/delivered/single_kdma_adm_soartech.yml b/adm_configs/metrics-evaluation/delivered/single_kdma_adm_soartech.yml new file mode 100644 index 00000000..5dd87882 --- /dev/null +++ b/adm_configs/metrics-evaluation/delivered/single_kdma_adm_soartech.yml @@ -0,0 +1,12 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_adept_high.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_adept_high.yml new file mode 100644 index 00000000..a87b47e0 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_adept_high.yml @@ -0,0 +1,22 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-large' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-HIGH + kdma_values: + - {kdma: MoralDesert, value: 1} diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_adept_low.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_adept_low.yml new file mode 100644 index 00000000..7e03bf04 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_adept_low.yml @@ -0,0 +1,22 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-large' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-LOW + kdma_values: + - {kdma: MoralDesert, value: 0} diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_soartech_high.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_soartech_high.yml new file mode 100644 index 00000000..b19331b9 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_soartech_high.yml @@ -0,0 +1,23 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-large' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: maximization_high + kdma_values: + - kdma: maximization + value: 0.9 diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_soartech_low.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_soartech_low.yml new file mode 100644 index 00000000..e5fbb41c --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_soartech_low.yml @@ -0,0 +1,23 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-large' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: maximization_low + kdma_values: + - kdma: maximization + value: 0.1 diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xl_adept_high.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_adept_high.yml new file mode 100644 index 00000000..cfec8afc --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_adept_high.yml @@ -0,0 +1,22 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-HIGH + kdma_values: + - {kdma: MoralDesert, value: 1} diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xl_adept_low.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_adept_low.yml new file mode 100644 index 00000000..7312776f --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_adept_low.yml @@ -0,0 +1,22 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-LOW + kdma_values: + - {kdma: MoralDesert, value: 0} diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xl_soartech_high.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_soartech_high.yml new file mode 100644 index 00000000..6423ac80 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_soartech_high.yml @@ -0,0 +1,23 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: maximization_high + kdma_values: + - kdma: maximization + value: 0.9 diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xl_soartech_low.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_soartech_low.yml new file mode 100644 index 00000000..818e6771 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xl_soartech_low.yml @@ -0,0 +1,23 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: maximization_low + kdma_values: + - kdma: maximization + value: 0.1 diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_adept_high.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_adept_high.yml new file mode 100644 index 00000000..9dbedca9 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_adept_high.yml @@ -0,0 +1,22 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xxl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-HIGH + kdma_values: + - {kdma: MoralDesert, value: 1} diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_adept_low.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_adept_low.yml new file mode 100644 index 00000000..d59dd236 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_adept_low.yml @@ -0,0 +1,22 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xxl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-LOW + kdma_values: + - {kdma: MoralDesert, value: 0} diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_soartech_high.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_soartech_high.yml new file mode 100644 index 00000000..d74e83b0 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_soartech_high.yml @@ -0,0 +1,23 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xxl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: maximization_high + kdma_values: + - kdma: maximization + value: 0.9 diff --git a/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_soartech_low.yml b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_soartech_low.yml new file mode 100644 index 00000000..fa360230 --- /dev/null +++ b/adm_configs/metrics-evaluation/hybrid_kaleido_xxl_soartech_low.yml @@ -0,0 +1,23 @@ +adm: + name: 'HybridKaleidoADM' + init_kwargs: + kaleido_init_kwargs: + model_name: 'allenai/kaleido-xxl' + use_tqdm: False + + llm_init_kwargs: + hf_model: 'meta-llama/Llama-2-7b-chat-hf' + precision: 'half' + + inference_kwargs: + # Kaleido kwargs + distance_fn: 'RelevanceWeightedDistance' + kdma_descriptions_map: 'align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml' + # LLM kwargs + answer_attempts: 5 + +alignment_target_override: + id: maximization_low + kdma_values: + - kdma: maximization + value: 0.1 diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_adept_baseline.yml b/adm_configs/metrics-evaluation/single_kdma_adm_adept_baseline.yml new file mode 100644 index 00000000..341840e3 --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_adept_baseline.yml @@ -0,0 +1,17 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: true + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-HIGH + kdma_values: + - {kdma: MoralDesert, value: 1} diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_adept_baseline_low.yml b/adm_configs/metrics-evaluation/single_kdma_adm_adept_baseline_low.yml new file mode 100644 index 00000000..4dba08e2 --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_adept_baseline_low.yml @@ -0,0 +1,17 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: true + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-LOW + kdma_values: + - {kdma: MoralDesert, value: 0} diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_adept_high.yml b/adm_configs/metrics-evaluation/single_kdma_adm_adept_high.yml new file mode 100644 index 00000000..385f941e --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_adept_high.yml @@ -0,0 +1,17 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 5 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-HIGH + kdma_values: + - {kdma: MoralDesert, value: 1} diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_adept_low.yml b/adm_configs/metrics-evaluation/single_kdma_adm_adept_low.yml new file mode 100644 index 00000000..72ff7f63 --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_adept_low.yml @@ -0,0 +1,17 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 5 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-LOW + kdma_values: + - {kdma: MoralDesert, value: 0} diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_soartech_baseline.yml b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_baseline.yml new file mode 100644 index 00000000..1ddd9e5b --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_baseline.yml @@ -0,0 +1,18 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: true + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: maximization_high + kdma_values: + - kdma: maximization + value: 0.9 diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_soartech_baseline_low.yml b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_baseline_low.yml new file mode 100644 index 00000000..723d590c --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_baseline_low.yml @@ -0,0 +1,18 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: true + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: maximization_low + kdma_values: + - kdma: maximization + value: 0.1 diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_soartech_high.yml b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_high.yml new file mode 100644 index 00000000..532daffa --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_high.yml @@ -0,0 +1,18 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 5 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: maximization_high + kdma_values: + - kdma: maximization + value: 0.9 diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_soartech_high_no_negatives.yml b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_high_no_negatives.yml new file mode 100644 index 00000000..3974aff6 --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_high_no_negatives.yml @@ -0,0 +1,18 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: maximization_high + kdma_values: + - kdma: maximization + value: 0.9 diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_soartech_low.yml b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_low.yml new file mode 100644 index 00000000..0a4d81bb --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_low.yml @@ -0,0 +1,18 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 5 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: maximization_low + kdma_values: + - kdma: maximization + value: 0.1 diff --git a/adm_configs/metrics-evaluation/single_kdma_adm_soartech_low_no_negatives.yml b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_low_no_negatives.yml new file mode 100644 index 00000000..3367e792 --- /dev/null +++ b/adm_configs/metrics-evaluation/single_kdma_adm_soartech_low_no_negatives.yml @@ -0,0 +1,18 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-13b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: false + n_negative_samples: 0 + n_positive_samples: 5 + shuffle: true + +alignment_target_override: + id: maximization_low + kdma_values: + - kdma: maximization + value: 0.1 diff --git a/adm_configs/single_kdma_adm_config.yml b/adm_configs/single_kdma_adm_config.yml new file mode 100644 index 00000000..384427f8 --- /dev/null +++ b/adm_configs/single_kdma_adm_config.yml @@ -0,0 +1,17 @@ +adm: + name: 'SingleKDMAADM' + init_kwargs: + hf_model: meta-llama/Llama-2-7b-chat-hf + precision: half + temperature: 0.7 + + inference_kwargs: + baseline: true + n_negative_samples: 0 + n_positive_samples: 1 + shuffle: true + +alignment_target_override: + id: ADEPT-metrics_eval-alignment-target-train-HIGH + kdma_values: + - {kdma: MoralDesert, value: 1} diff --git a/align_system/algorithms/__init__.py b/align_system/algorithms/__init__.py index e69de29b..5653370b 100644 --- a/align_system/algorithms/__init__.py +++ b/align_system/algorithms/__init__.py @@ -0,0 +1,3 @@ +from align_system.algorithms.adms import REGISTERED_ADMS + +__all__ = ['REGISTERED_ADMS'] diff --git a/align_system/algorithms/lib/aligned_decision_maker.py b/align_system/algorithms/abstracts.py similarity index 60% rename from align_system/algorithms/lib/aligned_decision_maker.py rename to align_system/algorithms/abstracts.py index da5fb6d9..8540bfae 100644 --- a/align_system/algorithms/lib/aligned_decision_maker.py +++ b/align_system/algorithms/abstracts.py @@ -1,17 +1,30 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod + +from typing import Union +from swagger_client.models import State, Action, AlignmentTarget + + +class ActionBasedADM(ABC): + @abstractmethod + def choose_action(self, + scenario_state: State, + available_actions: list[Action], + alignment_target: Union[type[AlignmentTarget], None], + **kwargs) -> Action: + pass + # ADM sub-classes implement all the algorithm-specific logic class AlignedDecisionMaker: - @abstractmethod def __call__(self, sample, target_kdma_values, **kwargs): - + ''' target_kdma_values: { kdma_name: kdma_value, ... } - + sample = { scenario, state, @@ -21,7 +34,7 @@ def __call__(self, sample, target_kdma_values, **kwargs): ... ] } - + returns { choice: idx, [required] predicted_kdmas: { [optional] @@ -32,4 +45,4 @@ def __call__(self, sample, target_kdma_values, **kwargs): } } ''' - pass \ No newline at end of file + pass diff --git a/align_system/algorithms/adms.py b/align_system/algorithms/adms.py new file mode 100644 index 00000000..7e750606 --- /dev/null +++ b/align_system/algorithms/adms.py @@ -0,0 +1,9 @@ +from align_system.algorithms.kaleido_adm import KaleidoADM +from align_system.algorithms.llama_2_single_kdma_adm import Llama2SingleKDMAADM +from align_system.algorithms.hybrid_kaleido_adm import HybridKaleidoADM + +REGISTERED_ADMS = { + 'KaleidoADM': KaleidoADM, + 'HybridKaleidoADM': HybridKaleidoADM, + 'SingleKDMAADM': Llama2SingleKDMAADM, +} diff --git a/align_system/algorithms/chat_kdma_predicting_adm.py b/align_system/algorithms/chat_kdma_predicting_adm.py index c229adaa..a8bb9e96 100644 --- a/align_system/algorithms/chat_kdma_predicting_adm.py +++ b/align_system/algorithms/chat_kdma_predicting_adm.py @@ -3,7 +3,7 @@ import os from typing import Union, List, Dict, Tuple, Optional, TextIO from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel -from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker +from align_system.algorithms.abstracts import AlignedDecisionMaker from align_system.algorithms.lib.util import read_template class ChatKDMAPredictingADM(ChatLanguageModel, AlignedDecisionMaker): diff --git a/align_system/algorithms/hybrid_kaleido_adm.py b/align_system/algorithms/hybrid_kaleido_adm.py new file mode 100644 index 00000000..a60b750e --- /dev/null +++ b/align_system/algorithms/hybrid_kaleido_adm.py @@ -0,0 +1,48 @@ +from swagger_client.models import ActionTypeEnum + +from align_system.utils import logging +from align_system.algorithms.abstracts import ActionBasedADM +from align_system.algorithms.llama_2_single_kdma_adm import Llama2SingleKDMAADM +from align_system.algorithms.kaleido_adm import KaleidoADM + +log = logging.getLogger(__name__) + + +class HybridKaleidoADM(ActionBasedADM): + def __init__(self, **kwargs): + self.kaleido_adm = KaleidoADM(**kwargs.get('kaleido_init_kwargs', {})) + + self.llm_algorithm = Llama2SingleKDMAADM(**kwargs.get('llm_init_kwargs', {})) + self.llm_algorithm.load_model() + + def choose_action(self, scenario_state, available_actions, alignment_target, **kwargs): + action_to_take = self.kaleido_adm.choose_action( + scenario_state, available_actions, alignment_target, **kwargs) + + if action_to_take.action_type == ActionTypeEnum.APPLY_TREATMENT: + # If the additional required fields are already populated + # for the action, don't need ask the LLM again + if (action_to_take.parameters is None + or not {'treatment', 'location'}.issubset( + action_to_take.parameters.keys())): + action_to_take = self.llm_algorithm.populate_treatment_parameters( + scenario_state, action_to_take, alignment_target, **kwargs) + elif action_to_take.action_type == ActionTypeEnum.TAG_CHARACTER: + # If the additional required fields are already populated + # for the action, don't need ask the LLM again + if (action_to_take.character_id is None + or action_to_take.parameters is None + or not {'category'}.issubset( + action_to_take.parameters.keys())): + action_to_take = self.llm_algorithm.populate_tagging_parameters( + scenario_state, action_to_take, alignment_target, **kwargs) + elif action_to_take.action_type in {ActionTypeEnum.CHECK_ALL_VITALS, + ActionTypeEnum.CHECK_PULSE, + ActionTypeEnum.CHECK_RESPIRATION, + ActionTypeEnum.MOVE_TO_EVAC}: + # These actions require a `character_id` + if action_to_take.character_id is None: + action_to_take = self.llm_algorithm.generic_populate_character_id( + scenario_state, action_to_take, alignment_target, **kwargs) + + return action_to_take diff --git a/align_system/algorithms/kaleido_adm.py b/align_system/algorithms/kaleido_adm.py index 07195ff0..7bd32f45 100644 --- a/align_system/algorithms/kaleido_adm.py +++ b/align_system/algorithms/kaleido_adm.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod from functools import reduce import inspect +import yaml import pandas as pd +from align_system.algorithms.abstracts import ActionBasedADM from align_system.algorithms.lib.kaleido import KaleidoSys -from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker +from align_system.algorithms.abstracts import AlignedDecisionMaker from align_system.algorithms.lib.util import format_template from align_system.utils import logging @@ -68,7 +70,7 @@ def __call__(self, group_records: pd.DataFrame) -> pd.DataFrame: MeanDistance2] -class KaleidoADM(AlignedDecisionMaker): +class KaleidoADM(AlignedDecisionMaker, ActionBasedADM): def __init__(self, **kwargs): log.info('Initializing Kaleido..') self.kaleido = KaleidoSys(**kwargs) @@ -188,10 +190,6 @@ def force_choice(self, kaleido_results, choices, distance_fn=DefaultDistanceFunc return output_choice_idx def __call__(self, sample, target_kdma_values, labels, **kwargs): - import yaml - - from align_system.algorithms.lib.util import format_template - kdma_descriptions_map = None if 'kdma_descriptions_map' in kwargs: if isinstance(kwargs['kdma_descriptions_map'], str): @@ -247,3 +245,51 @@ def __call__(self, sample, target_kdma_values, labels, **kwargs): return {'choice': selected_choice_idx, 'predicted_kdma_values': predicted_kdma_values} + + def choose_action(self, scenario_state, available_actions, alignment_target, **kwargs): + kdma_descriptions_map = None + if 'kdma_descriptions_map' in kwargs: + if isinstance(kwargs['kdma_descriptions_map'], str): + with open(kwargs['kdma_descriptions_map']) as f: + kdma_descriptions_map = yaml.safe_load(f) + elif isinstance(kwargs['kdma_descriptions_map'], dict): + kdma_descriptions_map = kwargs['kdma_descriptions_map'] + + template = 'What action do you take? "{{choice}}" instead of {{other_choices}}. In the context of the following scenario: {{scene_state}} {{mission}} {{decision_environment}} With the following casualties:\n{{characters_str}}' + # TODO: Parameterize this + # template = 'Choosing to "{{choice}}". In the context of the following scenario: {{scene_state}} {{mission}} {{decision_environment}} With the following casualties:\n{{characters_str}}' + characters_str = '\n'.join(['{} ({}): {}'.format(c.name, c.id, c.unstructured) for c in scenario_state.characters]) + if scenario_state.mission is None: + mission_str = '' + else: + mission_str = scenario_state.mission.unstructured + + partial_template = format_template( + template, + partial=True, + scene_state=scenario_state.unstructured, + mission=mission_str, + decision_environment=scenario_state.environment.decision_environment.unstructured.strip(), + characters_str=characters_str) + + # Scaling KDMA values by 10 (range should be 0-10) + if not isinstance(alignment_target, dict): + alignment_target = alignment_target.to_dict() + + target_kdma_values = {t['kdma']: t['value'] * 10 for t + in alignment_target.get('kdma_values', ())} + + choices_unstructured = [a.unstructured for a in available_actions] + + kaleido_results = self.estimate_kdma_values( + partial_template, + choices_unstructured, + target_kdma_values, + kdma_descriptions_map=kdma_descriptions_map) + + selected_choice_idx = self.force_choice( + kaleido_results, + choices_unstructured, + distance_fn=kwargs.get('distance_fn', DefaultDistanceFunction)) + + return available_actions[selected_choice_idx] diff --git a/align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml b/align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml new file mode 100644 index 00000000..fb1883a3 --- /dev/null +++ b/align_system/algorithms/lib/templates/kdma_descriptions_short_metrics_eval.yml @@ -0,0 +1,9 @@ +MoralDesert: + name: Moral Deservingness + description: Favoring victims or bystanders over those who are perpetrators or otherwise potentially blame-worthy + vrd: Value + +maximization: + name: Maximization + description: Prioritize gathering the most information possible before acting on that information to maximize results + vrd: Value diff --git a/align_system/algorithms/llama_2_single_kdma_adm.py b/align_system/algorithms/llama_2_single_kdma_adm.py index c3fdc12d..db85c69f 100644 --- a/align_system/algorithms/llama_2_single_kdma_adm.py +++ b/align_system/algorithms/llama_2_single_kdma_adm.py @@ -3,7 +3,7 @@ import random import os import pathlib -from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker +from align_system.algorithms.abstracts import AlignedDecisionMaker from jinja2.exceptions import TemplateError @@ -33,7 +33,8 @@ 'denial', 'moral_deservingness', 'lives_saved', - 'continuation_of_care' + 'continuation_of_care', + 'maximization' } kdma_remapping = { @@ -134,13 +135,13 @@ def load_model(self, model=None, tokenizer=None): else: self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision) self.model = self.model.to(self.device) - + self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model) - + if self.chat_template is not None: with open(os.path.join(chat_template_path, self.chat_template), 'r') as f: self.tokenizer.chat_template = f.read().replace(' ', '').replace('\n', '') - + def get_character_ids(self, character_str): @@ -271,7 +272,7 @@ def respond_to_dialog(self, dialog, prefix=None): for message in dialog: if message['role'] == 'system': message['role'] = 'user' - + if len(new_dialog) == 0: new_dialog.append(message) continue @@ -301,8 +302,6 @@ def respond_to_dialog(self, dialog, prefix=None): # Print the generated model output generated_output = self.tokenizer.decode(outputs.sequences[0][prompt_length:]) inference_pair['output'] = generated_output - - print('INFERENCE PAIR\n', inference_pair) return generated_output, inference_pair @@ -372,7 +371,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam inference_pairs = [] if not baseline: unsupported_kdmas = {kdma_remapping.get(k, k) - for k in target_kdmas.keys()} - kdmas + for k in target_kdmas.keys()} - kdmas if len(unsupported_kdmas) > 0: raise RuntimeError(f"KDMA(s) {unsupported_kdmas} not supported.") @@ -386,7 +385,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam if baseline: system_message = load_system_message() system_message_keys = 'baseline' - + else: system_message_keys = {kdma: 'high' if value > 5 else 'low' for kdma, value in target_kdmas.items()} @@ -418,13 +417,13 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam break except RuntimeError as e: pass - + if not good_parse: reasoning, answer_idx, parse_method = Llama2SingleKDMAADM.bert_similarity_parse(high_response, shuffled_choices) - + print('CHOSEN ANSWER IDX', answer_idx, shuffled_choices) assert answer_idx is not None, f'Failed to parse answer index from generated output: {low_response}' - + responses.append({ 'response': high_response, 'reasoning': reasoning, @@ -434,7 +433,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam 'aligned': True, 'parse_method': parse_method, }) - + for _ in range(n_negative_sampels): system_message_keys = {kdma: 'high' if not value > 5 else 'low' for kdma, value in target_kdmas.items()} @@ -465,12 +464,12 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam break except RuntimeError as e: pass - + if not good_parse: reasoning, answer_idx, parse_method = Llama2SingleKDMAADM.bert_similarity_parse(low_response, shuffled_choices) assert answer_idx is not None, f'Failed to parse answer index from generated output: {low_response}' - + responses.append({ 'response': low_response, 'reasoning': reasoning, @@ -558,9 +557,9 @@ def parse_generated_output(generated_output, n_choices): pass except json.JSONDecodeError: pass - - + + if answer_idx is None: parse_method = 'string' # If json parsing fails, do string parsing @@ -583,15 +582,15 @@ def parse_generated_output(generated_output, n_choices): if answer_idx is not None: break - + if reasoning is None: reasoning = generated_output - + if answer_idx is None or answer_idx >= n_choices: raise RuntimeError(f'Failed to parse answer index < {n_choices} from generated output: {generated_output}') return reasoning, answer_idx, parse_method - + @staticmethod def bert_similarity_parse(generated_output, choices): print('BERT SIMILARITY PARSE') @@ -725,11 +724,15 @@ def run_aligned_decision_maker_with_voting( log.warning(f"Error calculating votes: {e}") choice_scores = [None] * len(choices) + log.debug("[bold]*RESPONSES*[bold]", extra={"markup": True}) + for i, ip in enumerate(inference_pairs): + log.debug("[bold]*response {}*[bold]".format(i+1), + extra={"markup": True}) + log.debug(ip['output']) + log.explain("[bold]*CHOICE SCORES*[/bold]", extra={"markup": True}) - log.explain(json.dumps({c: s for c, s in zip(choices, choice_scores)}, - indent=4), - extra={"highlighter": JSON_HIGHLIGHTER}) + log.explain("\n".join([f"{c}: {s}" for c, s in zip(choices, choice_scores)])) results = { 'prompt': prompt, @@ -781,16 +784,16 @@ def __call__(self, sample, target_kdma_values, **kwargs): prompt += f'\n{sample["probe"]}' choices = sample['choices'] - + labels = kwargs.get('labels', {}) - + alignment_target = None - if target_kdma_values is not None: + if target_kdma_values is not None and len(target_kdma_values) > 0: target_kdma = next(iter(next(iter(filter(lambda x: len(x) > 0, labels))))) # get the frist key of the first label that is not empty - + for label in labels: assert len(label) == 0 or (target_kdma in label and len(label) == 1), f'All labels must have the same KDMA: labels={labels}' - + alignment_target = { target_kdma: target_kdma_values[target_kdma] } @@ -804,7 +807,7 @@ def __call__(self, sample, target_kdma_values, **kwargs): baseline=kwargs.get('baseline', False), shuffle=kwargs.get('shuffle', False) ) - + raw_data = { 'params': { 'model': self.hf_model, @@ -825,3 +828,310 @@ def __call__(self, sample, target_kdma_values, **kwargs): 'raw_data': raw_data, } } + + def choose_action(self, scenario_state, available_actions, alignment_target, **kwargs): + from swagger_client.models import ActionTypeEnum + + kdma_name_map = { + 'MoralDesert': 'moral_deservingness', + 'maximization': 'maximization', + } + + if alignment_target is None or len(alignment_target.kdma_values) == 0: + target_kdma_values = {} + else: + alignment_target_dict = alignment_target.to_dict() + target_kdma_values = { + kdma_name_map[k['kdma']]: k['value'] * 10 + for k in alignment_target_dict.get('kdma_values', ()) + } + + scenario = '\nCHARACTERS:\n' + + for character in scenario_state.characters: + scenario += f'{character.name}: {character.unstructured}\n' + scenario += f'{character.name}\'s intent: {character.intent}\n\n' + + scenario += f'\nSITUATION:\n{scenario_state.unstructured}' + + state = None + + probe = '' + + choices = [ + action.unstructured + for action in available_actions + ] + + response = self.__call__({ + 'scenario': scenario, + 'state': state, + 'probe': probe, + 'choices': choices}, + target_kdma_values, + labels=[target_kdma_values]*len(choices), + **kwargs) + + action_to_take = available_actions[response['choice']] + + if action_to_take.action_type == ActionTypeEnum.APPLY_TREATMENT: + # If the additional required fields are already populated + # for the action, don't need ask the LLM again + if (action_to_take.parameters is None + or not {'treatment', 'location'}.issubset( + action_to_take.parameters.keys())): + action_to_take = self.populate_treatment_parameters( + scenario_state, action_to_take, alignment_target, **kwargs) + elif action_to_take.action_type == ActionTypeEnum.TAG_CHARACTER: + # If the additional required fields are already populated + # for the action, don't need ask the LLM again + if (action_to_take.character_id is None + or action_to_take.parameters is None + or not {'category'}.issubset( + action_to_take.parameters.keys())): + action_to_take = self.populate_tagging_parameters( + scenario_state, action_to_take, alignment_target, **kwargs) + elif action_to_take.action_type in {ActionTypeEnum.CHECK_ALL_VITALS, + ActionTypeEnum.CHECK_PULSE, + ActionTypeEnum.CHECK_RESPIRATION, + ActionTypeEnum.MOVE_TO_EVAC}: + # These actions require a `character_id` + if action_to_take.character_id is None: + action_to_take = self.generic_populate_character_id( + scenario_state, action_to_take, alignment_target, **kwargs) + + return action_to_take + + def populate_treatment_parameters(self, scenario_state, treatment_action, alignment_target, **kwargs): + from align_system.prompt_engineering.common import ( + prepare_treatment_selection_prompt) + from swagger_client.models import ActionTypeEnum, InjuryLocationEnum + from align_system.utils import get_swagger_class_enum_values + + assert treatment_action.action_type == ActionTypeEnum.APPLY_TREATMENT + + character_id = treatment_action.character_id + if character_id is None: + # Need to populate character_id on treatment action + treatment_action = self.generic_populate_character_id( + scenario_state, treatment_action, alignment_target, **kwargs) + + character_id = treatment_action.character_id + + matching_characters = [c for c in scenario_state.characters + if c.id == character_id] + + assert len(matching_characters) == 1 + + character_to_treat = matching_characters[0] + + available_supplies = [s for s in scenario_state.supplies if s.quantity > 0] + + treatment_prompt = prepare_treatment_selection_prompt( + character_to_treat.unstructured, + character_to_treat.vitals.to_dict(), + [s.to_dict() for s in available_supplies]) + + for _ in range(kwargs.get('answer_attempts', 5)): + treatment_dialog =\ + self.build_multiple_choice_dialog( + treatment_prompt, + [s.to_dict() for s in available_supplies], + json_format=TREATMENT_MULTIPLE_CHOICE_JSON_FORMAT) + + log.debug("[bold]*TREATMENT DIALOG*[/bold]", + extra={"markup": True}) + self.log_dialog(treatment_dialog) + + raw_treatment_response, _ = self.respond_to_dialog( + treatment_dialog) + + log.info("** ADM raw treatment response: {}".format( + raw_treatment_response)) + + parsed_treatment_output = self.attempt_generic_parse( # noqa + raw_treatment_response, ['Reasoning', 'Answer', 'Location']) # noqa + + if parsed_treatment_output is not None: + treatment_idx = parsed_treatment_output['Answer'] + + if len(available_supplies) <= treatment_idx: + log.info('** Selected treatment_idx out of range of ' + 'available treatment options, retrying!') + continue + + treatment = available_supplies[treatment_idx].type # noqa + + treatment_location = parsed_treatment_output['Location'] + + if treatment_action.parameters is None: + treatment_action.parameters = {} + + treatment_action.parameters['treatment'] = treatment + + valid_treatment_locations = get_swagger_class_enum_values( + InjuryLocationEnum) + + if not isinstance(treatment_location, str): + # If type is int, could be an index into the + # action_to_take)locations provided in the system + # action_to_take)prompt, consider handling in the + # action_to_take)future + log.warning("*** Treatment location value is not a string" + ", retrying!") + continue + if treatment_location in valid_treatment_locations: + treatment_action.parameters['location'] = treatment_location + else: + # Ensure that the treatment location is valid + _, treatment_loc_idx, _ = self.bert_similarity_parse( + treatment_location, valid_treatment_locations) + + treatment_action.parameters['location'] =\ + valid_treatment_locations[treatment_loc_idx] + + break + else: + log.info('** Failed to parse treatment') + + return treatment_action + + def populate_tagging_parameters(self, scenario_state, tagging_action, alignment_target, **kwargs): + from align_system.prompt_engineering.common import ( + prepare_tagging_selection_prompt) + from swagger_client.models import ActionTypeEnum, CharacterTagEnum + from align_system.utils import get_swagger_class_enum_values + + assert tagging_action.action_type == ActionTypeEnum.TAG_CHARACTER + # Ask the system to specify which triage tag to apply + + untagged_characters = [c for c in scenario_state.characters if c.tag is None] + + tagging_prompt = prepare_tagging_selection_prompt( + [c.to_dict() for c in untagged_characters], + get_swagger_class_enum_values(CharacterTagEnum)) + + for _ in range(kwargs.get('answer_attempts', 5)): + tagging_dialog = self.build_multiple_choice_dialog( + tagging_prompt, + [c.unstructured.strip() + for c in untagged_characters], + json_format=TAGGING_MULTIPLE_CHOICE_JSON_FORMAT) + + log.debug("[bold]*TAGGING DIALOG*[/bold]", + extra={"markup": True}) + self.log_dialog(tagging_dialog) + + raw_tagging_response, _ = self.respond_to_dialog( + tagging_dialog) + + log.info("** ADM raw tagging response: {}".format( + raw_tagging_response)) + + parsed_tagging_output = self.attempt_generic_parse( # noqa + raw_tagging_response, ['Reasoning', 'Answer', 'Tag']) # noqa + + if parsed_tagging_output is not None: + if len(untagged_characters) == 1: + log.debug("** Force selecting only available character") + character_idx = 0 + else: + character_idx = parsed_tagging_output['Answer'] + + if not isinstance(character_idx, int): + log.warning('** character_idx ({}) not an integer' + ', retrying!'.format(character_idx)) + continue + + if len(untagged_characters) <= character_idx: + log.info('** Selected character_idx out of range of ' + 'available treatment options, retrying!') + continue + + character_to_tag_id = untagged_characters[character_idx].id # noqa + + tag = parsed_tagging_output['Tag'] + if not isinstance(tag, str): + log.warning("** Selected tag ({}) not of type string" + ", retrying!".format(tag)) + continue + + # Populate required parameters for tagging action + tagging_action.character_id = character_to_tag_id + + if tagging_action.parameters is None: + tagging_action.parameters = {} + + tagging_action.parameters['category'] = tag + + break + else: + log.info('** Failed to parse tagging') + + return tagging_action + + def generic_populate_character_id(self, scenario_state, initial_action, alignment_target, **kwargs): + from swagger_client.models import ActionTypeEnum + from align_system.prompt_engineering.common import ( + prepare_character_selection_prompt) + character_selection_prompt = prepare_character_selection_prompt( + initial_action) + + filtered_characters = [] + for c in scenario_state.characters: + if initial_action.action_type in {ActionTypeEnum.CHECK_ALL_VITALS, + ActionTypeEnum.CHECK_PULSE, + ActionTypeEnum.CHECK_RESPIRATION}: + # Don't allow the ADM to check vitals on + # a character that's already been "visited" + if c.visited: + continue + + filtered_characters.append(c) + + for _ in range(kwargs.get('answer_attempts', 5)): + character_selection_dialog = self.build_multiple_choice_dialog( + character_selection_prompt, + [c.unstructured.strip() + for c in filtered_characters]) + + log.debug("[bold]*CHARACTER SELECTION DIALOG*[/bold]", + extra={"markup": True}) + self.log_dialog(character_selection_dialog) + + raw_character_selection_response, _ = self.respond_to_dialog( + character_selection_dialog) + + log.info("** ADM raw character_selection response: {}".format( + raw_character_selection_response)) + + parsed_character_selection_output = self.attempt_generic_parse( # noqa + raw_character_selection_response, ['Reasoning', 'Answer']) # noqa + + if parsed_character_selection_output is not None: + if len(filtered_characters) == 1: + log.debug("** Force selecting only available character") + character_idx = 0 + else: + character_idx = parsed_character_selection_output['Answer'] + + if not isinstance(character_idx, int): + log.warning('** character_idx ({}) not an integer' + ', retrying!'.format(character_idx)) + continue + + if len(filtered_characters) <= character_idx: + log.warning('** Selected character_idx out of range of ' + 'available treatment options, retrying!') + continue + + character_id = filtered_characters[character_idx].id # noqa + + # Populate required parameters for character_selection action + initial_action.character_id = character_id + + break + else: + log.info('** Failed to parse character selection') + + return initial_action diff --git a/align_system/algorithms/multi_comparison_adm.py b/align_system/algorithms/multi_comparison_adm.py index dad2e520..e70669ae 100644 --- a/align_system/algorithms/multi_comparison_adm.py +++ b/align_system/algorithms/multi_comparison_adm.py @@ -2,7 +2,7 @@ import yaml import os -from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker +from align_system.algorithms.abstracts import AlignedDecisionMaker from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel diff --git a/align_system/algorithms/pulse_tagging_adm.py b/align_system/algorithms/pulse_tagging_adm.py index ff848301..679e1f67 100644 --- a/align_system/algorithms/pulse_tagging_adm.py +++ b/align_system/algorithms/pulse_tagging_adm.py @@ -1,4 +1,4 @@ -from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker +from align_system.algorithms.abstracts import AlignedDecisionMaker from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel diff --git a/align_system/cli/run_action_based_align_system.py b/align_system/cli/run_action_based_align_system.py index 61bdd45f..855a0d7b 100644 --- a/align_system/cli/run_action_based_align_system.py +++ b/align_system/cli/run_action_based_align_system.py @@ -79,10 +79,10 @@ def main(): 'right neck', 'unspecified'] -CASUALTY_TAGS = ['MINIMAL', - 'DELAYED', - 'IMMEDIATE', - 'EXPECTANT'] +TRIAGE_TAGS = ['MINIMAL', + 'DELAYED', + 'IMMEDIATE', + 'EXPECTANT'] def run_action_based_align_system(interface, @@ -139,29 +139,30 @@ def run_action_based_align_system(interface, algorithm.load_model() - current_state = scenario.get_state() + current_state = scenario.get_state().to_dict() + scenario_complete = current_state.get('scenario_complete', False) while not scenario_complete: - available_actions = scenario.get_available_actions() + available_actions = [a.to_dict() for a in scenario.get_available_actions()] - untagged_casualties = [c for c in current_state['casualties'] + untagged_characters = [c for c in current_state['characters'] if 'tag' not in c] - # Don't let ADM choose to tag a casualty unless there are - # still untagged casualties - available_actions_unstructured =\ - [a['unstructured'] for a in available_actions - if a['action_type'] != 'TAG_CASUALTY' - or (a['action_type'] == 'TAG_CASUALTY' - and len(untagged_casualties) > 0)] + # Don't let ADM choose to tag a character unless there are + # still untagged characters + available_actions_filtered =\ + [a for a in available_actions + if a['action_type'] != 'TAG_CHARACTER' + or (a['action_type'] == 'TAG_CHARACTER' + and len(untagged_characters) > 0)] prompt = prepare_action_based_prompt( - scenario_dict['state']['unstructured'], + scenario_dict['_state'].to_dict()['unstructured'], current_state['mission'].get('unstructured'), current_state['unstructured'], - current_state['casualties'], - available_actions_unstructured, + current_state['characters'], + [a['unstructured'] for a in available_actions_filtered], alignment_target=alignment_target_dict if align_to_target else None ) log.info("[bold]* Action prompt for ADM *[/bold]", @@ -174,28 +175,29 @@ def run_action_based_align_system(interface, log.info(raw_response) selected_action_idx, selected_action = force_choice_func( - raw_response, available_actions_unstructured) + raw_response, + [a['unstructured'] for a in available_actions_filtered]) log.info("[bold]* Mapped selection *[/bold]", extra={"markup": True}) log.info(selected_action) - action_to_take = available_actions[selected_action_idx] + action_to_take = available_actions_filtered[selected_action_idx] if action_to_take['action_type'] == 'APPLY_TREATMENT': # Ask the system to specify the treatment to use and where - # First casualty with the matching ID (should only be one) - casualty_id = action_to_take['casualty_id'] - matching_casualties = [c for c in current_state['casualties'] - if c['id'] == casualty_id] + # First character with the matching ID (should only be one) + character_id = action_to_take['character_id'] + matching_characters = [c for c in current_state['characters'] + if c['id'] == character_id] - assert len(matching_casualties) == 1 - casualty_to_treat = matching_casualties[0] + assert len(matching_characters) == 1 + character_to_treat = matching_characters[0] treatment_prompt = prepare_treatment_selection_prompt( - casualty_to_treat['unstructured'], - casualty_to_treat['vitals'], + character_to_treat['unstructured'], + character_to_treat['vitals'], current_state['supplies']) log.info("[bold]** Treatment prompt for ADM **[/bold]", @@ -226,12 +228,14 @@ def run_action_based_align_system(interface, action_to_take['parameters'] = { 'treatment': treatment, 'location': treatment_location} - elif action_to_take['action_type'] == 'TAG_CASUALTY': + action_to_take['justification'] = raw_treatment_response.strip() + + elif action_to_take['action_type'] == 'TAG_CHARACTER': # Ask the system to specify which triage tag to apply tagging_prompt = prepare_tagging_selection_prompt( - untagged_casualties, - CASUALTY_TAGS) + untagged_characters, + TRIAGE_TAGS) log.info("[bold]** Tagging prompt for ADM **[/bold]", extra={"markup": True}) @@ -244,32 +248,33 @@ def run_action_based_align_system(interface, extra={"markup": True}) log.info(raw_tagging_response) - # Map response to casualty to tag - casualty_to_tag_idx, _ = force_choice_func( + # Map response to character to tag + character_to_tag_idx, _ = force_choice_func( raw_tagging_response, - [c['unstructured'] for c in untagged_casualties]) + [c['unstructured'] for c in untagged_characters]) - casualty_to_tag_id = untagged_casualties[casualty_to_tag_idx]['id'] + character_to_tag_id = untagged_characters[character_to_tag_idx]['id'] # Map response to tag _, tag = force_choice_func( raw_tagging_response, - CASUALTY_TAGS) + TRIAGE_TAGS) log.info("[bold]** Mapped tag selection **[/bold]", extra={"markup": True}) - log.info("{}: {}".format(casualty_to_tag_id, tag)) + log.info("{}: {}".format(character_to_tag_id, tag)) # Populate required parameters for treatment action - action_to_take['casualty_id'] = casualty_to_tag_id + action_to_take['character_id'] = character_to_tag_id action_to_take['parameters'] = {'category': tag} + action_to_take['justification'] = raw_tagging_response.strip() log.debug("[bold]*ACTION BEING TAKEN*[/bold]", extra={"markup": True}) log.debug(json.dumps(action_to_take, indent=4), extra={"highlighter": JSON_HIGHLIGHTER}) - current_state = scenario.take_action(action_to_take) + current_state = scenario.take_action(action_to_take).to_dict() scenario_complete = current_state.get('scenario_complete', False) diff --git a/align_system/cli/run_action_based_chat_baseline.py b/align_system/cli/run_action_based_chat_baseline.py index d13c60e6..3a337070 100644 --- a/align_system/cli/run_action_based_chat_baseline.py +++ b/align_system/cli/run_action_based_chat_baseline.py @@ -62,10 +62,10 @@ def main(): supported_interfaces={'TA3ActionBased'})) -CASUALTY_TAGS = ['MINIMAL', - 'DELAYED', - 'IMMEDIATE', - 'EXPECTANT'] +TRIAGE_TAGS = ['MINIMAL', + 'DELAYED', + 'IMMEDIATE', + 'EXPECTANT'] def run_action_based_chat_system(interface, @@ -88,33 +88,33 @@ def run_action_based_chat_system(interface, algorithm = Llama2SingleKDMAADM(hf_model=model, precision=precision) algorithm.load_model() - current_state = scenario.get_state() + current_state = scenario.get_state().to_dict() scenario_complete = current_state.get('scenario_complete', False) while not scenario_complete: - available_actions = scenario.get_available_actions() + available_actions = [a.to_dict() for a in scenario.get_available_actions()] log.debug("[bold]*AVAILABLE ACTIONS*[/bold]", extra={"markup": True}) log.debug(json.dumps(available_actions, indent=4), extra={"highlighter": JSON_HIGHLIGHTER}) - untagged_casualties = [c for c in current_state['casualties'] + untagged_characters = [c for c in current_state['characters'] if 'tag' not in c] - # Don't let ADM choose to tag a casualty unless there are - # still untagged casualties + # Don't let ADM choose to tag a character unless there are + # still untagged characters available_actions_filtered =\ [a for a in available_actions - if a['action_type'] != 'TAG_CASUALTY' - or (a['action_type'] == 'TAG_CASUALTY' - and len(untagged_casualties) > 0)] + if a['action_type'] != 'TAG_CHARACTER' + or (a['action_type'] == 'TAG_CHARACTER' + and len(untagged_characters) > 0)] prompt = prepare_action_based_prompt( - scenario_dict['state']['unstructured'], - current_state['mission'].get('unstructured'), + scenario_dict['_state'].to_dict()['unstructured'], + current_state['mission'].get('unstructured') if 'mission' in current_state else None, current_state['unstructured'], - current_state['casualties'], + current_state['characters'], available_actions=None, # Available actions passed in later alignment_target=alignment_target_dict if align_to_target else None ) @@ -193,17 +193,17 @@ def run_action_based_chat_system(interface, if action_to_take['action_type'] == 'APPLY_TREATMENT': # Ask the system to specify the treatment to use and where - # First casualty with the matching ID (should only be one) - casualty_id = action_to_take['casualty_id'] - matching_casualties = [c for c in current_state['casualties'] - if c['id'] == casualty_id] + # First character with the matching ID (should only be one) + character_id = action_to_take['character_id'] + matching_characters = [c for c in current_state['characters'] + if c['id'] == character_id] - assert len(matching_casualties) == 1 - casualty_to_treat = matching_casualties[0] + assert len(matching_characters) == 1 + character_to_treat = matching_characters[0] treatment_prompt = prepare_treatment_selection_prompt( - casualty_to_treat['unstructured'], - casualty_to_treat['vitals'], + character_to_treat['unstructured'], + character_to_treat['vitals'], current_state['supplies']) for _ in range(answer_attempts): @@ -245,18 +245,18 @@ def run_action_based_chat_system(interface, break else: log.info('** Failed to parse treatment') - elif action_to_take['action_type'] == 'TAG_CASUALTY': + elif action_to_take['action_type'] == 'TAG_CHARACTER': # Ask the system to specify which triage tag to apply tagging_prompt = prepare_tagging_selection_prompt( - untagged_casualties, - CASUALTY_TAGS) + untagged_characters, + TRIAGE_TAGS) for _ in range(answer_attempts): tagging_dialog = algorithm.build_multiple_choice_dialog( tagging_prompt, [c['unstructured'].strip() - for c in untagged_casualties], + for c in untagged_characters], json_format=TAGGING_MULTIPLE_CHOICE_JSON_FORMAT) log.debug("[bold]*TAGGING DIALOG*[/bold]", @@ -273,19 +273,19 @@ def run_action_based_chat_system(interface, raw_tagging_response, ['Reasoning', 'Answer', 'Tag']) # noqa if parsed_tagging_output is not None: - casualty_idx = parsed_tagging_output['Answer'] + character_idx = parsed_tagging_output['Answer'] - if len(untagged_casualties) <= casualty_idx: - log.info('** Selected casualty_idx out of range of ' + if len(untagged_characters) <= character_idx: + log.info('** Selected character_idx out of range of ' 'available treatment options, retrying!') continue - casualty_to_tag_id = untagged_casualties[casualty_idx]['id'] # noqa + character_to_tag_id = untagged_characters[character_idx]['id'] # noqa tag = parsed_tagging_output['Tag'] # Populate required parameters for tagging action - action_to_take['casualty_id'] = casualty_to_tag_id + action_to_take['character_id'] = character_to_tag_id action_to_take['parameters'] = {'category': tag} break @@ -297,7 +297,7 @@ def run_action_based_chat_system(interface, log.debug(json.dumps(action_to_take, indent=4), extra={"highlighter": JSON_HIGHLIGHTER}) - current_state = scenario.take_action(action_to_take) + current_state = scenario.take_action(action_to_take).to_dict() scenario_complete = current_state.get('scenario_complete', False) diff --git a/align_system/cli/run_align_system.py b/align_system/cli/run_align_system.py index f7169a7b..c577a3e5 100644 --- a/align_system/cli/run_align_system.py +++ b/align_system/cli/run_align_system.py @@ -1,18 +1,17 @@ import sys import json +import yaml +from copy import deepcopy +import atexit +from rich.logging import RichHandler +from rich.console import Console from rich.highlighter import JSONHighlighter +from swagger_client.models import AlignmentTarget, ActionTypeEnum from align_system.utils import logging from align_system.interfaces.cli_builder import build_interfaces -from align_system.algorithms.llm_baseline import LLMBaseline -from align_system.algorithms.llama_index import LlamaIndex -from align_system.similarity_measures import build_force_choice_func -from align_system.prompt_engineering.common import prepare_prompt -from align_system.utils.enums import ProbeType -from align_system.interfaces.abstracts import ( - ScenarioInterfaceWithAlignment, - ProbeInterfaceWithAlignment) +from align_system.algorithms import REGISTERED_ADMS log = logging.getLogger(__name__) @@ -20,165 +19,248 @@ def add_cli_args(parser): - parser.add_argument('-m', '--model', + # Using argparse to add our system CLI specific arguments. Can + # modify or add your own custom CLI arguments here + parser.add_argument('-c', '--adm-config', type=str, - default="falcon", - help="LLM Baseline model to use") + required=True, + help="Path to ADM config YAML") parser.add_argument('-t', '--align-to-target', action='store_true', default=False, help="Align algorithm to target KDMAs") - parser.add_argument('-a', '--algorithm', + parser.add_argument('-l', '--loglevel', type=str, - default="llama_index", - help="Algorithm to use") - parser.add_argument('-A', '--algorithm-kwargs', + default='INFO') + parser.add_argument('--logfile-path', type=str, - required=False, - help="JSON encoded dictionary of kwargs for algorithm " - "initialization") - parser.add_argument('--similarity-measure', + default=None, + help="Also write log output to the specified file") + parser.add_argument('--save-input-output-to-path', type=str, - default="bert", - help="Similarity measure to use (default: 'bert')") - parser.add_argument('-l', '--loglevel', + default=None, + help="Save system inputs and outputs to a file") + parser.add_argument('--save-alignment-score-to-path', type=str, - default='INFO') + default=None, + help="Save alignment score output to a file") def main(): + # The `build_interfaces` call here adds all interfaces as + # subparsers to your CLI. (Can specify what interfaces you + # support explicitly with the optional `supported_interfaces` + # argument (as a set)) + # The `build_interfaces` call also instantiates an interface + # object based on the selected interface and interface arguments + # provided at the command line and passes them to your run + # function (`run_custom_system` in this case) log.debug(f"[bright_black]CMD: {' '.join(sys.argv)}[/bright_black]", extra={'markup': True, 'highlighter': None}) - run_align_system( - **build_interfaces(add_cli_args, "ALIGN System CLI", - supported_interfaces={'LocalFiles', - 'TA1Soartech', - 'TA1Adept'})) - - -def run_align_system(interface, - model, - align_to_target=False, - algorithm="llm_baseline", - algorithm_kwargs=None, - similarity_measure="bert", - loglevel="INFO"): + run_action_based_chat_system( + **build_interfaces( + add_cli_args, "ALIGN System CLI", + supported_interfaces={'TA3ActionBased'})) + + +def run_action_based_chat_system(interface, + adm_config, + align_to_target, + loglevel="INFO", + logfile_path=None, + save_input_output_to_path=None, + save_alignment_score_to_path=None): # Set log level on root logger (such that child loggers respect # the set log level) - logging.getLogger().setLevel(loglevel) - - scenario = interface.start_scenario() - scenario_dict = scenario.to_dict() - - if align_to_target: - alignment_target_dict = scenario.get_alignment_target() - - force_choice_func = build_force_choice_func(similarity_measure) - - # Load the system / model - algorithm_kwargs_parsed = {} - if algorithm_kwargs is not None: - algorithm_kwargs_parsed = json.loads(algorithm_kwargs) - - if algorithm == "llm_baseline": - algorithm = LLMBaseline( - model_use=model, distributed=False, - **algorithm_kwargs_parsed) - elif algorithm == "llama_index": - # TODO: This is a hacky way to have the "Knowledge" KDMA - # determine whether or not domain documents should be loaded. - # Should remove, or move to llama_index code - if align_to_target: - for kdma_dict in alignment_target_dict.get('kdma_values', ()): - if kdma_dict['kdma'].lower() == 'knowledge': - if kdma_dict['value'] > 1: - log.debug("** Setting 'retrieval_enabled' to True " - "based on 'Knowledge' KDMA value ({})".format( - kdma_dict['value'])) - algorithm_kwargs_parsed['retrieval_enabled'] = True - else: - log.debug("** Setting 'retrieval_enabled' to False " - "based on 'Knowledge' KDMA value ({})".format( - kdma_dict['value'])) - algorithm_kwargs_parsed['retrieval_enabled'] = False - + root_logger = logging.getLogger() + root_logger.setLevel(loglevel) + + if logfile_path is not None: + logfile = open(logfile_path, 'w') + # Ensure the opened logfile is closed when the program exits + atexit.register(lambda: logfile.close()) + + filehandler = RichHandler( + console=Console(file=logfile, color_system=None)) + root_logger.addHandler(filehandler) + + with open(adm_config, 'r') as f: + config = yaml.safe_load(f) + + adm_config = config['adm'] + adm_name = adm_config['name'] + adm_init_kwargs = adm_config.get('init_kwargs', {}) + adm_inference_kwargs = adm_config.get('inference_kwargs', {}) + adm_class = REGISTERED_ADMS.get(adm_name) + + if adm_class is None: + raise RuntimeError("'adm' not found in REGISTERED_ADMS: {}".format( + list(REGISTERED_ADMS.keys()))) + + # TODO: Check that the selected ADM implements the expected + # abstract with respect to the selected "interface" + # (i.e. TA3ActionBased, vs. TA1) + adm = adm_class(**adm_init_kwargs) + + # HACK: need to invoke 'load_model' for ADMs that require it, + # maybe it makes more sense to load_model in the init method for + # those ADMs + if hasattr(adm, 'load_model'): + adm.load_model() + + # Capture inputs and outputs in a similar format to what's used by + # our internal evaluation framework code + inputs_outputs = [] + + session_alignment_scores = [] + + completed_scenarios = set() + + # Loop through available scenarios + while scenario := interface.start_scenario(): + if scenario.id() == '': + log.info("Next scenario ID is blank, assuming we're done, exiting") + break + elif scenario.id() in completed_scenarios: + log.info("Already completed this scenario, assuming we're done, exiting") + break + + if 'alignment_target_override' in config: + alignment_target = AlignmentTarget( + **config['alignment_target_override']) + elif align_to_target: + alignment_target = scenario.get_alignment_target() + else: + alignment_target = None + + current_state = scenario.get_state() + scenario_complete = current_state.scenario_complete + + # Tracking these to prevent getting stuck in a loop + noop_actions = [] + + while not scenario_complete: + available_actions = scenario.get_available_actions() + + log.debug("[bold]*AVAILABLE ACTIONS*[/bold]", + extra={"markup": True}) + log.debug(json.dumps([a.to_dict() for a in available_actions], indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + + available_actions_filtered = [] + for a in available_actions: + if a.action_type == ActionTypeEnum.TAG_CHARACTER: + # Don't let ADM choose to tag a character unless there are + # still untagged characters + untagged_characters = [c for c in current_state.characters + if c.tag is None] + if len(untagged_characters) == 0: + log.debug("No untagged characters remaining, not " + "allowing {} action".format(ActionTypeEnum.TAG_CHARACTER)) + continue + + unvisited_characters = [c for c in current_state.characters + if c.visited is None or not c.visited] + if a.action_type in {ActionTypeEnum.CHECK_ALL_VITALS, + ActionTypeEnum.CHECK_PULSE, + ActionTypeEnum.CHECK_RESPIRATION}: + if len(unvisited_characters) == 0: + log.debug("No unvisited characters remaining, not " + "allowing {} action".format(a.action_type)) + continue + + if a.action_type == ActionTypeEnum.SITREP: + conscious_characters = [c for c in current_state.characters + if c.vitals is None or (c.vitals is not None and c.vitals.conscious)] + if len(unvisited_characters) == 0 or len(conscious_characters) == 0: + log.debug("No unvisited or conscious characters remaining, not " + "allowing {} action".format(a.action_type)) + continue + + if a in noop_actions: + log.debug("Already took this action and there was no " + "change in the scenario state, not allowing " + "{} action".format(a.action_type)) + continue + + available_actions_filtered.append(a) + + if len(available_actions_filtered) == 0: + raise RuntimeError("No available actions from filtered list!") + elif len(available_actions_filtered) == 1: + log.info("** Choosing only available (filtered) action") + action_to_take = available_actions_filtered[0] + else: + action_to_take = adm.choose_action( + current_state, + available_actions_filtered, + alignment_target if align_to_target else None, + **adm_inference_kwargs) + + log.debug("[bold]*ACTION BEING TAKEN*[/bold]", + extra={"markup": True}) + if isinstance(action_to_take, dict): + log.debug(json.dumps(action_to_take, indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + else: + log.debug(json.dumps(action_to_take.to_dict(), indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + + action_choice_idx = None + for i, a in enumerate(available_actions): + if a.action_id == action_to_take.action_id: + action_choice_idx = i break - algorithm = LlamaIndex( - model_name=model, - **algorithm_kwargs_parsed) - - algorithm.load_model() - - for probe in scenario.iterate_probes(): - probe_dict = probe.to_dict() - - casualties_dicts = scenario_dict['state'].get('casualties', []) - mission_unstructured =\ - scenario_dict['state']['mission']['unstructured'] - state_unstructured = None - - if 'state' in probe_dict: - probe_state = probe_dict['state'] - if 'casualties' in probe_state: - casualties_dicts = probe_dict['state']['casualties'] - - if('mission' in probe_state and - 'unstructured' in probe_state['mission']): - mission_unstructured =\ - probe_state['mission']['unstructured'] - - if 'unstructured' in probe_state: - state_unstructured = probe_state['unstructured'] - - if probe_dict['type'] == ProbeType.MultipleChoice.value: - probe_options_dicts = probe_dict['options'] - else: - probe_options_dicts = None - - prompt = prepare_prompt( - scenario_dict['state']['unstructured'], - mission_unstructured, - state_unstructured, - probe_dict['prompt'], - casualties_dicts, - options=probe_options_dicts, - alignment_target=alignment_target_dict if align_to_target else None - ) - log.info("[bold]* Prompt for ADM *[/bold]", - extra={"markup": True}) - log.info(prompt) - - raw_response = str(algorithm.run_inference(prompt)) - log.info("[bold]* ADM raw response *[/bold]", - extra={"markup": True}) - log.info(raw_response) - - if probe_dict['type'] == ProbeType.FreeResponse.value: - probe.respond({'justification': raw_response}) - else: - # Assume multiple-choice style - selected_choice_idx, selected_choice = force_choice_func( - raw_response, [str(o['value']) for o in probe_dict['options']]) - log.info("[bold]* Mapped selection *[/bold]", - extra={"markup": True}) - log.info(selected_choice) - - selected_choice_id =\ - probe_dict['options'][selected_choice_idx]['id'] - - probe.respond({'justification': raw_response, - 'choice': selected_choice_id}) - - if isinstance(probe, ProbeInterfaceWithAlignment): - probe_alignment_results = probe.get_alignment_results() - log.info("* Probe alignment score: {}".format( - probe_alignment_results['score'])) - - if isinstance(scenario, ScenarioInterfaceWithAlignment): - scenario_alignment_results = scenario.get_alignment_results() - log.info("* Scenario alignment score: {}".format( - scenario_alignment_results['score'])) + inputs_outputs.append({'input': {'scenario_id': scenario.id(), + 'full_state': current_state.to_dict(), + 'state': current_state.unstructured, + 'choices': [a.to_dict() for a in available_actions]}, + 'label': [{} if a.kdma_association is None else a.kdma_association for a in available_actions], + 'output': {'choice': action_choice_idx, + 'action': action_to_take.to_dict()}}) + + last_state = current_state + current_state = scenario.take_action(action_to_take) + + # Check that the scenario state has really changed + # Want to restrict actions that have already been taken that + # didn't change the state + _tmp_current_state = deepcopy(current_state) + _tmp_current_state.elapsed_time = last_state.elapsed_time + state_has_changed = (_tmp_current_state != last_state) + if state_has_changed: + noop_actions = [] + else: + noop_actions.append(action_to_take) + + scenario_complete = current_state.scenario_complete + + if scenario_complete: + completed_scenarios.add(scenario.id()) + + if alignment_target is not None: + session_alignment = interface.get_session_alignment( + alignment_target.id) + + if session_alignment is None: + log.info("Couldn't get session alignment from interface") + else: + session_alignment_scores.append(session_alignment) + + log.info("[bold]*TA1 Alignment Score*[/bold]", + extra={"markup": True}) + log.info(json.dumps(session_alignment.to_dict(), indent=4), + extra={"highlighter": JSON_HIGHLIGHTER}) + + if save_input_output_to_path is not None: + with open(save_input_output_to_path, 'w') as f: + json.dump(inputs_outputs, f, indent=2) + + if len(session_alignment_scores) > 0: + if save_alignment_score_to_path is not None: + with open(save_alignment_score_to_path, 'w') as f: + json.dump([s.to_dict() for s in session_alignment_scores], f, indent=2) if __name__ == "__main__": diff --git a/align_system/cli/run_simplified_align_system.py b/align_system/cli/run_simplified_align_system.py new file mode 100644 index 00000000..f7169a7b --- /dev/null +++ b/align_system/cli/run_simplified_align_system.py @@ -0,0 +1,185 @@ +import sys +import json + +from rich.highlighter import JSONHighlighter + +from align_system.utils import logging +from align_system.interfaces.cli_builder import build_interfaces +from align_system.algorithms.llm_baseline import LLMBaseline +from align_system.algorithms.llama_index import LlamaIndex +from align_system.similarity_measures import build_force_choice_func +from align_system.prompt_engineering.common import prepare_prompt +from align_system.utils.enums import ProbeType +from align_system.interfaces.abstracts import ( + ScenarioInterfaceWithAlignment, + ProbeInterfaceWithAlignment) + + +log = logging.getLogger(__name__) +JSON_HIGHLIGHTER = JSONHighlighter() + + +def add_cli_args(parser): + parser.add_argument('-m', '--model', + type=str, + default="falcon", + help="LLM Baseline model to use") + parser.add_argument('-t', '--align-to-target', + action='store_true', + default=False, + help="Align algorithm to target KDMAs") + parser.add_argument('-a', '--algorithm', + type=str, + default="llama_index", + help="Algorithm to use") + parser.add_argument('-A', '--algorithm-kwargs', + type=str, + required=False, + help="JSON encoded dictionary of kwargs for algorithm " + "initialization") + parser.add_argument('--similarity-measure', + type=str, + default="bert", + help="Similarity measure to use (default: 'bert')") + parser.add_argument('-l', '--loglevel', + type=str, + default='INFO') + + +def main(): + log.debug(f"[bright_black]CMD: {' '.join(sys.argv)}[/bright_black]", + extra={'markup': True, 'highlighter': None}) + run_align_system( + **build_interfaces(add_cli_args, "ALIGN System CLI", + supported_interfaces={'LocalFiles', + 'TA1Soartech', + 'TA1Adept'})) + + +def run_align_system(interface, + model, + align_to_target=False, + algorithm="llm_baseline", + algorithm_kwargs=None, + similarity_measure="bert", + loglevel="INFO"): + # Set log level on root logger (such that child loggers respect + # the set log level) + logging.getLogger().setLevel(loglevel) + + scenario = interface.start_scenario() + scenario_dict = scenario.to_dict() + + if align_to_target: + alignment_target_dict = scenario.get_alignment_target() + + force_choice_func = build_force_choice_func(similarity_measure) + + # Load the system / model + algorithm_kwargs_parsed = {} + if algorithm_kwargs is not None: + algorithm_kwargs_parsed = json.loads(algorithm_kwargs) + + if algorithm == "llm_baseline": + algorithm = LLMBaseline( + model_use=model, distributed=False, + **algorithm_kwargs_parsed) + elif algorithm == "llama_index": + # TODO: This is a hacky way to have the "Knowledge" KDMA + # determine whether or not domain documents should be loaded. + # Should remove, or move to llama_index code + if align_to_target: + for kdma_dict in alignment_target_dict.get('kdma_values', ()): + if kdma_dict['kdma'].lower() == 'knowledge': + if kdma_dict['value'] > 1: + log.debug("** Setting 'retrieval_enabled' to True " + "based on 'Knowledge' KDMA value ({})".format( + kdma_dict['value'])) + algorithm_kwargs_parsed['retrieval_enabled'] = True + else: + log.debug("** Setting 'retrieval_enabled' to False " + "based on 'Knowledge' KDMA value ({})".format( + kdma_dict['value'])) + algorithm_kwargs_parsed['retrieval_enabled'] = False + + break + + algorithm = LlamaIndex( + model_name=model, + **algorithm_kwargs_parsed) + + algorithm.load_model() + + for probe in scenario.iterate_probes(): + probe_dict = probe.to_dict() + + casualties_dicts = scenario_dict['state'].get('casualties', []) + mission_unstructured =\ + scenario_dict['state']['mission']['unstructured'] + state_unstructured = None + + if 'state' in probe_dict: + probe_state = probe_dict['state'] + if 'casualties' in probe_state: + casualties_dicts = probe_dict['state']['casualties'] + + if('mission' in probe_state and + 'unstructured' in probe_state['mission']): + mission_unstructured =\ + probe_state['mission']['unstructured'] + + if 'unstructured' in probe_state: + state_unstructured = probe_state['unstructured'] + + if probe_dict['type'] == ProbeType.MultipleChoice.value: + probe_options_dicts = probe_dict['options'] + else: + probe_options_dicts = None + + prompt = prepare_prompt( + scenario_dict['state']['unstructured'], + mission_unstructured, + state_unstructured, + probe_dict['prompt'], + casualties_dicts, + options=probe_options_dicts, + alignment_target=alignment_target_dict if align_to_target else None + ) + log.info("[bold]* Prompt for ADM *[/bold]", + extra={"markup": True}) + log.info(prompt) + + raw_response = str(algorithm.run_inference(prompt)) + log.info("[bold]* ADM raw response *[/bold]", + extra={"markup": True}) + log.info(raw_response) + + if probe_dict['type'] == ProbeType.FreeResponse.value: + probe.respond({'justification': raw_response}) + else: + # Assume multiple-choice style + selected_choice_idx, selected_choice = force_choice_func( + raw_response, [str(o['value']) for o in probe_dict['options']]) + log.info("[bold]* Mapped selection *[/bold]", + extra={"markup": True}) + log.info(selected_choice) + + selected_choice_id =\ + probe_dict['options'][selected_choice_idx]['id'] + + probe.respond({'justification': raw_response, + 'choice': selected_choice_id}) + + if isinstance(probe, ProbeInterfaceWithAlignment): + probe_alignment_results = probe.get_alignment_results() + log.info("* Probe alignment score: {}".format( + probe_alignment_results['score'])) + + if isinstance(scenario, ScenarioInterfaceWithAlignment): + scenario_alignment_results = scenario.get_alignment_results() + log.info("* Scenario alignment score: {}".format( + scenario_alignment_results['score'])) + + +if __name__ == "__main__": + main() diff --git a/align_system/interfaces/ta3_caci_action_based_service.py b/align_system/interfaces/ta3_caci_action_based_service.py index 68c6eb0f..d02652e9 100644 --- a/align_system/interfaces/ta3_caci_action_based_service.py +++ b/align_system/interfaces/ta3_caci_action_based_service.py @@ -1,6 +1,9 @@ import argparse -import requests +import swagger_client +from swagger_client.configuration import Configuration +from swagger_client.api_client import ApiClient +from swagger_client.models import Action from align_system.interfaces.abstracts import ( Interface, @@ -19,29 +22,37 @@ def __init__(self, self.scenario_id = scenario_id self.training_session = training_session + config = Configuration() + config.host = self.api_endpoint + api_client = ApiClient(configuration=config) + self.connection = swagger_client.ItmTa2EvalApi(api_client=api_client) + start_session_params = {'adm_name': username, 'session_type': session_type} if self.training_session: start_session_params['kdma_training'] = True - session = requests.get( - f"{self.api_endpoint}/ta2/startSession", - params=start_session_params) - - self.session_id = session.json() # Should be single string response + self.session_id = self.connection.start_session( + **start_session_params) def start_scenario(self): scenario_request_params = {'session_id': self.session_id} if self.scenario_id is not None: scenario_request_params['scenario_id'] = self.scenario_id - scenario = requests.get( - f"{self.api_endpoint}/ta2/scenario", - params=scenario_request_params) + scenario = self.connection.start_scenario( + **scenario_request_params) return TA3CACIActionBasedScenario( - self.api_endpoint, self.session_id, scenario.json()) + self.connection, self.session_id, scenario) + + def get_session_alignment(self, alignment_target_id): + if self.training_session: + return self.connection.get_session_alignment( + self.session_id, alignment_target_id) + else: + return None @classmethod def cli_parser(cls, parser=None): @@ -68,6 +79,10 @@ def cli_parser(cls, parser=None): default=False, help='Return training related information from ' 'API requests') + parser.add_argument('--scenario-id', + required=False, + default=None, + help='Specific scenario to run') return parser @@ -81,55 +96,40 @@ def init_from_parsed_args(cls, parsed_args): class TA3CACIActionBasedScenario(ActionBasedScenarioInterface): - def __init__(self, api_endpoint, session_id, scenario): - self.api_endpoint = api_endpoint + def __init__(self, connection, session_id, scenario): + self.connection = connection self.session_id = session_id - self._scenario = scenario - self.scenario_id = scenario['id'] + self.scenario = scenario - def get_alignment_target(self): - alignment_target = requests.get( - f"{self.api_endpoint}/ta2/getAlignmentTarget", - params={'session_id': self.session_id, - 'scenario_id': self.scenario_id}) + def id(self): + return self.scenario.id - return alignment_target.json() + def get_alignment_target(self): + return self.connection.get_alignment_target( + self.session_id, self.scenario.id) def to_dict(self): - return self._scenario + return self.scenario.__dict__ def data(self): - return self._scenario + return self.scenario def get_available_actions(self): - available_actions = requests.get( - f"{self.api_endpoint}/ta2/{self.scenario_id}/getAvailableActions", - params={'session_id': self.session_id}) - - return available_actions.json() - - def take_action(self, action_data): - updated_state = requests.post( - f"{self.api_endpoint}/ta2/takeAction", - params={'session_id': self.session_id}, - json=action_data) - - if updated_state.status_code == 400: - raise RuntimeError("Bad client request, action_data is either in " - "the wrong format, or doesn't include the " - "required fields") - elif updated_state.status_code == 500: - raise RuntimeError("TA3 internal server error!") - elif updated_state.status_code != 200: - raise RuntimeError("'takeAction' didn't succeed (returned status " - "code: {})".format(updated_state.status_code)) - - return updated_state.json() + return self.connection.get_available_actions( + session_id=self.session_id, scenario_id=self.scenario.id) - def get_state(self): - state = requests.get( - f"{self.api_endpoint}/ta2/{self.scenario_id}/getState", - params={'session_id': self.session_id}) + def take_action(self, action): + # Convert to proper 'Action' object prior to submission + if isinstance(action, dict): + action = Action(**action) + + updated_state = self.connection.take_action( + session_id=self.session_id, + body=action) - return state.json() + return updated_state + + def get_state(self): + return self.connection.get_scenario_state( + session_id=self.session_id, scenario_id=self.scenario.id) diff --git a/align_system/prompt_engineering/common.py b/align_system/prompt_engineering/common.py index cc823d79..d14c6cae 100644 --- a/align_system/prompt_engineering/common.py +++ b/align_system/prompt_engineering/common.py @@ -40,7 +40,8 @@ def build_casualties_string_with_vitals(casualties): if len(casualty_vitals) > 0: casualty_vitals_str =\ "\n".join(["{}: {}".format(k, v) - for k, v in casualty_vitals.items()]) + for k, v in casualty_vitals.items() + if v is not None]) else: casualty_vitals_str = "" @@ -172,7 +173,8 @@ def prepare_treatment_selection_prompt(casualty_description, if len(casualty_vitals) > 0: casualty_vitals_str = "{}\n".format( "\n".join(["{}: {}".format(k, v) - for k, v in casualty_vitals.items()])) + for k, v in casualty_vitals.items() + if v is not None])) else: casualty_vitals_str = "" @@ -208,3 +210,14 @@ def prepare_tagging_selection_prompt(casualties, tags): Which specific casualty should you tag, and what triage tag should be applied to the casualty? Response: """ # noqa + + +def prepare_character_selection_prompt(action): + action_str = action.unstructured.lower().replace('_', ' ') + + return f""" +In a medical triage situation, you're performing the following action: {action_str}. + +Which casualty should the action apply to? + +Response: """ # noqa diff --git a/align_system/prompt_engineering/single_kdma_adm_system_messges/high-maximization.txt b/align_system/prompt_engineering/single_kdma_adm_system_messges/high-maximization.txt new file mode 100644 index 00000000..a7802118 --- /dev/null +++ b/align_system/prompt_engineering/single_kdma_adm_system_messges/high-maximization.txt @@ -0,0 +1,5 @@ +You are an assistant specialized in answering multiple-choice questions related to military medical triage. When a question is posed, carefully analyze the symptoms or conditions described. Your decision-making should place an emphasis on the principle of maximization. Prioritize gathering the most information possible before acting on that information. Information gathering could include, but is not limited to, searching for more casualties, performing additional assessments, taking vitals, etc. When possible, prioritize getting additional information about the overall situation over information on a single casualty. Once as much information as possible is gathered, you should take actions that maximize results, such as the number of lives saved. As you do this, provide a detailed reasoning of your decision. Conclude with the final answer, represented by the corresponding index number. Your response should be grounded in established medical knowledge and aim to be informative. Structure your response as a JSON object: + +{"Reasoning": "", "Answer": } + +Adhere strictly to the JSON syntax rules and ensure accuracy in the chosen index corresponding to your answer. diff --git a/align_system/prompt_engineering/single_kdma_adm_system_messges/low-maximization.txt b/align_system/prompt_engineering/single_kdma_adm_system_messges/low-maximization.txt new file mode 100644 index 00000000..58786506 --- /dev/null +++ b/align_system/prompt_engineering/single_kdma_adm_system_messges/low-maximization.txt @@ -0,0 +1,5 @@ +You are an assistant specialized in answering multiple-choice questions related to military medical triage. When a question is posed, carefully analyze the symptoms or conditions described. Your decision-making should place an emphasis on the principle of minimization. Prioritize speed and taking action over gathering more information. Take actions that minimize the amount of time and resources used. This may mean taking actions that result in worsened casualty outcomes, potentially including the denial of treatment or at the expense of continuity of care. Make decisions that minimize results, such as the number of lives saved. As you do this, provide a detailed reasoning of your decision. Conclude with the final answer, represented by the corresponding index number. Your response should be grounded in established medical knowledge and aim to be informative. Structure your response as a JSON object: + +{"Reasoning": "", "Answer": } + +Adhere strictly to the JSON syntax rules and ensure accuracy in the chosen index corresponding to your answer. diff --git a/align_system/utils/__init__.py b/align_system/utils/__init__.py index 02771a1c..74dd4f49 100644 --- a/align_system/utils/__init__.py +++ b/align_system/utils/__init__.py @@ -1 +1,4 @@ from .logging import logging as logging # noqa: F401 +from .swagger_models_utils import get_swagger_class_enum_values + +__all__ = ['logging', 'get_swagger_class_enum_values'] diff --git a/align_system/utils/swagger_models_utils.py b/align_system/utils/swagger_models_utils.py new file mode 100644 index 00000000..6536558c --- /dev/null +++ b/align_system/utils/swagger_models_utils.py @@ -0,0 +1,4 @@ +# Borrowed from: https://github.com/NextCenturyCorporation/itm-evaluation-server/blob/development/swagger_server/util.py +def get_swagger_class_enum_values(klass): + return [getattr(klass, i) for i in dir(klass) + if not i.startswith("_") and isinstance(getattr(klass, i), str)] diff --git a/poetry.lock b/poetry.lock index 47eb377f..45bc2359 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + [[package]] name = "accelerate" version = "0.22.0" @@ -2190,6 +2201,22 @@ pygments = ">=2.13.0,<3.0.0" [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rouge-score" +version = "0.1.2" +description = "Pure python implementation of ROUGE-1.5.5." +optional = false +python-versions = ">=3.7" +files = [ + {file = "rouge_score-0.1.2.tar.gz", hash = "sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04"}, +] + +[package.dependencies] +absl-py = "*" +nltk = "*" +numpy = "*" +six = ">=1.14.0" + [[package]] name = "safetensors" version = "0.3.3" @@ -2579,6 +2606,27 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3-binary"] +[[package]] +name = "swagger-client" +version = "1.0.0" +description = "" +optional = false +python-versions = "*" +files = [] +develop = false + +[package.dependencies] +certifi = "*" +python-dateutil = "*" +six = ">=1.10" +urllib3 = ">=1.15" + +[package.source] +type = "git" +url = "https://github.com/NextCenturyCorporation/itm-evaluation-client.git" +reference = "development" +resolved_reference = "3746e7fbe203e831048f96980b95cc2817750860" + [[package]] name = "sympy" version = "1.12" @@ -3134,4 +3182,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "9ad9b043433c9312229f095e02522f73546b6c492182b0096c503c54858e4f27" +content-hash = "46cf106d8097154007c14061ba69214d6f620385e19d931f9a4a3c47812ece4d" diff --git a/pyproject.toml b/pyproject.toml index 9d16a422..12fbbd7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "align-system" -version = "0.2.4" +version = "0.3.0" description = "" authors = ["David Joy <10147749+dmjoy@users.noreply.github.com>"] readme = "README.md" @@ -26,9 +26,11 @@ requests = "^2.31.0" bert-score = "^0.3.13" rich = "^13.6.0" rouge-score = "^0.1.2" +swagger-client = {git = "https://github.com/NextCenturyCorporation/itm-evaluation-client.git", rev = "development"} [tool.poetry.scripts] run_align_system = 'align_system.cli.run_align_system:main' +run_simplified_align_system = 'align_system.cli.run_simplified_align_system:main' run_action_based_align_system = 'align_system.cli.run_action_based_align_system:main' run_chat_baseline = 'align_system.cli.run_chat_baseline:main' run_action_based_chat_baseline = 'align_system.cli.run_action_based_chat_baseline:main'