-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlibpairgenerator.py
99 lines (85 loc) · 3.65 KB
/
libpairgenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import random
import datetime
import pytz
import csv
import os
class PairGenerator:
def __init__(self, num_inputs, num_pairs, int_values=False):
self.num_inputs = num_inputs
self.num_pairs = num_pairs
self.int_values = int_values
self.gen = random.Random(42)
def generate_pairs(self):
if self.num_inputs < 1:
raise ValueError("Number of inputs should be greater than 0.")
return [(input_values := self._generate_input_values(), self._generate_timestamp(), self._calculate_xor(input_values))
for _ in range(self.num_pairs)]
def _generate_input_values(self):
if self.int_values:
return [self.gen.randint(-1, 1) for _ in range(self.num_inputs)]
else:
return [self.gen.uniform(-1, 1) for _ in range(self.num_inputs)]
def _generate_timestamp(self):
return (datetime.datetime.now(pytz.utc) - datetime.timedelta(hours=self.gen.uniform(1, 24))).isoformat()
def _calculate_xor(self, input_values):
xor_value = round(input_values[0])
for val in input_values[1:]:
xor_value ^= round(val)
return xor_value
def save_to_file(self, filename, append=False):
pairs = self.generate_pairs()
mode = 'a' if append else 'w'
with open(filename, mode, newline='') as f:
writer = csv.writer(f)
if not append:
writer.writerow(['input', 'timestamp', 'target'])
for input_values, timestamp, expected_output in pairs:
input_values_str = ', '.join(map(str, input_values))
writer.writerow([input_values_str, timestamp, expected_output])
import random
import datetime
import pytz
import csv
class PairGenerator:
def __init__(self, num_inputs, num_pairs, num_outputs, int_values=False):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.num_pairs = num_pairs
self.int_values = int_values
self.gen = random.Random(42)
def generate_pairs(self):
if self.num_inputs < 1 or self.num_outputs < 1:
raise ValueError("Number of inputs and outputs should be greater than 0.")
return [
(
input_values := self._generate_input_values(),
self._generate_timestamp(),
self._calculate_xor(input_values)
)
for _ in range(self.num_pairs)
]
def _generate_input_values(self):
if self.int_values:
return [self.gen.randint(-1, 1) for _ in range(self.num_inputs)]
else:
return [self.gen.uniform(-1, 1) for _ in range(self.num_inputs)]
def _generate_timestamp(self):
return (datetime.datetime.now(pytz.utc) - datetime.timedelta(hours=self.gen.uniform(1, 24))).isoformat()
def _calculate_xor(self, input_values):
xor_value = round(input_values[0])
for val in input_values[1:]:
xor_value ^= round(val)
return [xor_value] * self.num_outputs
def save_to_file(self, filename, append=False):
pairs = self.generate_pairs()
mode = 'a' if append else 'w'
with open(filename, mode, newline='') as f:
writer = csv.writer(f)
if not append:
writer.writerow(['input', 'timestamp', 'target'])
for input_values, timestamp, expected_output in pairs:
input_values_str = ', '.join(map(str, input_values))
writer.writerow([input_values_str, timestamp] + expected_output)
# Example usage:
generator = PairGenerator(num_inputs=3, num_pairs=5, num_outputs=2, int_values=False)
generator.save_to_file('output.csv')