-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfinal_audio_inference.py
131 lines (105 loc) · 4.89 KB
/
final_audio_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from msclap import CLAP
import torch.nn.functional as F
import numpy as np
import torch
import os
from moviepy.editor import VideoFileClip
import sounddevice as sd
import numpy as np
from torch.nn import functional as F
from scipy.io.wavfile import write
import time
import threading
import queue
from IPython.display import display, clear_output
import csv
import joblib
import socket
def get_labels(csv_path='datasets/home_labels.csv'):
label2id = {}
id2label = {}
with open(csv_path, mode='r') as file:
csv_reader = csv.reader(file)
for i, row in enumerate(csv_reader):
class_name = row[0]
label2id[class_name] = i
id2label[i] = class_name
class_labels = list(label2id.keys())
return label2id, id2label, class_labels
def get_predictions(audio_filename, window_size = 2, sample_rate = 16000):
label2id, id2label, class_labels = get_labels()
print("Loading model...")
# Load model (Choose between versions '2022' or '2023')
# The model weight will be downloaded automatically if `model_fp` is not specified
clap_model = CLAP(version='2023', use_cuda=False)
print("Extracting text embeddings...")
# Extract text embeddings
text_embeddings = clap_model.get_text_embeddings([f"This is a sound of {c}" for c in class_labels])
print("Listening...")
bind_ip = '0.0.0.0' # Listen on all network interfaces
bind_port = 50007
# REMOTE_IP = '192.168.1.154'
REMOTE_IP = "172.26.128.166"
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
server_socket.bind((bind_ip, bind_port))
while True:
server_socket.listen(1)
print("Server is listening for connections...")
# Accept a client connection
conn, addr = server_socket.accept()
print(f"Connection attempted from {addr[0]}")
# Check if the connection is from the allowed IP
# if addr[0] == allowed_ip:
# print(f"Connected by {addr}")
with open('received_output.wav', 'wb') as f:
while True:
data = conn.recv(1024)
if not data:
break
f.write(data)
print("File received successfully.")
# Extract audio embeddings
audio_embeddings = clap_model.get_audio_embeddings([audio_filename])
# Compute similarity between audio and text embeddings
similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings)
similarity = F.softmax(similarities, dim=1)
values, indices = similarity[0].topk(5)
detected = False
clap_results = []
output_label = ""
print("\nCLAP predictions:")
for value, index in zip(values, indices):
index = index.item()
value = round(value.item() * 100, 4)
clap_results.append(index)
clap_results.append(value)
print(id2label[index], value)
if id2label[index] in ["Crying", "Gunshot", "Glass breaking"] and value > 50:
print(f"ALERT: {id2label[index]} detected!")
detected = True
output_label = id2label[index]
break
# Ensemble
if not detected:
vclip_results = [-1, 0] * 5
trained_ensemble = joblib.load('trained_RF_ensemble.joblib')
X_test = np.expand_dims(np.hstack([vclip_results, clap_results]), axis=0)
y_pred = trained_ensemble.predict(X_test)
print("\nEmsemble prediction:", id2label[y_pred[0]])
output_label = id2label[y_pred[0]]
if id2label[y_pred[0]] in ["Crying", "Gunshot", "Glass breaking"]:
print(f"ALERT: {id2label[y_pred[0]]} detected!")
os.remove(audio_filename)
# Send a confirmation message back to the client
conn.sendall(output_label.encode('utf-8'))
conn.shutdown(socket.SHUT_WR)
conn.close()
# if detected:
# filename = f"recording_{int(time.time())}.wav"
# os.rename(temp_filename, filename)
# print(f"Alert! Glass breaking detected in {filename}")
# else:
# os.remove(temp_filename)
if __name__ == "__main__":
audio_filename = "received_output.wav"
get_predictions(audio_filename, window_size = 2, sample_rate = 16000)