1
1
import random
2
+ import pulp
2
3
3
4
from typing import Sequence
4
5
5
6
from rna3db .utils import PathLike , read_json , write_json
6
7
7
8
8
- def find_optimal_components (lengths_dict , capacity ):
9
- component_name = list (lengths_dict .keys ())
10
- lengths = list (lengths_dict .values ())
9
+ def find_optimal_components (
10
+ components : Sequence [int ], bins : Sequence [int ], verbose : bool = False
11
+ ) -> Sequence [set [int ]]:
12
+ """Function used to find optimal placement of components into
13
+ training/testing sets.
11
14
12
- dp = [0 ] * (capacity + 1 )
13
- trace = [[] for i in range (capacity + 1 )]
14
- for i in range (len (lengths )):
15
- for j in range (capacity , lengths [i ] - 1 , - 1 ):
16
- if dp [j ] < dp [j - lengths [i ]] + lengths [i ]:
17
- dp [j ] = dp [j - lengths [i ]] + lengths [i ]
18
- trace [j ] = trace [j - lengths [i ]] + [component_name [i ]]
15
+ We use an ILP formulation that is very similar to the classic ILP
16
+ formulation of the bin packing problem.
19
17
20
- return set (trace [capacity ])
18
+ Args:
19
+ components (Sequence[int]): list of component sizes
20
+ bins (Sequence[int]): list of bin sizes
21
+ verbose (bool): whether to print verbose output
22
+ Returns:
23
+ Sequence[set[int]]: list of sets, where each set contains the indices
24
+ of the components that go into that bin
25
+ """
26
+
27
+ n , k = len (components ), len (bins )
28
+
29
+ # set up problem
30
+ p = pulp .LpProblem ("OptimalComponentSolver" , pulp .LpMinimize )
31
+ x = pulp .LpVariable .dicts (
32
+ "x" , ((i , j ) for i in range (n ) for j in range (k )), cat = "Binary"
33
+ )
34
+ deviation = pulp .LpVariable .dicts (
35
+ "d" , (j for j in range (k )), lowBound = 0 , cat = "Continuous"
36
+ )
37
+
38
+ # we want to minimise total "deviation"
39
+ # (deviation is the total sum of the difference between target bins and found bins)
40
+ p += pulp .lpSum (deviation [j ] for j in range (k ))
41
+
42
+ # components can go into exactly one bin
43
+ for i in range (n ):
44
+ p += pulp .lpSum (x [(i , j )] for j in range (k )) == 1 , f"AssignComponent_{ i } "
45
+
46
+ # deviation constraints (to handle abs)
47
+ for j in range (k ):
48
+ total_weight_in_bin = pulp .lpSum (components [i ] * x [(i , j )] for i in range (n ))
49
+ p += total_weight_in_bin - bins [j ] <= deviation [j ], f"DeviationPos_{ j } "
50
+ p += bins [j ] - total_weight_in_bin <= deviation [j ], f"DeviationNeg_{ j } "
51
+
52
+ # solve ILP problem with PuLP
53
+ p .solve (pulp .PULP_CBC_CMD (msg = int (verbose )))
54
+
55
+ # extract solution in sensible format
56
+ sol = [set () for i in range (k )]
57
+ for i in range (k ):
58
+ for j in range (n ):
59
+ if pulp .value (x [(j , i )]) == 1 :
60
+ sol [i ].add (j )
61
+
62
+ return sol
21
63
22
64
23
65
def split (
24
66
input_path : PathLike ,
25
- output_path : PathLike ,
67
+ output_path : PathLike = None ,
26
68
splits : Sequence [float ] = [0.7 , 0.0 , 0.3 ],
27
69
split_names : Sequence [str ] = ["train_set" , "valid_set" , "test_set" ],
28
70
shuffle : bool = False ,
@@ -41,33 +83,53 @@ def split(
41
83
if sum (splits ) != 1.0 :
42
84
raise ValueError ("Sum of splits must equal 1.0." )
43
85
44
- # read json
86
+ if len (splits ) != len (split_names ):
87
+ raise ValueError ("Number of splits must match number of split names." )
88
+
45
89
cluster_json = read_json (input_path )
46
90
47
- # count number of repr sequences
48
- lengths = {k : len (v ) for k , v in cluster_json .items ()}
49
- total_repr_clusters = sum (lengths .values ())
91
+ # get lengths of the components, and mapping from idx to keys
92
+ keys , lengths = [], []
93
+ for k , v in cluster_json .items ():
94
+ if force_zero_last and k == "component_0" :
95
+ continue
96
+ keys .append (k )
97
+ lengths .append (len (v ))
50
98
51
- # shuffle if we want to add randomness
52
- if shuffle :
53
- L = list (zip (component_name , lengths ))
54
- random .shuffle (L )
55
- component_name , lengths = zip (* L )
56
- component_name , lengths = list (component_name ), list (lengths )
99
+ # calculate actual bin capacities
100
+ # rounding is probably close enough
101
+ bins = [round (sum (lengths ) * ratio ) for ratio in splits ]
57
102
103
+ # create output dict
58
104
output = {k : {} for k in split_names }
59
105
106
+ # force `component_0` into the last bin
60
107
if force_zero_last :
108
+ if bins [- 1 ] < len (cluster_json ["component_0" ]):
109
+ print (
110
+ "ERROR: cannot force `component_0` into the last bin. Increase the last bin size."
111
+ )
112
+ raise ValueError
113
+ bins [- 1 ] -= len (cluster_json ["component_0" ])
61
114
output [split_names [- 1 ]]["component_0" ] = cluster_json ["component_0" ]
62
- lengths .pop ("component_0" )
115
+ del cluster_json ["component_0" ]
116
+
117
+ if shuffle :
118
+ L = list (zip (keys , lengths ))
119
+ random .shuffle (L )
120
+ keys , lengths = zip (* L )
121
+ keys , lengths = list (keys ), list (lengths )
122
+
123
+ # find optimal split with ILP
124
+ sol = find_optimal_components (lengths , bins )
63
125
64
- capacities = [round (total_repr_clusters * ratio ) for ratio in splits ]
65
- for name , capacity in zip (split_names , capacities ):
66
- components = find_optimal_components (lengths , capacity )
67
- for k in sorted (components ):
68
- lengths .pop (k )
126
+ # write output to dict
127
+ for idx , name in enumerate (split_names ):
128
+ for k in sorted (sol [idx ]):
129
+ k = keys [k ]
69
130
output [name ][k ] = cluster_json [k ]
70
131
71
- assert len (lengths ) == 0
132
+ if output_path :
133
+ write_json (output , output_path )
72
134
73
- write_json ( output , output_path )
135
+ return output
0 commit comments