diff --git a/active_learning_basics.py b/active_learning_basics.py index 9d7003b..fefb611 100644 --- a/active_learning_basics.py +++ b/active_learning_basics.py @@ -259,8 +259,8 @@ def train_model(training_data, validation_data = "", evaluation_data = "", num_l # with an equal number of items from each label shuffle(training_data) #randomize the order of the training data - related = [row for row in training_data if '1' in row[2]] - not_related = [row for row in training_data if '0' in row[2]] + related = [row for row in training_data if '1' in row] + not_related = [row for row in training_data if '0' in row] epoch_data = related[:select_per_epoch] epoch_data += not_related[:select_per_epoch]