-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
54 lines (41 loc) · 7.44 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from torch import cuda
data_dir = "/usr/local/courses/lt2316-h19/a2/hansard/scrapedxml/debates/"
CSV_FILENAME = "parliamentary_speech_data.csv"
CSV_FILENAME_TRAIN = 'parliamentary_speech_data_train.csv'
CSV_FILENAME_TEST = 'parliamentary_speech_data_test.csv'
PATH_TO_PRETRAINED_EMBEDDINGS = '/scratch/GoogleNews-vectors-negative300.bin.gz'
BOUNDARY_TO_INT_MAPPING = {
'[SAME]': 0,
'[CHANGE]': 1
}
RNN_BATCH_SIZE = 256
RNN_NUM_EPOCHS = 100
RNN_HIDDEN_SIZE = 300
RNN_NUM_RECORDS = 100000
RNN_SAME_ADDITIONAL_RECORDS = 5000
RNN_LEARNING_RATE = 0.0005
USE_ATTENTION = True
EQUALIZE_CLASS_COUNTS = True
RNN_MODEL1 = 'model_1.pkl'
RNN_MODEL2 = 'model_2.pkl'
RNN_CLASSIFIER = 'classifier.pkl'
RNN_EQ_MODEL1 = 'model_eq_1.pkl'
RNN_EQ_MODEL2 = 'model_eq_2.pkl'
RNN_EQ_CLASSIFIER = 'classifier_eq.pkl'
RNN_ATTENTION_MODEL1 = 'model_attention_1.pkl'
RNN_ATTENTION_MODEL2 = 'model_attention_2.pkl'
RNN_ATTENTION_CLASSIFIER = 'classifier_attention.pkl'
RNN_EQ_ATTENTION_MODEL1 = 'model_eq_attention_1.pkl'
RNN_EQ_ATTENTION_MODEL2 = 'model_eq_attention_2.pkl'
RNN_EQ_ATTENTION_CLASSIFIER = 'classifier_eq_attention.pkl'
DEVICE = "cuda:1" if cuda.is_available() else "cpu"
UNKNOWN_WORD_VECTOR = [-0.505834698677063, 0.8371995091438293, -0.3907436430454254, 1.2388476133346558, 0.39628866314888, 0.659585177898407, 0.2056134194135666, -1.2870607376098633, 0.22113212943077087, -1.330844759941101, 0.8694334030151367, 0.41195404529571533, -0.17144230008125305, -0.6751164197921753, 2.0545945167541504, 1.5696337223052979, 0.8888554573059082, 0.6383231282234192, 0.34052929282188416, -1.718234896659851, 1.7102441787719727, -0.74288010597229, 1.2348617315292358, 0.20869804918766022, 0.010648425668478012, 0.7828048467636108, 0.4952743649482727, 0.7345136404037476, 0.7647744417190552, -0.08482004702091217, -1.0111994743347168, -0.8100836873054504, -0.24033965170383453, -1.593235731124878, -1.7683085203170776, -0.18870864808559418, 0.15959683060646057, 0.06499271094799042, 0.14582936465740204, 0.4639718234539032, 1.082289218902588, -0.4073246121406555, 1.0503937005996704, 0.8217169046401978, 0.04474393650889397, 1.134543538093567, -1.8783674240112305, -0.01754256710410118, 1.373317003250122, -0.49135923385620117, -0.667228639125824, -0.9872323274612427, -1.0467846393585205, -0.522704005241394, -1.0922895669937134, 0.37549442052841187, 1.4084978103637695, -0.4988190829753876, -0.4912785291671753, -0.9757516980171204, 0.9069097638130188, 1.9779735803604126, 0.8327620625495911, -0.27752694487571716, 2.6078107357025146, 1.2526928186416626, 0.486898273229599, 1.2384307384490967, 0.2881917953491211, -0.4037890136241913, -0.6822932958602905, 0.5154398083686829, -0.15835464000701904, 2.401808738708496, -0.044415660202503204, 2.055353879928589, 1.4222408533096313, 0.14033208787441254, 1.408102035522461, 0.5589955449104309, 0.42443642020225525, -0.5595348477363586, -0.02134212665259838, 1.5728685855865479, 0.0595572255551815, 0.6637532711029053, -1.410125970840454, 0.33802279829978943, -0.6599981188774109, 0.05207591876387596, -0.7007763385772705, -0.5083052515983582, -0.6373371481895447, -0.39122241735458374, -0.5458566546440125, -1.4461982250213623, 0.37437334656715393, 0.2300325632095337, -1.6860687732696533, -0.8976618051528931, -0.3166378438472748, -1.253980278968811, -0.25334432721138, 0.2735156714916229, -0.19204142689704895, 0.38214144110679626, 1.202514410018921, -2.010162591934204, -0.8418667316436768, 0.9256077408790588, -0.7056060433387756, 0.30248740315437317, 1.7211133241653442, 0.876978874206543, -0.1790681779384613, -0.997778058052063, 0.1718505173921585, 1.4730232954025269, 0.08387551456689835, -0.3736158013343811, -0.3336387872695923, 0.8899230360984802, -1.0761618614196777, -1.2676701545715332, 1.2060940265655518, 1.1710679531097412, 0.9012223482131958, 0.15513092279434204, -0.8016363978385925, 1.3856252431869507, 0.27607378363609314, 0.9537136554718018, 0.3620387613773346, 1.901503562927246, 1.271773338317871, -1.2165992259979248, -1.1655319929122925, 0.19391410052776337, -0.008311061188578606, -0.48887336254119873, -0.44454896450042725, 0.5515300631523132, -0.38193953037261963, -0.647794246673584, -1.5365986824035645, -1.8549766540527344, -1.413175106048584, 1.5679494142532349, -1.7551692724227905, 0.05125558748841286, -1.3250762224197388, -0.4479677081108093, 0.36137405037879944, -0.7935487627983093, -0.1577877402305603, -0.4462861120700836, -0.41334351897239685, 0.24989688396453857, -0.48370274901390076, 0.5078774094581604, 1.4539228677749634, -0.1624590903520584, 1.335595726966858, -1.747597336769104, -2.0099356174468994, 1.3160730600357056, -0.7577613592147827, 0.11472854763269424, 0.14775584638118744, -0.1385658234357834, 0.5658499002456665, -0.459995836019516, -0.7670608758926392, -0.5251957178115845, -0.03897169977426529, 0.19235558807849884, -0.5747939944267273, 0.36217936873435974, 0.06927046179771423, 0.8604796528816223, -1.0001740455627441, -0.6298030614852905, 1.0134512186050415, -0.8840906620025635, -0.7409910559654236, 0.8349681496620178, -1.7573039531707764, 0.8475261926651001, -1.1452159881591797, -0.7934998273849487, -0.9209117293357849, -1.5573413372039795, 0.3694569170475006, -0.6681967973709106, 1.1655405759811401, -0.8026101589202881, -0.5906173586845398, -0.6410574913024902, 0.5311210751533508, 0.07662922143936157, -0.6139146685600281, 0.2491513043642044, 1.1668953895568848, 0.5382435917854309, -0.21040044724941254, -1.203473687171936, 0.43246331810951233, 0.1738639771938324, 1.3473886251449585, -0.09564293920993805, -0.24091584980487823, -1.1073969602584839, -0.6354038119316101, -1.8128196001052856, 0.4310932457447052, 0.3172605335712433, -1.7118213176727295, -0.5851176977157593, 0.4524517357349396, -0.06275753676891327, 0.9286178946495056, 1.2695915699005127, 0.23088225722312927, 0.8585662841796875, 1.2107691764831543, 1.405938982963562, 0.024015391245484352, 0.8975788354873657, -0.9891437292098999, 1.7498891353607178, -2.0724780559539795, -0.8942360281944275, 1.2017576694488525, 0.7358479499816895, 0.9223952293395996, -0.2673342823982239, -1.2941982746124268, -1.1372345685958862, 0.42498579621315, 0.20963171124458313, 1.6649818420410156, 0.3427635431289673, -0.5924468040466309, -1.1568814516067505, -0.7269764542579651, 0.604738175868988, -0.9780421257019043, -0.2571362853050232, 1.031184434890747, 1.4818615913391113, -0.24192920327186584, -0.5791612267494202, -1.1024761199951172, 2.6458823680877686, 1.0979740619659424, -1.4632577896118164, 1.034663200378418, -0.5986494421958923, -2.043222188949585, 0.010517362505197525, 1.2263511419296265, -1.8005244731903076, -0.4504871368408203, 0.9241243004798889, -0.11459331214427948, 0.12185172736644745, -0.5425137877464294, 1.1841212511062622, 0.8931083083152771, 0.038081664592027664, 0.23679955303668976, 0.7889691591262817, 3.053788661956787, -0.04461027309298515, -0.05341833829879761, 0.07460398972034454, -0.3817829489707947, -1.1750118732452393, 0.4739631712436676, 0.6622761487960815, -0.8415477871894836, 0.4038948118686676, -0.8042656779289246, 1.8321316242218018, 0.2302185595035553, -0.37898436188697815, -0.7147290110588074, 0.1850944608449936, 1.288645625114441, -0.4520632028579712, 1.0726860761642456, 0.769446849822998, 1.165770173072815, -0.6193650364875793, 0.19318526983261108, 0.20301254093647003, -1.4767383337020874, -0.7551678419113159, 1.3037528991699219, -0.3005061745643616]
# for BERT
BERT_BATCH_SIZE = 16
BERT_NUM_EPOCHS = 4
BERT_MAX_SENT_LEN = 256
BERT_NUM_RECORDS = 1000000
BERT_MODEL_FILE = "pytorch_model.bin"
BERT_CONFIG_FILE = "config.json"
BERT_VOCAB_FILE = "vocab.txt"
BERT_TAR_FILE = "bert_model.tar.gz"