|
17 | 17 | package org.deidentifier.arx.framework.check.distribution;
|
18 | 18 |
|
19 | 19 | import java.io.Serializable;
|
| 20 | +import java.security.SecureRandom; |
20 | 21 | import java.util.ArrayList;
|
21 | 22 | import java.util.Arrays;
|
22 | 23 | import java.util.Collections;
|
23 | 24 | import java.util.Iterator;
|
24 | 25 | import java.util.List;
|
| 26 | +import java.util.Map; |
| 27 | +import java.util.Random; |
25 | 28 |
|
26 | 29 | import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
|
27 | 30 | import org.deidentifier.arx.DataType;
|
28 | 31 | import org.deidentifier.arx.DataType.DataTypeWithRatioScale;
|
29 | 32 |
|
| 33 | +import com.carrotsearch.hppc.IntArrayList; |
| 34 | +import com.carrotsearch.hppc.IntDoubleOpenHashMap; |
| 35 | + |
30 | 36 | import cern.colt.GenericSorting;
|
31 | 37 | import cern.colt.Swapper;
|
32 | 38 | import cern.colt.function.IntComparator;
|
@@ -459,7 +465,6 @@ private <T> T getValueAt(List<T> values, List<Integer> frequencies, int index) {
|
459 | 465 | }
|
460 | 466 | }
|
461 | 467 |
|
462 |
| - |
463 | 468 | /**
|
464 | 469 | * This class calculates the mode for a given distribution.
|
465 | 470 | *
|
@@ -583,6 +588,223 @@ private int getMode(Distribution distribution) {
|
583 | 588 | }
|
584 | 589 | }
|
585 | 590 |
|
| 591 | + /** |
| 592 | + * This class calculates the mode for a given distribution falling back to drawing from multiple values that would qualify as |
| 593 | + * mode using the provided distribution |
| 594 | + * |
| 595 | + * @author Fabian Prasser |
| 596 | + * |
| 597 | + */ |
| 598 | + public static class DistributionAggregateFunctionModeWithDistributionFallback extends DistributionAggregateFunction { |
| 599 | + |
| 600 | + /** SVUID. */ |
| 601 | + private static final long serialVersionUID = 6285156778817664604L; |
| 602 | + |
| 603 | + /** Minimum */ |
| 604 | + private double minimum = 0d; |
| 605 | + |
| 606 | + /** Maximum */ |
| 607 | + private double maximum = 0d; |
| 608 | + |
| 609 | + /** Distribution*/ |
| 610 | + private Map<String, Double> distribution; |
| 611 | + |
| 612 | + /** Integer distribution*/ |
| 613 | + private IntDoubleOpenHashMap intDistribution; |
| 614 | + |
| 615 | + /** The seed to use*/ |
| 616 | + private Long seed; |
| 617 | + |
| 618 | + /** The random source to use*/ |
| 619 | + private Random random; |
| 620 | + |
| 621 | + /** |
| 622 | + * Instantiates. |
| 623 | + * |
| 624 | + * @param ignoreMissingData |
| 625 | + * @param distribution |
| 626 | + * @param seed Maybe null |
| 627 | + */ |
| 628 | + public DistributionAggregateFunctionModeWithDistributionFallback(boolean ignoreMissingData, |
| 629 | + Map<String, Double> distribution, |
| 630 | + Long seed) { |
| 631 | + super(ignoreMissingData, true); |
| 632 | + this.distribution = distribution; |
| 633 | + this.seed = seed; |
| 634 | + if (this.seed == null) { |
| 635 | + this.random = new SecureRandom(); |
| 636 | + } else { |
| 637 | + this.random = new Random(this.seed); |
| 638 | + } |
| 639 | + } |
| 640 | + |
| 641 | + /** |
| 642 | + * Clone constructor |
| 643 | + * @param ignoreMissingData |
| 644 | + * @param minimum |
| 645 | + * @param maximum |
| 646 | + * @param distribution |
| 647 | + * @param seed Maybe null |
| 648 | + */ |
| 649 | + private DistributionAggregateFunctionModeWithDistributionFallback(boolean ignoreMissingData, |
| 650 | + double minimum, |
| 651 | + double maximum, |
| 652 | + Map<String, Double> distribution, |
| 653 | + Long seed) { |
| 654 | + this(ignoreMissingData, distribution, seed); |
| 655 | + this.minimum = minimum; |
| 656 | + this.maximum = maximum; |
| 657 | + } |
| 658 | + |
| 659 | + @Override |
| 660 | + public <T> String aggregate(Distribution distribution) { |
| 661 | + |
| 662 | + // Determine mode |
| 663 | + int mode = getModeWithDistributionFallback(distribution); |
| 664 | + return mode == -1 ? DataType.NULL_VALUE : dictionary[mode]; |
| 665 | + } |
| 666 | + |
| 667 | + /** |
| 668 | + * Clone method |
| 669 | + */ |
| 670 | + public DistributionAggregateFunctionModeWithDistributionFallback clone() { |
| 671 | + DistributionAggregateFunctionModeWithDistributionFallback result = new DistributionAggregateFunctionModeWithDistributionFallback(this.ignoreMissingData, |
| 672 | + this.minimum, |
| 673 | + this.maximum, |
| 674 | + this.distribution, |
| 675 | + this.seed); |
| 676 | + if (dictionary != null) { |
| 677 | + result.initialize(dictionary, type); |
| 678 | + } |
| 679 | + return result; |
| 680 | + } |
| 681 | + |
| 682 | + @Override |
| 683 | + public <T> double getError(Distribution distribution) { |
| 684 | + |
| 685 | + if (!(type instanceof DataTypeWithRatioScale)) { |
| 686 | + return 0d; |
| 687 | + } |
| 688 | + |
| 689 | + @SuppressWarnings("unchecked") |
| 690 | + DataTypeWithRatioScale<T> rType = (DataTypeWithRatioScale<T>) this.type; |
| 691 | + DoubleArrayList list = new DoubleArrayList(); |
| 692 | + Iterator<Double> it = DistributionIterator.createIteratorDouble(distribution, dictionary, rType); |
| 693 | + while (it.hasNext()) { |
| 694 | + Double value = it.next(); |
| 695 | + value = value == null ? (ignoreMissingData ? null : 0d) : value; |
| 696 | + if (value != null) { |
| 697 | + list.add(value); |
| 698 | + } |
| 699 | + } |
| 700 | + |
| 701 | + // Determine and check mode |
| 702 | + int mode = getModeWithDistributionFallback(distribution); |
| 703 | + if (mode == -1) { |
| 704 | + return 1d; |
| 705 | + } |
| 706 | + |
| 707 | + // Compute error |
| 708 | + return getNMSE(minimum, maximum, Arrays.copyOf(list.elements(), list.size()), |
| 709 | + rType.toDouble(rType.parse(dictionary[mode]))); |
| 710 | + } |
| 711 | + |
| 712 | + @Override |
| 713 | + public void initialize(String[] dictionary, DataType<?> type) { |
| 714 | + super.initialize(dictionary, type); |
| 715 | + if (type instanceof DataTypeWithRatioScale) { |
| 716 | + double[] values = getMinMax(dictionary, (DataTypeWithRatioScale<?>)type); |
| 717 | + this.minimum = values[0]; |
| 718 | + this.maximum = values[1]; |
| 719 | + } |
| 720 | + intDistribution = new IntDoubleOpenHashMap(); |
| 721 | + int index = 0; |
| 722 | + for (String value : dictionary) { |
| 723 | + Double frequency = this.distribution.get(value); |
| 724 | + if (frequency != null) { |
| 725 | + intDistribution.put(index, frequency); |
| 726 | + } |
| 727 | + index++; |
| 728 | + } |
| 729 | + } |
| 730 | + |
| 731 | + /** |
| 732 | + * Returns the index of the most frequent element from the distribution. Draws from the most frequent items if |
| 733 | + * there are multiple, using the provided distribution. Returns -1 if there is no such element. |
| 734 | + * @param distribution |
| 735 | + * @return |
| 736 | + */ |
| 737 | + private int getModeWithDistributionFallback(Distribution distribution) { |
| 738 | + |
| 739 | + // Prepare |
| 740 | + int[] buckets = distribution.getBuckets(); |
| 741 | + int max = -1; |
| 742 | + IntArrayList mode = new IntArrayList(); |
| 743 | + |
| 744 | + // Iterate through distribution, collecting the mode |
| 745 | + for (int i = 0; i < buckets.length; i += 2) { |
| 746 | + int value = buckets[i]; |
| 747 | + int frequency = buckets[i + 1]; |
| 748 | + if (value != -1) { |
| 749 | + // Same frequency |
| 750 | + if (Math.abs(max - frequency) < 1e-9) { |
| 751 | + mode.add(value); |
| 752 | + // More frequent |
| 753 | + } else if (frequency > max) { |
| 754 | + max = frequency; |
| 755 | + mode.clear(); |
| 756 | + mode.add(value); |
| 757 | + } |
| 758 | + } |
| 759 | + } |
| 760 | + |
| 761 | + // Weird |
| 762 | + if (mode.isEmpty()) { |
| 763 | + return -1; |
| 764 | + |
| 765 | + // Exactly one mode |
| 766 | + } else if (mode.size() == 1) { |
| 767 | + return mode.get(0); |
| 768 | + |
| 769 | + // Need to draw from distribution |
| 770 | + } else { |
| 771 | + |
| 772 | + // Collect frequencies |
| 773 | + DoubleArrayList frequencies = new DoubleArrayList(); |
| 774 | + for (int i = 0; i < mode.size(); i++) { |
| 775 | + int code = mode.get(i); |
| 776 | + if (this.intDistribution.containsKey(code)) { |
| 777 | + frequencies.add(this.intDistribution.get(code)); |
| 778 | + } else { |
| 779 | + frequencies.add(0d); |
| 780 | + } |
| 781 | + } |
| 782 | + |
| 783 | + // Convert frequencies to cumulative frequencies |
| 784 | + for (int i = 1; i < frequencies.size(); i++) { |
| 785 | + frequencies.set(i, frequencies.get(i) + frequencies.get(i - 1)); |
| 786 | + } |
| 787 | + |
| 788 | + // Normalize |
| 789 | + double maxFrequency = frequencies.get(frequencies.size() - 1); |
| 790 | + for (int i = 0; i < frequencies.size(); i++) { |
| 791 | + frequencies.set(i, frequencies.get(i) / maxFrequency); |
| 792 | + } |
| 793 | + |
| 794 | + // Draw |
| 795 | + double r = random.nextDouble(); |
| 796 | + for (int i = 0; i < frequencies.size(); i++) { |
| 797 | + if (r <= frequencies.get(i)) { |
| 798 | + return mode.get(i); |
| 799 | + } |
| 800 | + } |
| 801 | + } |
| 802 | + |
| 803 | + // Should never happen |
| 804 | + return -1; |
| 805 | + } |
| 806 | + } |
| 807 | + |
586 | 808 | /**
|
587 | 809 | * This class calculates a set for a given distribution.
|
588 | 810 | *
|
|
0 commit comments