From 3fab8aa6f72b6f7c581ccd06a6e9564665782d17 Mon Sep 17 00:00:00 2001 From: yjh1021317464 Date: Wed, 4 Dec 2019 11:12:08 +0800 Subject: [PATCH 1/2] 1. Fix crash caused by saving model to a file in non-existent subdirectory. 2. Fix crash caused by using cPickle.load/dump a file which is not opened with binary mode. --- SS_dataset.py | 2 +- chat.py | 2 +- compute_dialogue_embeddings.py | 4 ++-- convert-text2dict.py | 2 +- convert-wordemb-dict2emb-matrix.py | 2 +- create-text-file-for-tests.py | 6 +++--- dialog_encdec.py | 4 ++-- evaluate.py | 2 +- generate_encodings.py | 4 ++-- model.py | 2 ++ sample.py | 2 +- split-examples-by-token.py | 2 +- train.py | 5 +++-- 13 files changed, 21 insertions(+), 18 deletions(-) diff --git a/SS_dataset.py b/SS_dataset.py index 5c465e9..d632250 100644 --- a/SS_dataset.py +++ b/SS_dataset.py @@ -78,7 +78,7 @@ def __init__(self, self.exit_flag = False def load_files(self): - self.data = cPickle.load(open(self.dialogue_file, 'r')) + self.data = cPickle.load(open(self.dialogue_file, 'rb')) self.data_len = len(self.data) logger.debug('Data len is %d' % self.data_len) diff --git a/chat.py b/chat.py index 06f1223..a862164 100644 --- a/chat.py +++ b/chat.py @@ -86,7 +86,7 @@ def main(): state_path = args.model_prefix + "_state.pkl" model_path = args.model_prefix + "_model.npz" - with open(state_path) as src: + with open(state_path, "rb") as src: state.update(cPickle.load(src)) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") diff --git a/compute_dialogue_embeddings.py b/compute_dialogue_embeddings.py index 02b49a6..99f2d1a 100644 --- a/compute_dialogue_embeddings.py +++ b/compute_dialogue_embeddings.py @@ -109,7 +109,7 @@ def main(): state_path = args.model_prefix + "_state.pkl" model_path = args.model_prefix + "_model.npz" - with open(state_path) as src: + with open(state_path, "rb") as src: state.update(cPickle.load(src)) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") @@ -172,7 +172,7 @@ def main(): dialogue_encodings.append(encs[i]) # Save encodings to disc - cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w')) + cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'wb')) if __name__ == "__main__": main() diff --git a/convert-text2dict.py b/convert-text2dict.py index cf70f2f..da16852 100644 --- a/convert-text2dict.py +++ b/convert-text2dict.py @@ -47,7 +47,7 @@ def safe_pickle(obj, filename): if args.dict != "": # Load external dictionary assert os.path.isfile(args.dict) - vocab = dict([(x[0], x[1]) for x in cPickle.load(open(args.dict, "r"))]) + vocab = dict([(x[0], x[1]) for x in cPickle.load(open(args.dict, "rb"))]) # Check consistency assert '' in vocab diff --git a/convert-wordemb-dict2emb-matrix.py b/convert-wordemb-dict2emb-matrix.py index d0772b3..9a27c51 100644 --- a/convert-wordemb-dict2emb-matrix.py +++ b/convert-wordemb-dict2emb-matrix.py @@ -111,7 +111,7 @@ def edits1(word): # Load model dictionary -model_dict = cPickle.load(open(args.model_dictionary, 'r')) +model_dict = cPickle.load(open(args.model_dictionary, 'rb')) str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in model_dict]) i_dim = len(str_to_idx.keys()) diff --git a/create-text-file-for-tests.py b/create-text-file-for-tests.py index 95b3359..d13c209 100644 --- a/create-text-file-for-tests.py +++ b/create-text-file-for-tests.py @@ -75,13 +75,13 @@ def main(): # Load state file state = prototype_state() state_path = args.model_prefix + "_state.pkl" - with open(state_path) as src: + with open(state_path, "rb") as src: state.update(cPickle.load(src)) # Load dictionary # Load dictionaries to convert str to idx and vice-versa - raw_dict = cPickle.load(open(state['dictionary'], 'r')) + raw_dict = cPickle.load(open(state['dictionary'], 'rb')) str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict]) idx_to_str = dict([(tok_id, tok) for tok, tok_id, freq, _ in raw_dict]) @@ -95,7 +95,7 @@ def main(): # Is it a pickle file? Then process using model dictionaries.. if args.test_file[len(args.test_file)-4:len(args.test_file)] == '.pkl': - test_dialogues = cPickle.load(open(args.test_file, 'r')) + test_dialogues = cPickle.load(open(args.test_file, 'rb')) for test_dialogueid,test_dialogue in enumerate(test_dialogues): if test_dialogueid % 100 == 0: print 'test_dialogue', test_dialogueid diff --git a/dialog_encdec.py b/dialog_encdec.py index 3ce90e0..11b5fa0 100644 --- a/dialog_encdec.py +++ b/dialog_encdec.py @@ -1614,7 +1614,7 @@ def __init__(self, state): self.rng = numpy.random.RandomState(state['seed']) # Load dictionary - raw_dict = cPickle.load(open(self.dictionary, 'r')) + raw_dict = cPickle.load(open(self.dictionary, 'rb')) # Probabilities for each term in the corpus used for noise contrastive estimation (NCE) self.noise_probs = [x[2] for x in sorted(raw_dict, key=operator.itemgetter(1))] @@ -1674,7 +1674,7 @@ def __init__(self, state): if self.initialize_from_pretrained_word_embeddings == True: # Load pretrained word embeddings from pickled file logger.debug("Loading pretrained word embeddings") - pretrained_embeddings = cPickle.load(open(self.pretrained_word_embeddings_file, 'r')) + pretrained_embeddings = cPickle.load(open(self.pretrained_word_embeddings_file, 'rb')) # Check all dimensions match from the pretrained embeddings assert(self.idim == pretrained_embeddings[0].shape[0]) diff --git a/evaluate.py b/evaluate.py index 5d05819..855505e 100644 --- a/evaluate.py +++ b/evaluate.py @@ -63,7 +63,7 @@ def main(): state_path = args.model_prefix + "_state.pkl" model_path = args.model_prefix + "_model.npz" - with open(state_path) as src: + with open(state_path, "rb") as src: state.update(cPickle.load(src)) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") diff --git a/generate_encodings.py b/generate_encodings.py index 8cbe89f..331743c 100644 --- a/generate_encodings.py +++ b/generate_encodings.py @@ -238,7 +238,7 @@ def get_all_encodings(model, encoding_func, sentenceDict, max_length, nb_sent_ba #print "end", encodingDict[keys][0,1950:] print "----> Dummping the encodings..." - cPickle.dump(encodingDict, open(outputName + ".pkl", "w")) + cPickle.dump(encodingDict, open(outputName + ".pkl", "wb")) print "\tL----> Done." return encodingDict @@ -254,7 +254,7 @@ def init(path): state_path = path + "_state.pkl" model_path = path + "_model.npz" - with open(state_path) as src: + with open(state_path, "rb") as src: state.update(cPickle.load(src)) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") diff --git a/model.py b/model.py index cf35c5a..da825f8 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,7 @@ import logging import numpy import theano +import os logger = logging.getLogger(__name__) # This is the list of strings required to ignore, if we're going to take a pretrained HRED model @@ -19,6 +20,7 @@ def save(self, filename): Save the model to file `filename` """ vals = dict([(x.name, x.get_value()) for x in self.params]) + os.makedirs(os.path.split(filename)[0]) numpy.savez(filename, **vals) def load(self, filename, parameter_strings_to_ignore=[]): diff --git a/sample.py b/sample.py index 52f0114..40eb22a 100755 --- a/sample.py +++ b/sample.py @@ -71,7 +71,7 @@ def main(): state_path = args.model_prefix + "_state.pkl" model_path = args.model_prefix + "_model.npz" - with open(state_path) as src: + with open(state_path, "rb") as src: state.update(cPickle.load(src)) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") diff --git a/split-examples-by-token.py b/split-examples-by-token.py index f719733..6f1983f 100644 --- a/split-examples-by-token.py +++ b/split-examples-by-token.py @@ -64,7 +64,7 @@ def magicsplit(l, *splitters): raise Exception("Input file not found!") logger.info("Loading dialogue corpus") -data = cPickle.load(open(args.input, 'r')) +data = cPickle.load(open(args.input, 'rb')) data_len = len(data) logger.info('Corpus loaded... Data len is %d' % data_len) diff --git a/train.py b/train.py index 61c58f5..820a12c 100644 --- a/train.py +++ b/train.py @@ -65,8 +65,9 @@ def save(model, timings, post_fix = ''): start = time.time() s = signal.signal(signal.SIGINT, signal.SIG_IGN) + os.makedirs(model.state['save_dir']) model.save(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + post_fix + 'model.npz') - cPickle.dump(model.state, open(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + post_fix + 'state.pkl', 'w')) + cPickle.dump(model.state, open(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + post_fix + 'state.pkl', 'wb')) numpy.savez(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + post_fix + 'timing.npz', **timings) signal.signal(signal.SIGINT, s) @@ -135,7 +136,7 @@ def main(args): if os.path.isfile(state_file) and os.path.isfile(timings_file): logger.debug("Loading previous state") - state = cPickle.load(open(state_file, 'r')) + state = cPickle.load(open(state_file, 'rb')) timings = dict(numpy.load(open(timings_file, 'r'))) for x, y in timings.items(): timings[x] = list(y) From f8e6e4ec571df4e23a16c5d38dbd14a68e351802 Mon Sep 17 00:00:00 2001 From: yjh1021317464 Date: Sat, 7 Dec 2019 01:36:08 +0800 Subject: [PATCH 2/2] Re-fix crash caused by saving model to a file in non-existent subdirectory --- train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 820a12c..feb8410 100644 --- a/train.py +++ b/train.py @@ -64,6 +64,9 @@ def save(model, timings, post_fix = ''): # ignore keyboard interrupt while saving start = time.time() s = signal.signal(signal.SIGINT, signal.SIG_IGN) + + if not os.path.exists(model.state['save_dir']): + os.makedirs(model.state['save_dir']) os.makedirs(model.state['save_dir']) model.save(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + post_fix + 'model.npz') @@ -137,7 +140,7 @@ def main(args): logger.debug("Loading previous state") state = cPickle.load(open(state_file, 'rb')) - timings = dict(numpy.load(open(timings_file, 'r'))) + timings = dict(numpy.load(open(timings_file, 'rb'))) for x, y in timings.items(): timings[x] = list(y)