-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_test_csv.py
71 lines (55 loc) · 2.3 KB
/
create_test_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
from project.utils.create_input_df import CreateDataframe
import pandas as pd
PREDICTION_DAYS_COUNT = 26
FILE_NAME = "Test.csv"
STATES_COUNT = 50
STATE_CSV_FILE_PATH = "./project/data/daily_report_per_states/states/states.csv"
def get_forecast_id(date_day, state_id):
return state_id + STATES_COUNT * date_day
def get_test_df(self, state_id, attr):
if attr == "Confirmed":
return self.test_data_confirmed[self.US_STATES[state_id]]
else:
return self.test_data_death[self.US_STATES[state_id]]
class CreateTestCSV(object):
def __init__(self):
self.dataFrameFactory = CreateDataframe()
self.test_data_confirmed = self.dataFrameFactory.get_final_df("Confirmed")[142:]
self.test_data_death = self.dataFrameFactory.get_final_df("Deaths")[142:]
# testing input: array of date index, following the training input (i.e 142,143,...167)
self.days = np.array(
self.test_data_confirmed["Days"]).reshape(-1, 1)[142:]
states_file = STATE_CSV_FILE_PATH
states = pd.read_csv(states_file, engine="python")
self.US_STATES = []
for index, row in states.iterrows():
self.US_STATES.append(row.loc['State'])
def generate(self):
deaths = [None] * STATES_COUNT
confirmed = [None] * STATES_COUNT
res = []
# get predicted values for each state
for state_id in range(STATES_COUNT):
confirmed[state_id] = get_test_df(self, state_id, "Confirmed")
deaths[state_id] = get_test_df(self, state_id, "Deaths")
for day in range(142, 142 + PREDICTION_DAYS_COUNT):
for state_id in range(STATES_COUNT):
forcast_id = get_forecast_id(day-142, state_id)
res.append([forcast_id, confirmed[state_id]
[day], deaths[state_id][day]])
return res
def write_file(self, data):
file = open(FILE_NAME, "w")
file.truncate()
file.write("ForecastID,Confirmed,Deaths\n")
for row in data:
line = str(row[0]) + "," + str(row[1]) + \
"," + str(row[2]) + "\n"
file.write(line)
def main():
csv = CreateTestCSV()
output = csv.generate()
csv.write_file(output)
if __name__ == "__main__":
main()