diff --git a/tmu/clause_bank/clause_bank.py b/tmu/clause_bank/clause_bank.py index c4aa1021..90705cf4 100644 --- a/tmu/clause_bank/clause_bank.py +++ b/tmu/clause_bank/clause_bank.py @@ -157,6 +157,18 @@ def calculate_clause_outputs_predict(self, encoded_X, e): xi_p ) + lib.cb_calculate_clause_outputs_predict_recurrent( + self.ptr_ta_state, + self.number_of_clauses, + self.number_of_literals, + self.number_of_state_bits_ta, + self.number_of_patches, + self.co_p, + xi_p + ) + + return self.clause_output + if not self.incremental: lib.cb_calculate_clause_outputs_predict( self.ptr_ta_state, @@ -167,6 +179,7 @@ def calculate_clause_outputs_predict(self, encoded_X, e): self.co_p, xi_p ) + return self.clause_output xi_p = ffi.cast("unsigned int *", encoded_X[e, :].ctypes.data) @@ -216,16 +229,27 @@ def calculate_clause_outputs_update(self, literal_active, encoded_X, e): xi_p ) - lib.cb_calculate_clause_outputs_update( - self.ptr_ta_state, - self.number_of_clauses, - self.number_of_literals, - self.number_of_state_bits_ta, - self.number_of_patches, - self.co_p, - la_p, - xi_p - ) + lib.cb_calculate_clause_outputs_update_recurrent( + self.ptr_ta_state, + self.number_of_clauses, + self.number_of_literals, + self.number_of_state_bits_ta, + self.number_of_patches, + self.co_p, + la_p, + xi_p + ) + else: + lib.cb_calculate_clause_outputs_update( + self.ptr_ta_state, + self.number_of_clauses, + self.number_of_literals, + self.number_of_state_bits_ta, + self.number_of_patches, + self.co_p, + la_p, + xi_p + ) return self.clause_output