@@ -28,6 +28,7 @@ trait MatrixWrapper {
2828 fn init_with_hashes < T : SparseMatrixReader + Sync + Send > (
2929 rows : usize ,
3030 cols : usize ,
31+ fixed_random_value : i64 ,
3132 sparse_matrix_reader : Arc < T > ,
3233 ) -> Self ;
3334
@@ -55,14 +56,15 @@ impl MatrixWrapper for TwoDimVectorMatrix {
5556 fn init_with_hashes < T : SparseMatrixReader + Sync + Send > (
5657 rows : usize ,
5758 cols : usize ,
59+ fixed_random_value : i64 ,
5860 sparse_matrix_reader : Arc < T > ,
5961 ) -> Self {
6062 let result: Vec < Vec < f32 > > = ( 0 ..cols)
6163 . into_par_iter ( )
6264 . map ( |i| {
6365 let mut col: Vec < f32 > = Vec :: with_capacity ( rows) ;
6466 for hsh in sparse_matrix_reader. iter_hashes ( ) {
65- let col_value = init_value ( i, hsh. value ) ;
67+ let col_value = init_value ( i, hsh. value , fixed_random_value ) ;
6668 col. push ( col_value) ;
6769 }
6870 col
@@ -129,8 +131,8 @@ impl MatrixWrapper for TwoDimVectorMatrix {
129131 }
130132}
131133
132- fn init_value ( col : usize , hsh : u64 ) -> f32 {
133- ( ( hash ( ( hsh as i64 ) + ( col as i64 ) ) % MAX_HASH_I64 ) as f32 ) / MAX_HASH_F32
134+ fn init_value ( col : usize , hsh : u64 , fixed_random_value : i64 ) -> f32 {
135+ ( ( hash ( ( hsh as i64 ) + ( col as i64 ) + fixed_random_value ) % MAX_HASH_I64 ) as f32 ) / MAX_HASH_F32
134136}
135137
136138fn hash ( num : i64 ) -> i64 {
@@ -160,6 +162,7 @@ impl MatrixWrapper for MMapMatrix {
160162 fn init_with_hashes < T : SparseMatrixReader + Sync + Send > (
161163 rows : usize ,
162164 cols : usize ,
165+ fixed_random_value : i64 ,
163166 sparse_matrix_reader : Arc < T > ,
164167 ) -> Self {
165168 let uuid = Uuid :: new_v4 ( ) ;
@@ -172,7 +175,7 @@ impl MatrixWrapper for MMapMatrix {
172175 // i - number of dimension
173176 // chunk - column/vector of bytes
174177 for ( j, hsh) in sparse_matrix_reader. iter_hashes ( ) . enumerate ( ) {
175- let col_value = init_value ( i, hsh. value ) ;
178+ let col_value = init_value ( i, hsh. value , fixed_random_value ) ;
176179 MMapMatrix :: update_column ( j, chunk, |value| unsafe { * value = col_value } ) ;
177180 }
178181 } ) ;
@@ -338,6 +341,7 @@ pub fn calculate_embeddings<T1, T2>(
338341struct MatrixMultiplicator < T : SparseMatrixReader + Sync + Send , M : MatrixWrapper > {
339342 dimension : usize ,
340343 number_of_entities : usize ,
344+ fixed_random_value : i64 ,
341345 sparse_matrix_reader : Arc < T > ,
342346 _marker : PhantomData < M > ,
343347}
@@ -348,9 +352,11 @@ where
348352 M : MatrixWrapper ,
349353{
350354 fn new ( config : Arc < Configuration > , sparse_matrix_reader : Arc < T > ) -> Self {
355+ let rand_value = config. seed . map ( hash) . unwrap_or ( 0 ) ;
351356 Self {
352357 dimension : config. embeddings_dimension as usize ,
353358 number_of_entities : sparse_matrix_reader. get_number_of_entities ( ) as usize ,
359+ fixed_random_value : rand_value,
354360 sparse_matrix_reader,
355361 _marker : PhantomData ,
356362 }
@@ -366,6 +372,7 @@ where
366372 let result = M :: init_with_hashes (
367373 self . number_of_entities ,
368374 self . dimension ,
375+ self . fixed_random_value ,
369376 self . sparse_matrix_reader . clone ( ) ,
370377 ) ;
371378
0 commit comments