-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify-trace.py
86 lines (65 loc) · 2.57 KB
/
classify-trace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import sys
import joblib
from sklearn.base import BaseEstimator
import lightgbm as lgb
LABEL_FILE = 'apt_trace_labels.txt'
def load_model(model_file):
"""Load the persisted model and validate it."""
try:
model = joblib.load(model_file)
if not isinstance(model, (BaseEstimator, lgb.Booster)):
raise TypeError("Loaded model is not a valid sklearn or LightGBM model.")
return model
except Exception as e:
print(f"Error loading the model: {e}")
sys.exit(1)
def _remove_consecutive_duplicates(trace):
"""Helper function to remove consecutive duplicates from a trace"""
calls = trace.split(',')
result = [calls[0]]
for i in range(1, len(calls)):
if calls[i] != calls[i - 1]:
result.append(calls[i])
return ','.join(result)
def remove_repeated_calls(traces):
"""Remove consecutive repeated API calls within each trace."""
return [_remove_consecutive_duplicates(trace) for trace in traces]
def preprocess_traces(input_file, vectorizer):
"""Preprocess the traces by removing consecutive duplicates."""
# Read the input traces
with open(input_file, 'r') as f:
traces = [line.strip() for line in f]
# Remove consecutive repeated API calls within each trace
traces = remove_repeated_calls(traces)
X = vectorizer.transform(traces)
return X.toarray(), len(traces)
def classify_traces(model, traces):
"""Classify each trace and return the predicted labels."""
try:
if isinstance(model, lgb.Booster):
return model.predict(traces).round().astype(int) # Assuming binary classification
return model.predict(traces)
except AttributeError as e:
print(f"Error during prediction: {e}")
sys.exit(1)
def load_true_labels(num_samples):
"""Load the true labels from the apt_trace_labels.txt file."""
with open(LABEL_FILE, 'r') as f:
labels = [line.strip() for line in f][:num_samples]
return labels
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: python classify-trace.py <input_file>")
sys.exit(1)
# Load the input file from the command line
input_file = sys.argv[1]
# Load the trained model
model = load_model('best_model.joblib')
vectorizer = load_model('vectorizer.joblib')
# Preprocess the traces from the input file
traces, num_samples = preprocess_traces(input_file, vectorizer)
# Classify the traces
predictions = classify_traces(model, traces)
# Output predictions
for prediction in predictions:
print(prediction)