Skip to content

Commit 00d6c61

Browse files
author
eric
committed
优化代码结构
1 parent 6f41708 commit 00d6c61

File tree

6 files changed

+285
-77
lines changed

6 files changed

+285
-77
lines changed

kr2r/src/args.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ pub struct ClassifyArgs {
133133
#[clap(short = 'z', long, value_parser, default_value_t = false)]
134134
pub report_zero_counts: bool,
135135

136+
#[clap(long, value_parser, default_value_t = false)]
137+
pub full_output: bool,
138+
136139
/// A list of input file paths (FASTA/FASTQ) to be processed by the classify program.
137140
// #[clap(short = 'F', long = "files")]
138141
pub input_files: Vec<String>,

kr2r/src/bin/annotate.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use clap::Parser;
2-
use kr2r::compact_hash::{CHPage, Compact, HashConfig, K2Compact, Slot};
2+
use kr2r::compact_hash::{CHPage, Compact, HashConfig, K2Compact, Row, Slot};
33
use kr2r::utils::find_and_sort_files;
44
// use std::collections::HashMap;
55
use rayon::prelude::*;
@@ -103,12 +103,15 @@ where
103103
R: Read + Send,
104104
{
105105
let slot_size = std::mem::size_of::<Slot<u64>>();
106+
let row_size = std::mem::size_of::<Row>();
106107
let mut batch_buffer = vec![0u8; slot_size * batch_size];
107108
let mut last_file_index: Option<u64> = None;
108109
let mut writer: Option<BufWriter<File>> = None;
109110

110111
let value_mask = chtm.get_value_mask();
111112
let value_bits = chtm.get_value_bits();
113+
let idx_mask = chtm.get_idx_mask();
114+
let idx_bits = chtm.get_idx_bits();
112115

113116
while let Ok(bytes_read) = reader.read(&mut batch_buffer) {
114117
if bytes_read == 0 {
@@ -125,14 +128,19 @@ where
125128
let result: HashMap<u64, Vec<u8>> = slots
126129
.into_par_iter()
127130
.filter_map(|slot| {
128-
let taxid = chtm.get_from_page(slot);
131+
let indx = slot.idx & idx_mask;
132+
let taxid = chtm.get_from_page(indx, slot.value);
129133

130134
if taxid > 0 {
135+
let kmer_id = slot.idx >> idx_bits;
131136
let file_index = slot.value.right(value_mask) >> 32;
137+
let seq_id = slot.get_seq_id() as u32;
132138
let left = slot.value.left(value_bits) as u32;
133-
let high = u32::combined(left, taxid, value_bits) as u64;
134-
let value = slot.to_b(high);
135-
let value_bytes = value.to_le_bytes(); // 将u64转换为[u8; 8]
139+
let high = u32::combined(left, taxid, value_bits);
140+
let row = Row::new(high, seq_id, kmer_id as u32);
141+
// let value = slot.to_b(high);
142+
// let value_bytes = value.to_le_bytes(); // 将u64转换为[u8; 8]
143+
let value_bytes = row.as_slice(row_size);
136144
Some((file_index, value_bytes.to_vec()))
137145
} else {
138146
None
@@ -223,7 +231,7 @@ pub fn run(args: Args) -> Result<()> {
223231
// 计算持续时间
224232
let duration = start.elapsed();
225233
// 打印运行时间
226-
println!("squid took: {:?}", duration);
234+
println!("annotate took: {:?}", duration);
227235

228236
Ok(())
229237
}

kr2r/src/bin/resolve.rs

Lines changed: 150 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use clap::Parser;
2-
use dashmap::DashMap;
3-
use kr2r::compact_hash::{Compact, HashConfig};
2+
use dashmap::{DashMap, DashSet};
3+
use kr2r::compact_hash::{Compact, HashConfig, Row};
44
use kr2r::iclassify::{resolve_tree, trim_pair_info};
55
use kr2r::readcounts::{TaxonCounters, TaxonCountersDash};
66
use kr2r::report::report_kraken_style;
77
use kr2r::taxonomy::Taxonomy;
88
use kr2r::utils::find_and_sort_files;
99
use rayon::prelude::*;
10-
use std::collections::{HashMap, HashSet};
10+
use std::collections::HashMap;
1111
use std::fs::File;
1212
use std::io::{self, BufRead, BufReader, BufWriter, Read, Result, Write};
1313
use std::path::{Path, PathBuf};
@@ -16,57 +16,130 @@ use std::sync::Mutex;
1616

1717
const BATCH_SIZE: usize = 8 * 1024 * 1024;
1818

19-
pub fn read_id_to_seq_map<P: AsRef<Path>>(filename: P) -> Result<DashMap<u32, (String, usize)>> {
19+
pub fn read_id_to_seq_map<P: AsRef<Path>>(
20+
filename: P,
21+
) -> Result<DashMap<u32, (String, String, u32, Option<u32>)>> {
2022
let file = File::open(filename)?;
2123
let reader = BufReader::new(file);
2224
let id_map = DashMap::new();
2325

2426
reader.lines().par_bridge().for_each(|line| {
2527
let line = line.expect("Could not read line");
2628
let parts: Vec<&str> = line.trim().split_whitespace().collect();
27-
if parts.len() >= 3 {
29+
if parts.len() >= 4 {
2830
// 解析序号为u32类型的键
2931
if let Ok(id) = parts[0].parse::<u32>() {
3032
// 第二列是序列标识符,直接作为字符串
3133
let seq_id = parts[1].to_string();
32-
if let Ok(count) = parts[2].parse::<usize>() {
33-
// 插入到DashMap中
34-
id_map.insert(id, (seq_id, count));
35-
}
34+
let seq_size = parts[2].to_string();
35+
let count_parts: Vec<&str> = parts[3].split('|').collect();
36+
let kmer_count1 = count_parts[0].parse::<u32>().unwrap();
37+
let kmer_count2 = count_parts[1].parse::<u32>().map_or(None, |i| Some(i));
38+
id_map.insert(id, (seq_id, seq_size, kmer_count1, kmer_count2));
3639
}
3740
}
3841
});
3942

4043
Ok(id_map)
4144
}
4245

46+
fn generate_hit_string(
47+
count: u32,
48+
rows: &Vec<Row>,
49+
taxonomy: &Taxonomy,
50+
value_mask: usize,
51+
offset: u32,
52+
) -> String {
53+
let mut result = String::new();
54+
let mut last_pos = 0;
55+
let mut has_key = false; // 标记是否处理了特定位置
56+
57+
for row in rows {
58+
let value = row.value;
59+
let key = value.right(value_mask);
60+
let ext_code = taxonomy.nodes[key as usize].external_id;
61+
62+
// 忽略不在当前段的位置
63+
if row.kmer_id < offset || row.kmer_id >= offset + count {
64+
continue;
65+
}
66+
// 调整位置为相对于当前段的起始
67+
let adjusted_pos = row.kmer_id - offset;
68+
// 填充前导0
69+
if adjusted_pos > last_pos {
70+
if has_key {
71+
result.push_str(&format!("0:{} ", adjusted_pos - last_pos - 1));
72+
} else {
73+
result.push_str(&format!("0:{} ", adjusted_pos));
74+
}
75+
}
76+
// 添加当前键的计数
77+
result.push_str(&format!("{}:1 ", ext_code));
78+
last_pos = adjusted_pos;
79+
has_key = true;
80+
}
81+
82+
// 填充尾随0
83+
if last_pos < count - 1 {
84+
if has_key {
85+
result.push_str(&format!("0:{} ", count - last_pos - 1));
86+
} else {
87+
result.push_str(&format!("0:{} ", count));
88+
}
89+
}
90+
91+
result.trim_end().to_string()
92+
}
93+
94+
pub fn add_hitlist_string(
95+
rows: &Vec<Row>,
96+
value_mask: usize,
97+
kmer_count1: u32,
98+
kmer_count2: Option<u32>,
99+
taxonomy: &Taxonomy,
100+
) -> String {
101+
let result1 = generate_hit_string(kmer_count1, &rows, taxonomy, value_mask, 0);
102+
if let Some(count) = kmer_count2 {
103+
let result2 = generate_hit_string(count, &rows, taxonomy, value_mask, kmer_count1);
104+
format!("{} |:| {}", result1, result2)
105+
} else {
106+
format!("{}", result1)
107+
}
108+
}
109+
43110
pub fn count_values(
44-
vec: Vec<u32>,
111+
rows: &Vec<Row>,
45112
value_mask: usize,
113+
kmer_count1: u32,
46114
) -> (HashMap<u32, u64>, TaxonCountersDash, usize) {
47115
let mut counts = HashMap::new();
48116

49-
let mut unique_elements = HashSet::new();
117+
let mut hit_count: usize = 0;
50118

119+
let mut last_row: Row = Row::new(0, 0, 0);
51120
let cur_taxon_counts = TaxonCountersDash::new();
52121

53-
for value in vec {
54-
// 使用entry API处理计数
55-
// entry返回的是一个Entry枚举,它代表了可能存在也可能不存在的值
56-
// or_insert方法在键不存在时插入默认值(在这里是0)
57-
// 然后无论哪种情况,我们都对计数器加1
122+
for row in rows {
123+
let value = row.value;
58124
let key = value.right(value_mask);
59125
*counts.entry(key).or_insert(0) += 1;
60-
if !unique_elements.contains(&value) {
126+
127+
// 如果切换到第2条seq,就重新计算
128+
if last_row.kmer_id < kmer_count1 && row.kmer_id > kmer_count1 {
129+
last_row = Row::new(0, 0, 0);
130+
}
131+
if !(last_row.value == value && row.kmer_id - last_row.kmer_id == 1) {
61132
cur_taxon_counts
62133
.entry(key as u64)
63134
.or_default()
64135
.add_kmer(value as u64);
136+
hit_count += 1;
65137
}
66-
unique_elements.insert(value);
138+
139+
last_row = *row;
67140
}
68141

69-
(counts, cur_taxon_counts, unique_elements.len())
142+
(counts, cur_taxon_counts, hit_count)
70143
}
71144

72145
#[derive(Parser, Debug, Clone)]
@@ -84,13 +157,8 @@ pub struct Args {
84157
#[clap(long, value_parser, required = true)]
85158
pub chunk_dir: PathBuf,
86159

87-
// /// The file path for the Kraken 2 index.
88-
// #[clap(short = 'H', long = "index-filename", value_parser, required = true)]
89-
// index_filename: PathBuf,
90-
91-
// /// The file path for the Kraken 2 taxonomy.
92-
// #[clap(short = 't', long = "taxonomy-filename", value_parser, required = true)]
93-
// taxonomy_filename: String,
160+
#[clap(long, value_parser, default_value_t = false)]
161+
pub full_output: bool,
94162
/// Confidence score threshold, default is 0.0.
95163
#[clap(
96164
short = 'T',
@@ -126,20 +194,21 @@ pub struct Args {
126194
pub kraken_output_dir: Option<PathBuf>,
127195
}
128196

129-
fn process_batch<P: AsRef<Path>, B: Compact>(
197+
fn process_batch<P: AsRef<Path>>(
130198
sample_file: P,
131199
args: &Args,
132200
taxonomy: &Taxonomy,
133-
id_map: DashMap<u32, (String, usize)>,
134-
writer: Box<dyn Write + Send>,
201+
id_map: &DashMap<u32, (String, String, u32, Option<u32>)>,
202+
writer: &Mutex<Box<dyn Write + Send>>,
135203
value_mask: usize,
136-
) -> Result<(TaxonCountersDash, usize)> {
204+
) -> Result<(TaxonCountersDash, usize, DashSet<u32>)> {
137205
let file = File::open(sample_file)?;
138206
let mut reader = BufReader::new(file);
139-
let size = std::mem::size_of::<B>();
207+
let size = std::mem::size_of::<Row>();
140208
let mut batch_buffer = vec![0u8; size * BATCH_SIZE];
141209

142210
let hit_counts = DashMap::new();
211+
let hit_seq_id_set = DashSet::new();
143212
let confidence_threshold = args.confidence_threshold;
144213
let minimum_hit_groups = args.minimum_hit_groups;
145214

@@ -150,31 +219,36 @@ fn process_batch<P: AsRef<Path>, B: Compact>(
150219

151220
// 处理读取的数据批次
152221
let slots_in_batch = bytes_read / size;
153-
154222
let slots = unsafe {
155-
std::slice::from_raw_parts(batch_buffer.as_ptr() as *const B, slots_in_batch)
223+
std::slice::from_raw_parts(batch_buffer.as_ptr() as *const Row, slots_in_batch)
156224
};
157225

158226
slots.into_par_iter().for_each(|item| {
159-
let cell = item.left(0).to_u32();
160-
let seq_id = item.right(0).to_u32();
161-
hit_counts.entry(seq_id).or_insert_with(Vec::new).push(cell)
227+
let seq_id = item.seq_id;
228+
hit_seq_id_set.insert(seq_id);
229+
hit_counts
230+
.entry(seq_id)
231+
.or_insert_with(Vec::new)
232+
.push(*item)
162233
});
163234
}
164235

165-
let writer = Mutex::new(writer);
236+
// let writer = Mutex::new(writer);
166237
let classify_counter = AtomicUsize::new(0);
167238
let cur_taxon_counts = TaxonCountersDash::new();
168239

169-
hit_counts.into_par_iter().for_each(|(k, cells)| {
240+
hit_counts.into_par_iter().for_each(|(k, mut rows)| {
170241
if let Some(item) = id_map.get(&k) {
171-
let total_kmers: usize = item.1;
242+
rows.sort_unstable();
243+
let total_kmers: usize = item.2 as usize + item.3.unwrap_or(0) as usize;
172244
let dna_id = trim_pair_info(&item.0);
173-
let (counts, cur_counts, hit_groups) = count_values(cells, value_mask);
245+
let (counts, cur_counts, hit_groups) = count_values(&rows, value_mask, item.2);
246+
let hit_string = add_hitlist_string(&rows, value_mask, item.2, item.3, taxonomy);
174247
let mut call = resolve_tree(&counts, taxonomy, total_kmers, confidence_threshold);
175248
if call > 0 && hit_groups < minimum_hit_groups {
176249
call = 0;
177250
};
251+
178252
cur_counts.iter().for_each(|entry| {
179253
cur_taxon_counts
180254
.entry(*entry.key())
@@ -184,20 +258,31 @@ fn process_batch<P: AsRef<Path>, B: Compact>(
184258
});
185259

186260
let ext_call = taxonomy.nodes[call as usize].external_id;
187-
if call > 0 {
188-
let output_line = format!("C\t{}\t{}\n", dna_id, ext_call);
189-
// 使用锁来同步写入
190-
let mut file = writer.lock().unwrap();
191-
file.write_all(output_line.as_bytes()).unwrap();
261+
let clasify = if call > 0 {
192262
classify_counter.fetch_add(1, Ordering::SeqCst);
193263
cur_taxon_counts
194264
.entry(call as u64)
195265
.or_default()
196266
.increment_read_count();
197-
}
267+
268+
"C"
269+
} else {
270+
"U"
271+
};
272+
// 使用锁来同步写入
273+
let output_line = format!(
274+
"{}\t{}\t{}\t{}\t{}\n",
275+
clasify, dna_id, ext_call, item.1, hit_string
276+
);
277+
let mut file = writer.lock().unwrap();
278+
file.write_all(output_line.as_bytes()).unwrap();
198279
}
199280
});
200-
Ok((cur_taxon_counts, classify_counter.load(Ordering::SeqCst)))
281+
Ok((
282+
cur_taxon_counts,
283+
classify_counter.load(Ordering::SeqCst),
284+
hit_seq_id_set,
285+
))
201286
}
202287

203288
pub fn run(args: Args) -> Result<()> {
@@ -228,14 +313,29 @@ pub fn run(args: Args) -> Result<()> {
228313
}
229314
None => Box::new(io::stdout()) as Box<dyn Write + Send>,
230315
};
231-
let (thread_taxon_counts, thread_classified) = process_batch::<&PathBuf, u64>(
316+
let writer = Mutex::new(writer);
317+
let (thread_taxon_counts, thread_classified, hit_seq_set) = process_batch::<&PathBuf>(
232318
sample_file,
233319
&args,
234320
&taxo,
235-
sample_id_map,
236-
writer,
321+
&sample_id_map,
322+
&writer,
237323
value_mask,
238324
)?;
325+
326+
if args.full_output {
327+
sample_id_map
328+
.iter()
329+
.filter(|item| !hit_seq_set.contains(item.key()))
330+
.for_each(|item| {
331+
let dna_id = trim_pair_info(&item.0);
332+
let hit_string = add_hitlist_string(&vec![], value_mask, item.2, item.3, &taxo);
333+
let output_line = format!("U\t{}\t0\t{}\t{}\n", dna_id, item.1, hit_string);
334+
let mut file = writer.lock().unwrap();
335+
file.write_all(output_line.as_bytes()).unwrap();
336+
});
337+
}
338+
239339
thread_taxon_counts.iter().for_each(|entry| {
240340
total_taxon_counts
241341
.entry(*entry.key())

0 commit comments

Comments
 (0)