forked from zhichunguo/Meta-MGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsamples.py
59 lines (41 loc) · 9.69 KB
/
samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import random
def obtain_distr_list(dataset):
if dataset == "sider":
return [[684,743],[431,996],[1405,22],[551,876],[276,1151],[430,997],[129,1298],[1176,251],[403,1024],[700,727],[1051,376],[135,1292],[1104,323],[1214,213],[319,1108],[542,885],[109,1318],[1174,253],[421,1006],[367,1060],[411,1016],[516,911],[1302,125],[768,659],[439,988],[123,1304],[481,946]]
elif dataset == "tox21":
return [[6956,309],[6521,237],[5781,768],[5521,300],[5400,793],[6605,350],[6264,186],[4890,942],[6808,264],[6095,372],[4892,918],[6351,423]]
elif dataset == "muv":
return [[14814,27],[14705,29],[14698,30],[14593,30],[14873,29],[14572,29],[14614,30],[14383,28],[14807,29],[14654,28],[14662,29],[14615,29],[14637,30],[14681,30],[14622,29],[14745,29],[14722,24]]
elif dataset == "toxcast":
return [[1293, 438], [1441, 290], [864, 170], [995, 39], [794, 240], [738, 296], [591, 443], [977, 57], [948, 86], [960, 59], [910, 109], [908, 126], [1010, 24], [930, 89], [947, 72], [281, 22], [849, 185], [889, 130], [822, 212], [740, 279], [979, 55], [994, 40], [1018, 16], [797, 237], [788, 246], [286, 17], [967, 67], [935, 99], [842, 192], [828, 206], [262, 41], [257, 46], [252, 51], [267, 36], [251, 52], [248, 55], [247, 56], [251, 52], [251, 52], [262, 41], [263, 40], [283, 20], [274, 29], [286, 17], [284, 19], [3333, 79], [2796, 616], [3198, 214], [3379, 33], [3400, 12], [3382, 30], [3184, 228], [2975, 437], [3321, 91], [3032, 380], [3341, 71], [3386, 26], [3279, 133], [2875, 537], [3363, 49], [3061, 351], [3341, 71], [3172, 240], [2716, 696], [3357, 55], [3278, 134], [3094, 318], [3287, 125], [3390, 22], [2913, 499], [3207, 205], [2543, 869], [3386, 26], [3388, 24], [3402, 10], [2685, 727], [3081, 331], [3340, 72], [3195, 217], [3395, 17], [3320, 92], [3353, 59], [3350, 62], [3256, 156], [3370, 42], [3068, 76], [3310, 102], [3376, 36], [3380, 32], [3231, 181], [3271, 141], [3367, 45], [3395, 17], [3363, 49], [3193, 219], [3036, 376], [3388, 24], [3373, 39], [3293, 119], [3356, 56], [3367, 45], [2998, 414], [3078, 334], [3330, 82], [2947, 465], [3397, 15], [3359, 53], [3319, 93], [3397, 15], [3346, 66], [2696, 716], [3400, 12], [3338, 74], [3356, 56], [3386, 26], [3364, 48], [3370, 42], [3363, 49], [3392, 20], [3401, 11], [3299, 113], [3371, 41], [3372, 40], [3233, 179], [3365, 47], [3146, 266], [3142, 270], [3282, 130], [3265, 147], [3319, 93], [3367, 45], [2123, 1289], [3392, 20], [3208, 204], [3386, 26], [2867, 545], [3392, 20], [3026, 386], [3385, 27], [3184, 228], [3351, 61], [2484, 928], [3330, 82], [2887, 525], [3090, 322], [1769, 1643], [3400, 12], [2445, 967], [2963, 449], [3176, 236], [3344, 68], [3285, 127], [3397, 15], [3357, 55], [3364, 48], [3393, 19], [3041, 371], [3368, 44], [3393, 19], [3387, 25], [3399, 13], [3314, 98], [3324, 88], [2893, 519], [3379, 33], [2887, 525], [3323, 89], [3381, 31], [3389, 23], [3198, 214], [3388, 24], [3071, 341], [3357, 55], [3300, 112], [3394, 18], [3186, 226], [2958, 454], [3382, 30], [3299, 113], [3285, 127], [3384, 28], [3311, 101], [3403, 9], [2486, 926], [3398, 14], [3373, 39], [2648, 245], [3393, 19], [2960, 452], [3083, 329], [3334, 78], [1043, 396], [889, 550], [1240, 199], [1114, 325], [1062, 377], [1254, 185], [842, 597], [964, 475], [1383, 56], [1306, 133], [1124, 315], [1385, 54], [1090, 349], [1006, 433], [1026, 413], [1006, 433], [1058, 381], [1041, 398], [1423, 16], [1051, 388], [1018, 421], [1229, 210], [1098, 341], [1424, 15], [1056, 383], [1174, 265], [1060, 379], [1253, 186], [1253, 186], [1312, 127], [1150, 289], [1235, 204], [1215, 224], [1198, 241], [1244, 195], [1406, 33], [1204, 235], [1154, 285], [1235, 204], [1396, 43], [1299, 140], [1281, 158], [1361, 78], [1231, 208], [1413, 26], [1180, 259], [1423, 16], [1287, 152], [998, 441], [1417, 22], [1267, 172], [1409, 30], [1193, 246], [1371, 68], [1191, 248], [1223, 216], [1160, 279], [1407, 32], [1197, 242], [1422, 17], [1218, 221], [1147, 292], [1121, 318], [1420, 19], [1186, 253], [1419, 20], [1053, 386], [1211, 228], [1151, 288], [1119, 320], [1177, 262], [1019, 420], [1138, 301], [1423, 16], [1134, 305], [1423, 16], [1124, 315], [1414, 25], [1119, 320], [1047, 392], [1146, 293], [1349, 90], [1070, 369], [1151, 288], [1368, 71], [1208, 231], [1390, 49], [1003, 436], [1000, 439], [998, 441], [1040, 399], [1034, 405], [1398, 41], [1096, 343], [1402, 37], [1096, 343], [1212, 227], [1123, 316], [1367, 72], [877, 562], [1079, 360], [1006, 433], [1347, 92], [1382, 57], [1252, 187], [1023, 416], [1027, 412], [1149, 290], [1178, 261], [1380, 59], [1049, 390], [817, 622], [1112, 327], [1176, 263], [1032, 407], [300, 202], [318, 184], [369, 133], [365, 131], [427, 69], [470, 30], [428, 72], [459, 43], [436, 66], [411, 51], [353, 147], [387, 113], [351, 118], [358, 142], [283, 17], [279, 21], [176, 120], [186, 109], [201, 101], [169, 133], [147, 153], [221, 81], [128, 171], [139, 161], [121, 181], [178, 114], [178, 116], [254, 42], [272, 28], [277, 22], [261, 39], [252, 50], [236, 64], [173, 200], [276, 97], [143, 175], [66, 307], [22, 31], [221, 482], [168, 71], [105, 70], [39, 134], [86, 27], [35, 101], [76, 301], [38, 187], [37, 80], [75, 85], [49, 28], [23, 31], [74, 68], [90, 21], [72, 23], [80, 90], [42, 37], [99, 31], [43, 60], [81, 80], [59, 54], [136, 29], [196, 24], [55, 44], [37, 45], [55, 35], [70, 34], [72, 21], [58, 39], [53, 26], [80, 58], [113, 67], [92, 20], [65, 31], [63, 24], [54, 25], [51, 24], [76, 32], [29, 38], [88, 26], [69, 29], [42, 21], [130, 24], [56, 84], [42, 61], [50, 49], [56, 39], [31, 84], [42, 64], [57, 71], [76, 56], [52, 54], [74, 38], [24, 31], [50, 85], [43, 77], [36, 53], [37, 28], [45, 57], [55, 91], [63, 46], [66, 89], [35, 65], [40, 120], [46, 21], [34, 84], [20, 66], [30, 61], [31, 81], [38, 57], [38, 40], [61, 25], [32, 98], [53, 72], [21, 57], [33, 57], [49, 22], [26, 57], [43, 75], [32, 70], [49, 81], [85, 79], [47, 60], [75, 114], [35, 60], [41, 70], [43, 29], [44, 48], [41, 51], [40, 53], [25, 53], [42, 23], [66, 46], [57, 28], [57, 72], [57, 65], [37, 33], [915, 27], [25, 30], [42, 57], [26, 77], [51, 40], [31, 71], [35, 54], [41, 117], [42, 25], [43, 23], [24, 26], [37, 25], [54, 30], [133, 215], [116, 217], [927, 127], [110, 75], [98, 206], [116, 112], [194, 83], [900, 228], [133, 31], [198, 59], [120, 225], [304, 72], [602, 178], [196, 85], [405, 109], [231, 29], [145, 21], [168, 55], [742, 186], [139, 131], [77, 20], [38, 107], [50, 123], [26, 51], [50, 193], [69, 160], [64, 39], [39, 39], [52, 61], [53, 49], [1635, 137], [1629, 111], [1532, 201], [1623, 125], [1575, 99], [1544, 211], [1478, 190], [1543, 201], [1497, 169], [1596, 175], [1619, 139], [1424, 311], [1549, 133], [1560, 198], [1657, 80], [6835, 352], [6564, 623], [5583, 1604], [7118, 69], [7926, 5], [7562, 369], [7540, 391], [7908, 23], [7746, 185], [6773, 1158], [7351, 580], [7565, 366], [7100, 831], [6034, 1153], [7141, 790], [6674, 1257], [7900, 31], [7898, 33], [7899, 32], [7926, 5], [7901, 30], [7927, 4], [5151, 120], [7653, 278], [7482, 449], [7480, 451], [7650, 281], [7694, 237], [6919, 1012], [7750, 181], [6691, 1240], [7234, 697], [7110, 77], [7094, 93], [6871, 316], [6971, 216], [6843, 344], [6917, 270], [7020, 167], [6997, 190], [6243, 944], [6871, 316], [7620, 311], [7721, 210], [7448, 483], [7413, 518], [7492, 439], [7550, 381], [6909, 278], [6830, 357], [6592, 595], [7035, 152], [4425, 846], [5163, 108], [4982, 289], [6908, 279], [7100, 87], [6961, 226], [6755, 432], [6551, 636], [7084, 103], [7184, 3], [7017, 170], [7010, 177], [6761, 426], [7171, 16], [7926, 5], [7716, 215], [7596, 335], [7175, 12], [6655, 532], [7045, 142], [7912, 19], [6186, 1745], [6572, 615], [7027, 160], [7140, 47], [7134, 53], [6890, 297], [6926, 261], [7380, 551], [7479, 452], [7136, 795], [7556, 375], [7423, 508], [7149, 782], [6919, 1012], [7152, 779], [7269, 662], [7239, 692], [6991, 940], [7516, 415], [7292, 639], [7379, 552], [7001, 930], [7279, 652], [7596, 335], [7306, 625], [7066, 865], [7622, 309], [910, 111], [840, 194], [962, 59], [975, 46], [1006, 15], [945, 76], [911, 110], [914, 120], [982, 39], [903, 118], [969, 52], [979, 42], [914, 107], [985, 36], [991, 30], [966, 55], [942, 79], [895, 126]]
def sample_datasets(data, dataset, task, n_way, m_support, n_query):
distri_list = obtain_distr_list(dataset)
support_list = random.sample(range(0,distri_list[task][0]), m_support)
support_list += random.sample(range(distri_list[task][0],len(data)), m_support)
random.shuffle(support_list)
l = [i for i in range(0, len(data)) if i not in support_list]
query_list = random.sample(l, n_query)
support_dataset = data[torch.tensor(support_list)]
query_dataset = data[torch.tensor(query_list)]
return support_dataset, query_dataset
def sample_test_datasets(data, dataset, task, n_way, m_support, n_query):
distri_list = obtain_distr_list(dataset)
support_list = random.sample(range(0,distri_list[task][0]), m_support)
support_list += random.sample(range(distri_list[task][0],len(data)), m_support)
random.shuffle(support_list)
l = [i for i in range(0, len(data)) if i not in support_list]
random.shuffle(l)
support_dataset = data[torch.tensor(support_list)]
query_dataset = data[torch.tensor(l)]
return support_dataset, query_dataset
def sample_premeta_datasets(data, dataset, task, n_way = 2, m_support = 10, n_query = 128):
distri_list = obtain_distr_list(dataset)
# support_list = random.sample(range(0,distri_list[task][0]), m_support)
# support_list += random.sample(range(distri_list[task][0],len(data)), m_support)
# l = [i for i in range(0, len(data)) if i not in support_list]
# support_list += random.sample(l, n_query)
support_list = [i for i in range(0, len(data))]
random.shuffle(support_list)
support_dataset = data[torch.tensor(support_list)]
return support_dataset