Skip to content

Commit

Permalink
Create probabilities for all chromosome frames
Browse files Browse the repository at this point in the history
  • Loading branch information
slowikj committed Jan 19, 2020
1 parent ddfee57 commit 8069bb5
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
30 changes: 28 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np

from test_data_processor import get_test_data_frames, split_frames_by_is_valid
from train_data_processor import prepare_features_labels, create_features
from train_model import train
from train_model import train_model

alphabet = ["A", "C", "G", "T"]
k = 4
Expand All @@ -9,14 +11,38 @@
frame_length = 1500
step = 750


def get_probabilities_for_1_list(valid_frames, invalid_frames, y_test_proba):
prob_for_1 = [0] * (len(valid_frames) + len(invalid_frames))
print(len(prob_for_1))
for valid_frame_ind in range(len(valid_frames)):
prob_for_1_ind = int(valid_frames[valid_frame_ind].begin / step)
prob_for_1[prob_for_1_ind] = y_test_proba[valid_frame_ind][1]

mean_prob = np.mean(y_test_proba[:, 1], axis=0)

for invalid_frame_ind in range(len(invalid_frames)):
prob_for_1_ind: int = int(invalid_frames[invalid_frame_ind].begin / step)
prob_for_1[prob_for_1_ind] = mean_prob

return np.array(prob_for_1)


if __name__ == "__main__":
X_train, y_train = prepare_features_labels(train_filenames_labels, alphabet=alphabet, k=k)
clf = train(X_train, y_train)
clf = train_model(X_train, y_train)

test_frames = get_test_data_frames(test_filename, frame_length=frame_length, step=step)
valid_frames, invalid_frames = split_frames_by_is_valid(test_frames)

X_test = create_features(valid_frames,
alphabet=alphabet,
k=k)
y_test_proba = clf.predict_proba(X_test)
print(y_test_proba)

print(len(valid_frames))
print(len(invalid_frames))
print("------------")
res = get_probabilities_for_1_list(valid_frames, invalid_frames, y_test_proba)
print(res)
4 changes: 2 additions & 2 deletions test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def get_test_data_frames(file_path, frame_length, step):

def split_frames_by_is_valid(frames):
return (
filter(lambda f: f.is_valid, frames),
filter(lambda f: not f.is_valid, frames)
list(filter(lambda f: f.is_valid, frames)),
list(filter(lambda f: not f.is_valid, frames))
)


Expand Down
2 changes: 1 addition & 1 deletion train_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sklearn.ensemble import RandomForestClassifier


def train(X, y):
def train_model(X, y):
clf = RandomForestClassifier(max_depth=5, random_state=123)
clf.fit(X, y)
return clf

0 comments on commit 8069bb5

Please sign in to comment.