diff --git a/src/data/generation.rs b/src/data/generation.rs index 825b2a0..44cb4c8 100644 --- a/src/data/generation.rs +++ b/src/data/generation.rs @@ -70,10 +70,31 @@ pub fn example_gen() { println!("Input: {:?}\nTarget: {:?}\n", input, target); } for (input, target) in pairs { - println!( - "Input: {:?}\nTarget: {:?}\n", - tokenizer.detokenize(input), - tokenizer.detokenize(target) - ); + + + let staircase_pairs = generate_staircase_pairs(&input, &target); + + for (staircase_input, staircase_target) in staircase_pairs { + println!( + "Input: {:?}\nTarget: {:?}\n", + tokenizer.detokenize(staircase_input), + tokenizer.detokenize(staircase_target) + ); + } } } +fn generate_staircase_pairs(input: &Vec, target: &Vec) -> Vec<(Vec, Vec)> { + let mut staircase_pairs = Vec::new(); + + // The number of steps will be the length of the target sequence + for i in 1..=target.len() { + // Slice input and target incrementally + let staircase_input = input.iter().take(i).cloned().collect::>(); + let staircase_target = target.iter().take(i).cloned().collect::>(); + + // Add this pair to the staircase pairs vector + staircase_pairs.push((staircase_input, staircase_target)); + } + + staircase_pairs +} \ No newline at end of file