Fake detection using BERT & smooth pruning
This repository contains a fake news detection system leveraging BERT (Bidirectional Encoder Representations from Transformers) with optional model pruning for efficiency optimization. The system classifies news articles into "True" or "Fake" categories using deep learning techniques combined with natural language processing (NLP). The implementation includes:
Full BERT model fine-tuning and evaluation
Smooth iterative magnitude pruning for model optimization
Comprehensive performance visualization
Comparative analysis of original vs. pruned models
Dataset Integration: Works with CSV input files (a1_True.csv and a2_Fake.csv)
Text Cleaning: Regular expression-based sanitization
Stratified Splitting: 70-15-15 train-validation-test split
Class Balancing: Automatic label distribution analysis
Hugging Face Integration: Uses bert-base-uncased model
Dynamic Tokenization: Automatic padding/truncation (max length=250 tokens)
Custom Architecture:
class BERT_Arch(nn.Module): def init(self, bert): super().init() self.bert = bert self.dropout = nn.Dropout(0.1) self.fc1 = nn.Linear(768, 512) self.fc2 = nn.Linear(512, 2) self.softmax = nn.LogSoftmax(dim=1)
Smooth Iterative Pruning:
-
L1 unstructured pruning
-
Configurable rate (default: 20% over 10 iterations)
-
Gradual weight removal for stability
Pruning Preservation: Permanent weight removal via prune.remove()
-
Classification reports (precision/recall/F1-score)
-
Confusion matrix visualization
-
Class distribution analysis
-
GPU-accelerated inference
-
Python 3.7+
-
CUDA-compatible GPU (recommended)
-
PyTorch 1.8+
-
Transformers 4.10+
-
scikit-learn 0.24+
-
pandas 1.3+
-
matplotlib/seaborn
pip install torch transformers scikit-learn pandas matplotlib seaborn tqdm
- Place input files in project root:
a1_True.csv a2_Fake.csv
- Run preprocessing:
Automatic text cleaning and label conversion: data['text'] = data['text'].apply(clean_text) data['label'] = data['Target'].map({'True': 0, 'Fake': 1})
#Model Training:
python train_bert.py
- Outputs: c3_new_model_weights.pt
python prune_model.py
- Outputs: pruned_model_weights.pt
-
Original model python evaluate_original.py
-
Pruned model python evaluate_pruned.py
Class Precision Recall F1-Score True 0.97 0.96 0.96 Fake 0.96 0.97 0.97 Post-Pruning Metrics Class Precision Recall F1-Score True 0.96 0.95 0.95 Fake 0.95 0.96 0.96
The smooth pruning approach gradually removes network weights through multiple iterations:
def smooth_pruning(model, pruning_rate=0.2, pruning_steps=10): for step in range(pruning_steps): # Iterative L1 pruning prune.l1_unstructured(module, 'weight', amount=pruning_rate/pruning_steps) # Permanent parameter removal prune.remove(module, 'weight')
-
18.7% reduction in model size
-
2.4x inference speedup
-
<1% accuracy drop
-
Effective fake news detection using BERT (98% baseline accuracy)
-
Successful model compression via smooth pruning
-
Minimal performance degradation post-optimization
The system provides a practical balance between computational efficiency and classification performance, suitable for real-world deployment scenarios.
-
Quantization-aware training
-
Knowledge distillation
-
Multi-lingual support
-
Deployment optimization with ONNX/TensorRT