Skip to content

Data generation / retrieval for Training #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 21, 2024
5 changes: 5 additions & 0 deletions src/data/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#[allow(dead_code)]
struct Dataset {
inputs: Vec<Vec<usize>>, // Each input is a sequence of token IDs
targets: Vec<Vec<usize>>, // Each target is the corresponding output sequence
}
98 changes: 98 additions & 0 deletions src/data/generation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use crate::data::io::get_input;
use crate::data::tokenizer::Tokenizer;

fn generate_input_target_pairs(
tokenizer: &Tokenizer,
sentences: Vec<String>,
) -> Vec<(Vec<usize>, Vec<usize>)> {
let mut pairs = Vec::new();

for i in 0..sentences.len() {
let sentence = &sentences[i]; // Borrow the sentence
let tokens = tokenizer.tokenize(sentence);

// Prepare input (same as sentence)
let input = tokens.clone();

// Prepare target (shifted version of the sentence)
let mut target = tokens.clone();
if i + 1 < sentences.len() {
// If not the last sentence, append the first token of the next sentence
let next_sentence = &sentences[i + 1];
let next_tokens = tokenizer.tokenize(next_sentence);
if !next_tokens.is_empty() {
target.push(next_tokens[0]); // Add the first token of the next sentence
}
} else {
target.push(tokenizer.vocab["<EOS>"]); // Use EOS token for the last sentence
}

// Remove the first token from target (shifting by 1)
if !target.is_empty() {
target.remove(0);
}

// Add the input-target pair to the result
pairs.push((input, target));
}

pairs
}

// This will convert a list of sentences into tokenized input-output pairs
#[allow(dead_code)]
fn generate_input_target_pairs_by_sentence(
tokenizer: &Tokenizer,
sentences: Vec<&str>,
) -> Vec<(Vec<usize>, Vec<usize>)> {
let mut pairs = Vec::new();

for i in 0..sentences.len() - 1 {
let input = tokenizer.tokenize(sentences[i]);
let target = tokenizer.tokenize(sentences[i + 1]);

pairs.push((input, target));
}

pairs
}

pub fn example_gen() {
let raw_text = get_input();

let tokenizer = Tokenizer::new(raw_text.clone());

// Generate input-target pairs
let pairs = generate_input_target_pairs(&tokenizer, raw_text);

// Display the pairs
for (input, target) in pairs.clone() {
println!("Input: {:?}\nTarget: {:?}\n", input, target);
}
for (input, target) in pairs {
let staircase_pairs = generate_staircase_pairs(&input, &target);

for (staircase_input, staircase_target) in staircase_pairs {
println!(
"Input: {:?}\nTarget: {:?}\n",
tokenizer.detokenize(staircase_input),
tokenizer.detokenize(staircase_target)
);
}
}
}
fn generate_staircase_pairs(input: &[usize], target: &[usize]) -> Vec<(Vec<usize>, Vec<usize>)> {
let mut staircase_pairs = Vec::new();

// The number of steps will be the length of the target sequence
for i in 1..=target.len() {
// Slice input and target incrementally
let staircase_input = input.iter().take(i).cloned().collect::<Vec<usize>>();
let staircase_target = target.iter().take(i).cloned().collect::<Vec<usize>>();

// Add this pair to the staircase pairs vector
staircase_pairs.push((staircase_input, staircase_target));
}

staircase_pairs
}
1 change: 1 addition & 0 deletions src/data/in/training_input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Once upon a time, in a land far away, there was a small village. The villagers were known for their kindness and generosity. Every year, they celebrated the harvest festival with music, dance, and delicious food. One day, a traveler came to the village. He was tired and hungry, but the villagers welcomed him with open arms. The traveler shared stories of his adventures as the villagers listened intently. He told them about distant lands and strange creatures. The villagers were fascinated by his tales. As the evening drew to a close, the traveler offered to leave the village, but the villagers insisted he stay for another night. The next morning, the traveler said goodbye and continued his journey. The villagers waved him off, grateful for the stories and the company.
34 changes: 34 additions & 0 deletions src/data/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use regex::Regex;
use std::fs;

fn split_into_sentences(text: String) -> Vec<String> {
let re = Regex::new(r"[.!?]").unwrap(); // Matches sentence-ending punctuation
let mut sentences: Vec<String> = Vec::new(); // We want to store owned Strings, not &str

let mut last_index = 0;
for mat in re.find_iter(&text) {
let end = mat.end();
// Extract the sentence up to the matched punctuation
let sentence = text[last_index..end].trim().to_string(); // Convert to String
if !sentence.is_empty() {
sentences.push(sentence);
}
last_index = end;
}

// Add any remaining text as a sentence
if last_index < text.len() {
let remaining = text[last_index..].trim().to_string(); // Convert remaining to String
if !remaining.is_empty() {
sentences.push(remaining);
}
}

sentences
}

pub fn get_input() -> Vec<String> {
let file_path = "src/data/in/training_input.txt";
let content: String = fs::read_to_string(file_path).unwrap(); // Read the file content
split_into_sentences(content) // Call the function to split into sentences
}
71 changes: 71 additions & 0 deletions src/data/learnable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use crate::settings::*;
use ndarray::{Array1, Array2};

pub struct LearnableWeights {
// Embedding Layer
pub embedding: Array2<f32>, // (vocab_size, embedding_dim)

// Attention Mechanism
pub query_weights: Array2<f32>, // (embedding_dim, attention_dim)
pub key_weights: Array2<f32>, // (embedding_dim, attention_dim)
pub value_weights: Array2<f32>, // (embedding_dim, attention_dim)
pub output_projection: Array2<f32>, // (attention_dim, embedding_dim)

// Feedforward Network
pub linear1_weights: Array2<f32>, // (embedding_dim, ffn_dim)
pub linear2_weights: Array2<f32>, // (ffn_dim, embedding_dim)
pub bias1: Array1<f32>,
pub bias2: Array1<f32>,

// Layer Normalization
pub layer_norm_scale: Vec<f32>, // (embedding_dim,)
pub layer_norm_shift: Vec<f32>, // (embedding_dim,)

// Output Layer
pub output_projection_vocab: Array2<f32>, // (embedding_dim, vocab_size)
}

impl LearnableWeights {
pub fn new(
output_size: usize,
hidden_size: usize,
vocab_size: usize,
embedding_dim: usize,
attention_dim: usize,
ffn_dim: usize,
) -> Self {
LearnableWeights {
// Embedding Layer
embedding: Array2::ones((vocab_size, embedding_dim)),

// Attention Mechanism
query_weights: Array2::ones((embedding_dim, attention_dim)),
key_weights: Array2::ones((embedding_dim, attention_dim)),
value_weights: Array2::ones((embedding_dim, attention_dim)),
output_projection: Array2::ones((attention_dim, embedding_dim)),

// Feedforward Network
linear1_weights: Array2::ones((embedding_dim, ffn_dim)),
linear2_weights: Array2::ones((ffn_dim, embedding_dim)),
bias1: Array1::zeros(hidden_size),
bias2: Array1::zeros(output_size),

// Layer Normalization
layer_norm_scale: vec![1.0; embedding_dim], // Initialize scale to 1
layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0

// Output Layer
output_projection_vocab: Array2::zeros((embedding_dim, vocab_size)),
}
}
}
pub fn initialize_weights() -> LearnableWeights {
LearnableWeights::new(
OUTPUT_SIZE, // output_size
HIDDEN_SIZE, // hidden_size
D_MODEL, // vocab_size
D_MODEL, // embedding_dim
D_K, // attention_dim (for keys, queries)
D_V, // ffn_dim (could align with embedding_dim or specific)
)
}
96 changes: 69 additions & 27 deletions src/data/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#![allow(warnings)]

use regex::Regex;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

#[derive(Debug)]
pub struct Tokenizer {
vocab: HashMap<String, usize>,
pub vocab: HashMap<String, usize>,
reverse_vocab: HashMap<usize, String>,
pad_token: String,
sos_token: String,
Expand All @@ -14,7 +14,10 @@ pub struct Tokenizer {
}

impl Tokenizer {
pub fn new(vocab: HashMap<String, usize>) -> Self {
pub fn new(input: Vec<String>) -> Self {
let vocab: HashMap<String, usize> = generate_vocab(input);
println!("size : {:?}", vocab.clone().len());

// Define special tokens
let pad_token = "<PAD>".to_string();
let sos_token = "<SOS>".to_string();
Expand Down Expand Up @@ -46,14 +49,18 @@ impl Tokenizer {

// Tokenize input sentence (word-based tokenization)
pub fn tokenize(&self, sentence: &str) -> Vec<usize> {
let words = self.tokenize_sentence(sentence);
let words = tokenize_sentence(sentence);
let mut tokens: Vec<usize> = vec![self.vocab[&self.sos_token]]; // Start with SOS token

for word in words {
let token = self
.vocab
.get(&word)
.unwrap_or(&self.vocab[&self.unk_token]);
let word_lower = word.to_lowercase(); // Convert to lowercase
let token = self.vocab.get(&word_lower).unwrap_or_else(|| {
eprintln!(
"Warning: Word '{}' not in vocabulary, substituting <UNK>",
word
);
&self.vocab[&self.unk_token]
});
tokens.push(*token);
}

Expand All @@ -71,38 +78,73 @@ impl Tokenizer {
.join(" ")
}

// Helper function to split sentence into words using a simple regex
fn tokenize_sentence(&self, sentence: &str) -> Vec<String> {
let re = Regex::new(r"\w+").unwrap(); // Matches words (letters and numbers)
re.find_iter(sentence)
.map(|mat| mat.as_str().to_string())
.collect()
// Helper function to split sentence into words using an improved regex
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_tokenizer() {
let input: Vec<String> = vec!["Hello world, my".parse().unwrap()];
let tokenizer = Tokenizer::new(input);

// Empty sentence
let tokens = tokenizer.tokenize("");
assert_eq!(tokens, vec![1, 2]); // <SOS>, <EOS>

// Sentence with punctuation
let tokens = tokenizer.tokenize("hello, world!");
// Detokenization
let decoded_sentence = tokenizer.detokenize(tokens.clone());
assert_eq!(decoded_sentence, "<SOS> hello , world <UNK> <EOS>");
}
}

pub fn example_tokens() {
// Define a small vocabulary (for example purposes)
let vocab = vec![
("hello".to_string(), 4),
("world".to_string(), 5),
("my".to_string(), 6),
("name".to_string(), 7),
("is".to_string(), 8),
]
.into_iter()
.collect::<HashMap<String, usize>>();
// Define a small sentence
let input: Vec<String> = vec!["Hello, world!".parse().unwrap()];

// Instantiate the tokenizer with the vocabulary
let tokenizer = Tokenizer::new(vocab);
let tokenizer = Tokenizer::new(input);

// Example sentence
let sentence = "hello world";
let sentence = "Hello, world! My name is ChatGPT.";

// Tokenize the sentence
let tokens = tokenizer.tokenize(sentence);
println!("Tokens: {:?}", tokens); // Should print indices for "hello", "world", etc.

// Detokenize the sentence
let decoded_sentence = tokenizer.detokenize(tokens);
println!("Decoded Sentence: {}", decoded_sentence); // Should print "hello world"
println!("Decoded Sentence: {}", decoded_sentence); // Should print the sequence with special tokens
}

pub fn generate_vocab(text: Vec<String>) -> HashMap<String, usize> {
let mut word_set = HashSet::new();

// Tokenize each sentence and collect words, treating punctuation separately
for sentence in text {
let words = tokenize_sentence(&sentence); // Updated tokenization method
for word in words {
word_set.insert(word.to_lowercase());
}
}

// Create vocabulary by assigning a unique index to each word
let vocab = word_set
.into_iter()
.enumerate()
.map(|(idx, word)| (word, idx + 4)) // Start index from 4 to leave space for special tokens
.collect::<HashMap<String, usize>>();

vocab
}

pub fn tokenize_sentence(sentence: &str) -> Vec<String> {
let re = Regex::new(r"\w+|[^\w\s]").unwrap(); // Matches words or punctuation
re.find_iter(sentence)
.map(|mat| mat.as_str().to_string())
.collect()
}
Loading
Loading