diff --git a/libautoscoper/src/PSO.cpp b/libautoscoper/src/PSO.cpp index be6808d4..5aa42ffd 100644 --- a/libautoscoper/src/PSO.cpp +++ b/libautoscoper/src/PSO.cpp @@ -2,6 +2,11 @@ #include #include #include +#include +#include + +// Prevents multiple threads from writing to the same memory location +std::mutex MTX; // New Particle Swarm Optimization float host_fitness_function(std::vector x) @@ -46,6 +51,27 @@ float getRandomClamped() return (float)rand() / (float)RAND_MAX; } +void thread_handler(Particle* p, Particle* pBest, Particle* gBest, float OMEGA) { + // Update the velocities and positions + p->updateVelocityAndPosition(pBest, gBest, OMEGA); + + // Get the NCC of the current particle + p->ncc_val = host_fitness_function(p->position); + + // Update the pBest if the current particle is better + if (p->ncc_val < pBest->ncc_val) { + *pBest = *p; + } + + // Critial Section + MTX.lock(); + // Update the gBest if the current particle is better + if (p->ncc_val < gBest->ncc_val) { + *gBest = *p; + } + MTX.unlock(); +} + void pso(std::vector* particles, Particle* gBest, unsigned int MAX_EPOCHS, unsigned int MAX_STALL) { int stall_iter = 0; @@ -80,28 +106,20 @@ void pso(std::vector* particles, Particle* gBest, unsigned int MAX_EP *currentBest = *gBest; // We want this to be a copy not a pointer + // Create a thread for each particle + std::vector threads; for (int i = 0; i < NUM_OF_PARTICLES; i++) { p = particles->at(i); // We want these to be pointers not copies curPBest = pBest.at(i); - // Update the velocities and positions - p->updateVelocityAndPosition(curPBest, gBest, OMEGA); - - // Get the NCC of the current particle - p->ncc_val = host_fitness_function(p->position); - - // Update the pBest if the current particle is better - if (p->ncc_val < curPBest->ncc_val) { - *curPBest = *p; - } - - // Update the gBest if the current particle is better - if (p->ncc_val < gBest->ncc_val) { - *gBest = *p; - } + threads.push_back(std::thread(thread_handler, p, curPBest, currentBest, OMEGA)); } + // Wait for all threads to finish + for (auto& t : threads) t.join(); + + // Update the OMEGA OMEGA = OMEGA * 0.9f; std::cout << "Current Best NCC: " << gBest->ncc_val << std::endl;