Skip to content

Commit 4a6df21

Browse files
authored
Merge pull request #27 from Synerise/seed-train-init
Init embedding with seed during training.
2 parents f00187a + 6ebeb61 commit 4a6df21

File tree

5 files changed

+26
-4
lines changed

5 files changed

+26
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ Command line options (for more info use `--help` as program argument):
187187
-r --relation-name (name of the relation, for output filename generation)
188188
-d --dimension (number of dimensions for output embeddings)
189189
-n --number-of-iterations (number of iterations for the algorithm, usually 3 or 4 works well)
190+
-s --seed (seed integer for embedding initialization)
190191
-c --columns (column format specification)
191192
-p --prepend-field-name (prepend field name to entity in output)
192193
-l --log-every-n (log output every N lines)

src/configuration.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ pub struct Configuration {
2222
/// Maximum number of iteration for training
2323
pub max_number_of_iteration: u8,
2424

25+
/// Seed for embedding initialization
26+
pub seed: Option<i64>,
27+
2528
/// Prepend field name to entity in the output file. It differentiates entities with the same
2629
/// name from different columns
2730
pub prepend_field: bool,
@@ -78,6 +81,7 @@ impl Configuration {
7881
produce_entity_occurrence_count: true,
7982
embeddings_dimension: 128,
8083
max_number_of_iteration: 4,
84+
seed: None,
8185
prepend_field: true,
8286
log_every_n: 1000,
8387
in_memory_embedding_calculation: true,

src/embedding.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ trait MatrixWrapper {
2828
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
2929
rows: usize,
3030
cols: usize,
31+
fixed_random_value: i64,
3132
sparse_matrix_reader: Arc<T>,
3233
) -> Self;
3334

@@ -55,14 +56,15 @@ impl MatrixWrapper for TwoDimVectorMatrix {
5556
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
5657
rows: usize,
5758
cols: usize,
59+
fixed_random_value: i64,
5860
sparse_matrix_reader: Arc<T>,
5961
) -> Self {
6062
let result: Vec<Vec<f32>> = (0..cols)
6163
.into_par_iter()
6264
.map(|i| {
6365
let mut col: Vec<f32> = Vec::with_capacity(rows);
6466
for hsh in sparse_matrix_reader.iter_hashes() {
65-
let col_value = init_value(i, hsh.value);
67+
let col_value = init_value(i, hsh.value, fixed_random_value);
6668
col.push(col_value);
6769
}
6870
col
@@ -129,8 +131,8 @@ impl MatrixWrapper for TwoDimVectorMatrix {
129131
}
130132
}
131133

132-
fn init_value(col: usize, hsh: u64) -> f32 {
133-
((hash((hsh as i64) + (col as i64)) % MAX_HASH_I64) as f32) / MAX_HASH_F32
134+
fn init_value(col: usize, hsh: u64, fixed_random_value: i64) -> f32 {
135+
((hash((hsh as i64) + (col as i64) + fixed_random_value) % MAX_HASH_I64) as f32) / MAX_HASH_F32
134136
}
135137

136138
fn hash(num: i64) -> i64 {
@@ -160,6 +162,7 @@ impl MatrixWrapper for MMapMatrix {
160162
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
161163
rows: usize,
162164
cols: usize,
165+
fixed_random_value: i64,
163166
sparse_matrix_reader: Arc<T>,
164167
) -> Self {
165168
let uuid = Uuid::new_v4();
@@ -172,7 +175,7 @@ impl MatrixWrapper for MMapMatrix {
172175
// i - number of dimension
173176
// chunk - column/vector of bytes
174177
for (j, hsh) in sparse_matrix_reader.iter_hashes().enumerate() {
175-
let col_value = init_value(i, hsh.value);
178+
let col_value = init_value(i, hsh.value, fixed_random_value);
176179
MMapMatrix::update_column(j, chunk, |value| unsafe { *value = col_value });
177180
}
178181
});
@@ -338,6 +341,7 @@ pub fn calculate_embeddings<T1, T2>(
338341
struct MatrixMultiplicator<T: SparseMatrixReader + Sync + Send, M: MatrixWrapper> {
339342
dimension: usize,
340343
number_of_entities: usize,
344+
fixed_random_value: i64,
341345
sparse_matrix_reader: Arc<T>,
342346
_marker: PhantomData<M>,
343347
}
@@ -348,9 +352,11 @@ where
348352
M: MatrixWrapper,
349353
{
350354
fn new(config: Arc<Configuration>, sparse_matrix_reader: Arc<T>) -> Self {
355+
let rand_value = config.seed.map(hash).unwrap_or(0);
351356
Self {
352357
dimension: config.embeddings_dimension as usize,
353358
number_of_entities: sparse_matrix_reader.get_number_of_entities() as usize,
359+
fixed_random_value: rand_value,
354360
sparse_matrix_reader,
355361
_marker: PhantomData,
356362
}
@@ -366,6 +372,7 @@ where
366372
let result = M::init_with_hashes(
367373
self.number_of_entities,
368374
self.dimension,
375+
self.fixed_random_value,
369376
self.sparse_matrix_reader.clone(),
370377
);
371378

src/main.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ fn main() {
6464
.help("Max number of iterations")
6565
.takes_value(true),
6666
)
67+
.arg(
68+
Arg::with_name("seed")
69+
.short("s")
70+
.long("seed")
71+
.help("Seed (integer) for embedding initialization")
72+
.takes_value(true),
73+
)
6774
.arg(
6875
Arg::with_name("columns")
6976
.short("c")
@@ -140,6 +147,7 @@ fn main() {
140147
.unwrap()
141148
.parse()
142149
.unwrap();
150+
let seed: Option<i64> = matches.value_of("seed").map(|s| s.parse().unwrap());
143151
let relation_name = matches.value_of("relation-name").unwrap();
144152
let prepend_field_name = {
145153
let value: u8 = matches
@@ -180,6 +188,7 @@ fn main() {
180188
produce_entity_occurrence_count: true,
181189
embeddings_dimension: dimension,
182190
max_number_of_iteration: max_iter,
191+
seed,
183192
prepend_field: prepend_field_name,
184193
log_every_n: log_every,
185194
in_memory_embedding_calculation,

tests/snapshot.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ fn prepare_config() -> Configuration {
8383
produce_entity_occurrence_count: true,
8484
embeddings_dimension: 128,
8585
max_number_of_iteration: 4,
86+
seed: None,
8687
prepend_field: false,
8788
log_every_n: 10000,
8889
in_memory_embedding_calculation: true,

0 commit comments

Comments
 (0)