-
Notifications
You must be signed in to change notification settings - Fork 2
/
spotify_air.py
136 lines (119 loc) · 4.25 KB
/
spotify_air.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from torchvision.transforms import ToTensor, Grayscale, Resize, Compose
import torch.nn as nn
import numpy
import cv2 as cv
import time
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
# setting device as cpu to make it less computationally intensive
device = torch.device("cpu")
# model class
class GestureModel(nn.Module):
"""Neural net for recognizing hand gestures
Input dims: m x 1 x 28 x 28
Output dims: m x 4
"""
def __init__(self):
super().__init__()
# input: m x 1 x 28 x 28
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(16),
) # m x 16 x 28 x 28
self.conv2 = nn.Sequential(
# nn.Dropout(p=0.2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.MaxPool2d(2),
) # m x 32 x 14 x 14
self.conv3 = nn.Sequential(
# nn.Dropout(p=0.2),
nn.Conv2d(32, 64, kernel_size=3),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2),
) # m x 64 x 6 x 6
self.conv4 = nn.Sequential(
# nn.Dropout(p=0.2),
nn.Conv2d(64, 128, kernel_size=3),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.MaxPool2d(2),
) # m x 128 x 2 x 2
self.classifier = nn.Sequential(
nn.Flatten(), # m x 128*2*2
nn.Dropout(p=0.2),
nn.Linear(128*2*2, 4),
nn.Softmax(dim=1),
) # m x 4
def forward(self, input):
x = self.conv1(input)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
output = self.classifier(x)
return output
# creating webdriver
browser = webdriver.Chrome()
browser.get("https://open.spotify.com/")
# creating a sleep timer to allow initializing the browser
time.sleep(4)
# finds login button
login_button = browser.find_elements_by_class_name("_3f37264be67c8f40fa9f76449afdb4bd-scss _1f2f8feb807c94d2a0a7737b433e19a8-scss")
#clicking the login button
login_button[0].click()
time.sleep(2)
username = browser.find_elements_by_id("login-username")
username[0].send_keys('USER-NAME') # enter username here
password = browser.find_elements_by_id("login-password")
password[0].send_keys('PASSWORD') # enter hardcoded password here
# finds login button
login = browser.find_elements_by_id("login-button")
#clicking the login button
login[0].click()
time.sleep(3)
model = GestureModel() # initialising model object
model.load_state_dict(torch.load("hand_gesture_model.pth")) # loading the pretrained weights
model.to(device) # moving the model to CPU for less computation
model.eval() # setting the model to evaluation mode
# defining the set of transformations
transform = Compose([
ToTensor(),
Grayscale(),
Resize([28, 28])
])
# runnig cv to capture video frames and perform inference on it
video = cv.VideoCapture(0)
while True:
isTrue, frame = video.read()
frame = torch.unsqueeze(transform(frame), 0)
output = model(frame) # inference on frame
action = torch.argmax(output).item()
if(output[action] > 0.5): # perform action only if probability over the threshold of 0.5
if(action == 0):
try:
play = browser.find_elements_by_class_name("_82ba3fb528bb730b297a91f46acd37a3-scss")
play[0].click()
except:
print("Play button not found")
elif(action == 1):
try:
pause = browser.find_elements_by_class_name("_82ba3fb528bb730b297a91f46acd37a3-scss")
pause[0].click()
except:
print("Pause button not found")
elif(action == 2):
try:
next_button = browser.find_elements_by_class_name("bf01b0d913b6bfffea0d4ffd7393c4af-scss")
next_button[0].click()
except:
print("Next button not found")
elif(action == 3):
try:
prev_button = browser.find_elements_by_class_name("bc13c597ccee51a09ec60253c3c51c75-scss")
prev_button[0].click()
except:
print("Previous button not found")