Skip to content

Commit 668f8aa

Browse files
committed
Implement microaggregation function mode with fallback distribution
1 parent 5d9add0 commit 668f8aa

File tree

3 files changed

+269
-1
lines changed

3 files changed

+269
-1
lines changed

src/main/org/deidentifier/arx/AttributeType.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
import java.util.Arrays;
2929
import java.util.Iterator;
3030
import java.util.List;
31+
import java.util.Map;
3132

3233
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction;
3334
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionArithmeticMean;
3435
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionGeometricMean;
3536
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionInterval;
3637
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionMedian;
3738
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionMode;
39+
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionModeWithDistributionFallback;
3840
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionSet;
3941
import org.deidentifier.arx.io.CSVDataOutput;
4042
import org.deidentifier.arx.io.CSVHierarchyInput;
@@ -732,6 +734,18 @@ public static MicroAggregationFunction createMedian(boolean ignoreMissingData) {
732734
public static MicroAggregationFunction createMode() {
733735
return createMode(true);
734736
}
737+
738+
/**
739+
* Creates a microaggregation function returning the mode. If more than one value qualifies as mode,
740+
* the function draws from the qualifying values using the provided distribution. Ignores missing data.
741+
*
742+
* @param distribution Map from values to frequencies
743+
* @param seed Seed to use for drawing, can be null
744+
* @return
745+
*/
746+
public static MicroAggregationFunction createModeWithDistributionFallback(Map<String, Double> distribution, long seed) {
747+
return createModeWithDistributionFallback(true, distribution, seed);
748+
}
735749

736750
/**
737751
* Creates a microaggregation function returning the mode.
@@ -743,6 +757,23 @@ public static MicroAggregationFunction createMode(boolean ignoreMissingData) {
743757
return new MicroAggregationFunction(new DistributionAggregateFunctionMode(ignoreMissingData),
744758
DataScale.NOMINAL, "Mode");
745759
}
760+
761+
/**
762+
* Creates a microaggregation function returning the mode. If more than one value qualifies as mode,
763+
* the function draws from the qualifying values using the provided distribution.
764+
*
765+
* @param ignoreMissingData Should the function ignore missing data. Default is true.
766+
* @param distribution Map from values to frequencies
767+
* @param seed Seed to use for drawing, can be null
768+
* @return
769+
*/
770+
public static MicroAggregationFunction createModeWithDistributionFallback(boolean ignoreMissingData,
771+
Map<String, Double> distribution,
772+
Long seed) {
773+
return new MicroAggregationFunction(new DistributionAggregateFunctionModeWithDistributionFallback(ignoreMissingData, distribution, seed),
774+
DataScale.NOMINAL, "Mode with distribution fallback");
775+
}
776+
746777
/**
747778
* Creates a microaggregation function returning sets. This variant will ignore missing data.
748779
*/

src/main/org/deidentifier/arx/aggregates/StatisticsFrequencyDistribution.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
*/
1717
package org.deidentifier.arx.aggregates;
1818

19+
import java.util.LinkedHashMap;
20+
import java.util.Map;
21+
1922
/**
2023
* A frequency distribution.
2124
*
@@ -44,4 +47,16 @@ public class StatisticsFrequencyDistribution {
4447
this.count = count;
4548
this.frequency = frequency;
4649
}
50+
51+
/**
52+
* Returns the distribution as a map
53+
* @return
54+
*/
55+
public Map<String, Double> asMap() {
56+
Map<String, Double> map = new LinkedHashMap<>();
57+
for (int i = 0; i < frequency.length; i++) {
58+
map.put(values[i], frequency[i]);
59+
}
60+
return map;
61+
}
4762
}

src/main/org/deidentifier/arx/framework/check/distribution/DistributionAggregateFunction.java

Lines changed: 223 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@
1717
package org.deidentifier.arx.framework.check.distribution;
1818

1919
import java.io.Serializable;
20+
import java.security.SecureRandom;
2021
import java.util.ArrayList;
2122
import java.util.Arrays;
2223
import java.util.Collections;
2324
import java.util.Iterator;
2425
import java.util.List;
26+
import java.util.Map;
27+
import java.util.Random;
2528

2629
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
2730
import org.deidentifier.arx.DataType;
2831
import org.deidentifier.arx.DataType.DataTypeWithRatioScale;
2932

33+
import com.carrotsearch.hppc.IntArrayList;
34+
import com.carrotsearch.hppc.IntDoubleOpenHashMap;
35+
3036
import cern.colt.GenericSorting;
3137
import cern.colt.Swapper;
3238
import cern.colt.function.IntComparator;
@@ -459,7 +465,6 @@ private <T> T getValueAt(List<T> values, List<Integer> frequencies, int index) {
459465
}
460466
}
461467

462-
463468
/**
464469
* This class calculates the mode for a given distribution.
465470
*
@@ -583,6 +588,223 @@ private int getMode(Distribution distribution) {
583588
}
584589
}
585590

591+
/**
592+
* This class calculates the mode for a given distribution falling back to drawing from multiple values that would qualify as
593+
* mode using the provided distribution
594+
*
595+
* @author Fabian Prasser
596+
*
597+
*/
598+
public static class DistributionAggregateFunctionModeWithDistributionFallback extends DistributionAggregateFunction {
599+
600+
/** SVUID. */
601+
private static final long serialVersionUID = 6285156778817664604L;
602+
603+
/** Minimum */
604+
private double minimum = 0d;
605+
606+
/** Maximum */
607+
private double maximum = 0d;
608+
609+
/** Distribution*/
610+
private Map<String, Double> distribution;
611+
612+
/** Integer distribution*/
613+
private IntDoubleOpenHashMap intDistribution;
614+
615+
/** The seed to use*/
616+
private Long seed;
617+
618+
/** The random source to use*/
619+
private Random random;
620+
621+
/**
622+
* Instantiates.
623+
*
624+
* @param ignoreMissingData
625+
* @param distribution
626+
* @param seed Maybe null
627+
*/
628+
public DistributionAggregateFunctionModeWithDistributionFallback(boolean ignoreMissingData,
629+
Map<String, Double> distribution,
630+
Long seed) {
631+
super(ignoreMissingData, true);
632+
this.distribution = distribution;
633+
this.seed = seed;
634+
if (this.seed == null) {
635+
this.random = new SecureRandom();
636+
} else {
637+
this.random = new Random(this.seed);
638+
}
639+
}
640+
641+
/**
642+
* Clone constructor
643+
* @param ignoreMissingData
644+
* @param minimum
645+
* @param maximum
646+
* @param distribution
647+
* @param seed Maybe null
648+
*/
649+
private DistributionAggregateFunctionModeWithDistributionFallback(boolean ignoreMissingData,
650+
double minimum,
651+
double maximum,
652+
Map<String, Double> distribution,
653+
Long seed) {
654+
this(ignoreMissingData, distribution, seed);
655+
this.minimum = minimum;
656+
this.maximum = maximum;
657+
}
658+
659+
@Override
660+
public <T> String aggregate(Distribution distribution) {
661+
662+
// Determine mode
663+
int mode = getModeWithDistributionFallback(distribution);
664+
return mode == -1 ? DataType.NULL_VALUE : dictionary[mode];
665+
}
666+
667+
/**
668+
* Clone method
669+
*/
670+
public DistributionAggregateFunctionModeWithDistributionFallback clone() {
671+
DistributionAggregateFunctionModeWithDistributionFallback result = new DistributionAggregateFunctionModeWithDistributionFallback(this.ignoreMissingData,
672+
this.minimum,
673+
this.maximum,
674+
this.distribution,
675+
this.seed);
676+
if (dictionary != null) {
677+
result.initialize(dictionary, type);
678+
}
679+
return result;
680+
}
681+
682+
@Override
683+
public <T> double getError(Distribution distribution) {
684+
685+
if (!(type instanceof DataTypeWithRatioScale)) {
686+
return 0d;
687+
}
688+
689+
@SuppressWarnings("unchecked")
690+
DataTypeWithRatioScale<T> rType = (DataTypeWithRatioScale<T>) this.type;
691+
DoubleArrayList list = new DoubleArrayList();
692+
Iterator<Double> it = DistributionIterator.createIteratorDouble(distribution, dictionary, rType);
693+
while (it.hasNext()) {
694+
Double value = it.next();
695+
value = value == null ? (ignoreMissingData ? null : 0d) : value;
696+
if (value != null) {
697+
list.add(value);
698+
}
699+
}
700+
701+
// Determine and check mode
702+
int mode = getModeWithDistributionFallback(distribution);
703+
if (mode == -1) {
704+
return 1d;
705+
}
706+
707+
// Compute error
708+
return getNMSE(minimum, maximum, Arrays.copyOf(list.elements(), list.size()),
709+
rType.toDouble(rType.parse(dictionary[mode])));
710+
}
711+
712+
@Override
713+
public void initialize(String[] dictionary, DataType<?> type) {
714+
super.initialize(dictionary, type);
715+
if (type instanceof DataTypeWithRatioScale) {
716+
double[] values = getMinMax(dictionary, (DataTypeWithRatioScale<?>)type);
717+
this.minimum = values[0];
718+
this.maximum = values[1];
719+
}
720+
intDistribution = new IntDoubleOpenHashMap();
721+
int index = 0;
722+
for (String value : dictionary) {
723+
Double frequency = this.distribution.get(value);
724+
if (frequency != null) {
725+
intDistribution.put(index, frequency);
726+
}
727+
index++;
728+
}
729+
}
730+
731+
/**
732+
* Returns the index of the most frequent element from the distribution. Draws from the most frequent items if
733+
* there are multiple, using the provided distribution. Returns -1 if there is no such element.
734+
* @param distribution
735+
* @return
736+
*/
737+
private int getModeWithDistributionFallback(Distribution distribution) {
738+
739+
// Prepare
740+
int[] buckets = distribution.getBuckets();
741+
int max = -1;
742+
IntArrayList mode = new IntArrayList();
743+
744+
// Iterate through distribution, collecting the mode
745+
for (int i = 0; i < buckets.length; i += 2) {
746+
int value = buckets[i];
747+
int frequency = buckets[i + 1];
748+
if (value != -1) {
749+
// Same frequency
750+
if (Math.abs(max - frequency) < 1e-9) {
751+
mode.add(value);
752+
// More frequent
753+
} else if (frequency > max) {
754+
max = frequency;
755+
mode.clear();
756+
mode.add(value);
757+
}
758+
}
759+
}
760+
761+
// Weird
762+
if (mode.isEmpty()) {
763+
return -1;
764+
765+
// Exactly one mode
766+
} else if (mode.size() == 1) {
767+
return mode.get(0);
768+
769+
// Need to draw from distribution
770+
} else {
771+
772+
// Collect frequencies
773+
DoubleArrayList frequencies = new DoubleArrayList();
774+
for (int i = 0; i < mode.size(); i++) {
775+
int code = mode.get(i);
776+
if (this.intDistribution.containsKey(code)) {
777+
frequencies.add(this.intDistribution.get(code));
778+
} else {
779+
frequencies.add(0d);
780+
}
781+
}
782+
783+
// Convert frequencies to cumulative frequencies
784+
for (int i = 1; i < frequencies.size(); i++) {
785+
frequencies.set(i, frequencies.get(i) + frequencies.get(i - 1));
786+
}
787+
788+
// Normalize
789+
double maxFrequency = frequencies.get(frequencies.size() - 1);
790+
for (int i = 0; i < frequencies.size(); i++) {
791+
frequencies.set(i, frequencies.get(i) / maxFrequency);
792+
}
793+
794+
// Draw
795+
double r = random.nextDouble();
796+
for (int i = 0; i < frequencies.size(); i++) {
797+
if (r <= frequencies.get(i)) {
798+
return mode.get(i);
799+
}
800+
}
801+
}
802+
803+
// Should never happen
804+
return -1;
805+
}
806+
}
807+
586808
/**
587809
* This class calculates a set for a given distribution.
588810
*

0 commit comments

Comments
 (0)