Skip to content

Commit

Permalink
Implemented IO for file reading and then data generation
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSchwenkbeck committed Dec 21, 2024
1 parent c37d01f commit 47dccc4
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 26 deletions.
36 changes: 16 additions & 20 deletions src/data/generation.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::data::io::get_input;
use crate::data::tokenizer::Tokenizer;

fn generate_input_target_pairs(
tokenizer: &Tokenizer,
sentences: Vec<&str>,
sentences: Vec<String>,
) -> Vec<(Vec<usize>, Vec<usize>)> {
let mut pairs = Vec::new();

for i in 0..sentences.len() {
let sentence = sentences[i];
let sentence = &sentences[i]; // Borrow the sentence
let tokens = tokenizer.tokenize(sentence);

// Prepare input (same as sentence)
Expand All @@ -16,15 +17,22 @@ fn generate_input_target_pairs(
// Prepare target (shifted version of the sentence)
let mut target = tokens.clone();
if i + 1 < sentences.len() {
// If not the last sentence, the last word will be the first word of the next sentence
target.push(tokenizer.tokenize(sentences[i + 1])[0]); // First token of the next sentence
// If not the last sentence, append the first token of the next sentence
let next_sentence = &sentences[i + 1];
let next_tokens = tokenizer.tokenize(next_sentence);
if !next_tokens.is_empty() {
target.push(next_tokens[0]); // Add the first token of the next sentence
}
} else {
target.push(tokenizer.vocab["<EOS>"]); // Use EOS token for the last word
target.push(tokenizer.vocab["<EOS>"]); // Use EOS token for the last sentence
}

// Remove the first word from target (shift by 1)
target.remove(0);
// Remove the first token from target (shifting by 1)
if !target.is_empty() {
target.remove(0);
}

// Add the input-target pair to the result
pairs.push((input, target));
}

Expand All @@ -50,19 +58,7 @@ fn generate_input_target_pairs_by_sentence(
}

pub fn example_gen() {
let raw_text = vec![
"Once upon a time, in a land far away, there was a small village.",
"The villagers were known for their kindness and generosity.",
"Every year, they celebrated the harvest festival with music, dance, and delicious food.",
"One day, a traveler came to the village.",
"He was tired and hungry, but the villagers welcomed him with open arms.",
"The traveler shared stories of his adventures as the villagers listened intently.",
"He told them about distant lands and strange creatures.",
"The villagers were fascinated by his tales.",
"As the evening drew to a close, the traveler offered to leave the village, but the villagers insisted he stay for another night.",
"The next morning, the traveler said goodbye and continued his journey.",
"The villagers waved him off, grateful for the stories and the company.",
];
let raw_text = get_input();

let tokenizer = Tokenizer::new(raw_text.clone());

Expand Down
1 change: 1 addition & 0 deletions src/data/in/training_input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Once upon a time, in a land far away, there was a small village. The villagers were known for their kindness and generosity. Every year, they celebrated the harvest festival with music, dance, and delicious food. One day, a traveler came to the village. He was tired and hungry, but the villagers welcomed him with open arms. The traveler shared stories of his adventures as the villagers listened intently. He told them about distant lands and strange creatures. The villagers were fascinated by his tales. As the evening drew to a close, the traveler offered to leave the village, but the villagers insisted he stay for another night. The next morning, the traveler said goodbye and continued his journey. The villagers waved him off, grateful for the stories and the company.
34 changes: 34 additions & 0 deletions src/data/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use regex::Regex;
use std::fs;

fn split_into_sentences(text: String) -> Vec<String> {
let re = Regex::new(r"[.!?]").unwrap(); // Matches sentence-ending punctuation
let mut sentences: Vec<String> = Vec::new(); // We want to store owned Strings, not &str

let mut last_index = 0;
for mat in re.find_iter(&text) {
let end = mat.end();
// Extract the sentence up to the matched punctuation
let sentence = text[last_index..end].trim().to_string(); // Convert to String
if !sentence.is_empty() {
sentences.push(sentence);
}
last_index = end;
}

// Add any remaining text as a sentence
if last_index < text.len() {
let remaining = text[last_index..].trim().to_string(); // Convert remaining to String
if !remaining.is_empty() {
sentences.push(remaining);
}
}

sentences
}

pub fn get_input() -> Vec<String> {
let file_path = "src/data/in/training_input.txt";
let content: String = fs::read_to_string(file_path).unwrap(); // Read the file content
split_into_sentences(content) // Call the function to split into sentences
}
8 changes: 4 additions & 4 deletions src/data/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Tokenizer {
}

impl Tokenizer {
pub fn new(input: Vec<&str>) -> Self {
pub fn new(input: Vec<String>) -> Self {
let vocab: HashMap<String, usize> = generate_vocab(input);
println!("size : {:?}", vocab.clone().len());

Expand Down Expand Up @@ -87,7 +87,7 @@ mod tests {

#[test]
fn test_tokenizer() {
let input = vec!["Hello world, my"];
let input: Vec<String> = vec!["Hello world, my".parse().unwrap()];
let tokenizer = Tokenizer::new(input);

// Empty sentence
Expand All @@ -104,7 +104,7 @@ mod tests {

pub fn example_tokens() {
// Define a small sentence
let input = vec!["Hello, world!"];
let input: Vec<String> = vec!["Hello, world!".parse().unwrap()];

// Instantiate the tokenizer with the vocabulary
let tokenizer = Tokenizer::new(input);
Expand All @@ -121,7 +121,7 @@ pub fn example_tokens() {
println!("Decoded Sentence: {}", decoded_sentence); // Should print the sequence with special tokens
}

pub fn generate_vocab(text: Vec<&str>) -> HashMap<String, usize> {
pub fn generate_vocab(text: Vec<String>) -> HashMap<String, usize> {
let mut word_set = HashSet::new();

// Tokenize each sentence and collect words, treating punctuation separately
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod settings;
pub mod data {
pub mod dataset;
pub mod generation;
pub mod io;
pub mod tokenizer;
}
pub mod model {
Expand Down
6 changes: 4 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ fn main() {

println!(" \n \n \n ENCODER/DECODER \n");

let input = vec!["hello world rust transformer learning model"];
let input: Vec<String> = vec!["hello world rust transformer learning model"
.parse()
.unwrap()];

// Initialize Tokenizer and Embedding layer
let tokenizer = Tokenizer::new(input.clone());
Expand Down Expand Up @@ -75,5 +77,5 @@ fn main() {

println!("Predicted Token: {:?}", predicted_token);

example_gen()
example_gen();
}

0 comments on commit 47dccc4

Please sign in to comment.