|
| 1 | +<?php |
| 2 | + |
| 3 | +namespace TextAnalysis\Classifiers; |
| 4 | + |
| 5 | +/** |
| 6 | + * Implementation of Naive Bayes algorithm, borrowed heavily from |
| 7 | + * https://github.com/fieg/bayes |
| 8 | + * @author yooper |
| 9 | + */ |
| 10 | +class NaiveBayes implements \TextAnalysis\Interfaces\IClassifier |
| 11 | +{ |
| 12 | + /** |
| 13 | + * Track token and counts for a given label |
| 14 | + * @var array |
| 15 | + */ |
| 16 | + protected $labels = []; |
| 17 | + |
| 18 | + /** |
| 19 | + * Track the number of docs with the given label |
| 20 | + * @var array[int] |
| 21 | + */ |
| 22 | + protected $labelCount = []; |
| 23 | + |
| 24 | + /** |
| 25 | + * Track the token counts |
| 26 | + * @var int[] |
| 27 | + */ |
| 28 | + protected $tokenCount = []; |
| 29 | + |
| 30 | + public function train(string $label, array $tokens) |
| 31 | + { |
| 32 | + $freqDist = array_count_values($tokens); |
| 33 | + if(!isset($this->labels[$label])) { |
| 34 | + $this->labels[$label] = []; |
| 35 | + $this->labelCount[$label] = 0; |
| 36 | + } |
| 37 | + |
| 38 | + $this->labelCount[$label]++; |
| 39 | + foreach($freqDist as $token => $count) |
| 40 | + { |
| 41 | + isset($this->tokenCount[$token]) ? $this->tokenCount[$token] += $count : $this->tokenCount[$token] = $count; |
| 42 | + isset($this->labels[$label][$token]) ? $this->labels[$label][$token] += $count : $this->labels[$label][$token] = $count; |
| 43 | + } |
| 44 | + } |
| 45 | + |
| 46 | + public function predict(array $tokens) |
| 47 | + { |
| 48 | + $totalDocs = $this->getDocCount(); |
| 49 | + $scores = []; |
| 50 | + |
| 51 | + foreach ($this->labelCount as $label => $docCount) |
| 52 | + { |
| 53 | + $sum = 0; |
| 54 | + $inversedDocCount = $totalDocs - $docCount; |
| 55 | + $docCountReciprocal = 1 / $docCount; |
| 56 | + $inversedDocCountReciprocal = 1 / $inversedDocCount; |
| 57 | + |
| 58 | + foreach ($tokens as $token) |
| 59 | + { |
| 60 | + $totalTokenCount = $this->tokenCount[$token] ?? 1; // prevent division by zero |
| 61 | + $tokenCount = $this->labels[$label][$token] ?? 0; |
| 62 | + $inversedTokenCount = $totalTokenCount - $tokenCount; |
| 63 | + $tokenProbabilityPositive = $tokenCount * $docCountReciprocal; |
| 64 | + $tokenProbabilityNegative = $inversedTokenCount * $inversedDocCountReciprocal; |
| 65 | + $probability = $tokenProbabilityPositive / ($tokenProbabilityPositive + $tokenProbabilityNegative); |
| 66 | + $probability = (0.5 + ($totalTokenCount * $probability)) / (1 + $totalTokenCount); |
| 67 | + $sum += log(1 - $probability) - log($probability); |
| 68 | + } |
| 69 | + $scores[$label] = 1 / (1 + exp($sum)); |
| 70 | + } |
| 71 | + arsort($scores, SORT_NUMERIC); |
| 72 | + return $scores; |
| 73 | + } |
| 74 | + |
| 75 | + public function getDocCount() : int |
| 76 | + { |
| 77 | + return array_sum( array_values( $this->labelCount)) ?? 0; |
| 78 | + } |
| 79 | + |
| 80 | + public function __destruct() |
| 81 | + { |
| 82 | + unset($this->labelCount); |
| 83 | + unset($this->labels); |
| 84 | + unset($this->tokenCount); |
| 85 | + } |
| 86 | + |
| 87 | + |
| 88 | +} |
0 commit comments