This project implements the Causal-CoT framework, a system designed to enhance the reliability of Large Language Models (LLMs) by validating each step of their reasoning process against causal principles and external knowledge.
The framework's core loop identifies causal fallacies (e.g., confounding, spurious correlation) in real-time. When a flawed step is detected, a reflection-and-regeneration cycle is triggered to self-correct and build a more robust and trustworthy reasoning path.
- Core Design Principles
- Framework Architecture
- Supported Datasets
- Setup and Installation
- Configuration
- How to Run an Experiment
- Evaluation Metrics
- Reasoning as Hypothesis Testing: An LLM's Chain-of-Thought is treated as a Chain of Hypotheses. Each step is a claim that is independently scrutinized.
- Causal Validation: Each step is passed to a Knowledge Prober that performs a causal analysis based on Judea Pearl's structural causal model framework, using ConceptNet to identify Causal Chains, Forks, and Colliders.
- Reflective Self-Correction: When a step is invalidated, the framework enters a reflection loop. The LLM is informed of its error and tasked with regenerating a new, more sound reasoning path.
The system operates via a four-phase, iterative pipeline for each question: CoT Generation → Iterative Causal Probing → Reflection & Regeneration → Final Synthesis.
The framework is configured to work with several standard reasoning benchmarks out-of-the-box. All listed datasets are verified to be available on the Hugging Face Hub. Each has a corresponding configuration file in the configs/ directory.
| Config File | Dataset Name | Hugging Face ID | Task Type |
|---|---|---|---|
dataset_commonsense_qa.json |
CommonsenseQA | commonsense_qa |
Multiple Choice |
dataset_arc_challenge.json |
ARC-Challenge | ai2_arc |
Multiple Choice |
dataset_openbookqa.json |
OpenBookQA | openbookqa |
Multiple Choice |
dataset_piqa.json |
PIQA | piqa |
Multiple Choice |
dataset_siqa.json |
SocialIQA | social_i_qa |
Multiple Choice |
dataset_boolq.json |
BoolQ | boolq |
Yes/No Reasoning |
dataset_gsm8k.json |
GSM8K | gsm8k |
Math Word Problem |
-
Clone the Repository:
git clone https://github.com/your-username/causal-cot-final.git cd causal-cot-final -
Create and Activate a Python Environment:
python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate
-
Install Dependencies: A
requirements.txtfile is provided.pip install -r requirements.txt
Note: For local model usage, you will need to install
torch,transformers, andaccelerate.
All experiments are driven by JSON configuration files in the configs/ directory.
Choose between api or local mode by pointing to the correct file.
- API Mode (
configs/model_api.json): Set theapi_key_envto the name of the environment variable holding your API key (e.g.,"DEEPINFRA_API_KEY"). - Local Mode (
configs/model_local.json): Set thepathto the directory of your locally saved Hugging Face model.
- Select a dataset by providing the path to its config file (e.g.,
configs/dataset_boolq.json). - The
hf_idfield must match the dataset's ID on the Hugging Face Hub. - The
hf_configfield should only be present if the dataset has multiple configurations (like ARC). Otherwise, it should be omitted from the JSON file.
The run_experiment.py script orchestrates the entire process.
# Example for DeepInfra
export DEEPINFRA_API_KEY="your_actual_api_key_here"Provide the paths to your desired model and dataset configurations.
python run_experiment.py --model_config configs/model_api.json --dataset_config configs/dataset_boolq.jsonpython run_experiment.py --model_config configs/model_local.json --dataset_config configs/dataset_gsm8k.json- The console will show a verbose, real-time log of the Causal-CoT process for each sample.
- Upon completion, a summary of the final metrics is printed.
- A detailed JSON file is saved in the
results/directory, named after the experiment configuration (e.g.,results/BoolQ_model_api.json).
accuracy: Final task accuracy after corrections.causal_metrics:intervention_rate: Average number of self-corrections per problem.reasoning_fidelity: Proportion of initial CoT steps that were valid.fallacy_rate: Percentage of steps identified as a causal fallacy.avg_correction_depth_percent: Average point (%) in the CoT where the first error occurred.causal_structure_distribution: Frequency count of identified causal structures.