Skip to content
This repository has been archived by the owner on Dec 19, 2023. It is now read-only.

Commit

Permalink
update data toolchains for classification task
Browse files Browse the repository at this point in the history
  • Loading branch information
Congyuwang committed Aug 10, 2021
1 parent 25d0cbc commit 23c7e89
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .idea/AlphaNetV3.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='alphanet',
version='0.0.12',
version='0.0.13',
packages=['alphanet'],
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion src/alphanet/alphanet.iml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.9 (base)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.8 (AlphaNetV3)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">
Expand Down
2 changes: 2 additions & 0 deletions src/alphanet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ def __full_tensor_generation__(data,

# 去掉所有包含缺失数据的某股票某时间历史片段
label_nan = _tf.math.is_nan(label_all)
if _tf.rank(label_all) == 2:
label_nan = _tf.math.reduce_any(label_nan, axis=1)
data_nan = _tf.math.is_nan(data_all)
nan_series_time_index = _tf.math.reduce_any(
_tf.math.reduce_any(data_nan, axis=2),
Expand Down
164 changes: 164 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,139 @@ def __get_last_batches__(self, start_basis, end, n, history=30, step=2):
step=step)


class TestDataModuleClassification(unittest.TestCase):
"""
Test the data ranges of data and label for `alphanet.data` for classification task
"""

@classmethod
def setUpClass(cls):
(cls.data,
cls.full_data,
cls.codes,
cls.trading_dates) = __test_data_classification__()
cls.test_date = np.random.randint(20110101, 20121231)
print("getting batches for {} (classification)".format(cls.test_date))
(cls.first_batch_train,
cls.first_batch_val,
cls.last_batch_train,
cls.last_batch_val,
cls.dates_info) = cls.__get_batches__(cls.data, cls.test_date)
cls.start_basis = np.min(np.where(cls.trading_dates >= cls.test_date))

def test_dates_info(self):
self.assertEqual(
min(self.dates_info["validation"]["dates_list"]),
self.dates_info["validation"]["start_date"],
"validation dates_list incorrect"
)
self.assertEqual(
max(self.dates_info["validation"]["dates_list"]),
self.dates_info["validation"]["end_date"],
"validation dates_list incorrect"
)

def test_first_batch_of_training_dataset(self):
data_label = self.__get_first_batches__(self.start_basis, 0, 120)
for k, name in enumerate(["data", "label"]):
self.assertTrue(__is_all_close__(
data_label[k][:len(self.first_batch_train[k])],
self.first_batch_train[k]
), "first batch of training {} "
"(start {}): failure".format(name, self.test_date))

def test_last_batch_of_training_dataset(self):
data_label = self.__get_last_batches__(self.start_basis, 1200, 120)
for k, name in enumerate(["data", "label"]):
self.assertTrue(__is_all_close__(
data_label[k][-len(self.last_batch_train[0]):],
self.last_batch_train[k]
), "last batch of training {} "
"(start {}): failure".format(name, self.test_date))

def test_first_batch_of_validation_dataset(self):
data_label = self.__get_first_batches__(self.start_basis, 1210 - 29, 120)
for k, name in enumerate(["data", "label"]):
self.assertTrue(__is_all_close__(
data_label[k][:len(self.first_batch_val[k])],
self.first_batch_val[k]
), "first batch of validation {} "
"(start {}): failure".format(name, self.test_date))

def test_last_batch_of_validation_dataset(self):
data_label = self.__get_last_batches__(self.start_basis, 1510 - 1, 120)
for k, name in enumerate(["data", "label"]):
self.assertTrue(__is_all_close__(
data_label[k][-len(self.last_batch_val[0]):],
self.last_batch_val[k]
), "last batch of validation {} "
"(start {}): failure".format(name, self.test_date))

@classmethod
def __get_batches__(cls, data, start_date):
train_val_generator = TrainValData(data)
train, val, dates_info = train_val_generator.get(start_date)
first_train = next(iter(train.batch(500)))
first_val = next(iter(val.batch(500)))
last_train = None
last_val = None

for b in iter(train.batch(500)):
last_train = b

for b in iter(val.batch(500)):
last_val = b

return first_train, first_val, last_train, last_val, dates_info

def __get_n_batches__(self,
start_date_index,
end_date_index,
n=2,
step=2,
class_count=3):
data_list = []
label_list = []
running_index = [(start_date_index + day, end_date_index + day, co)
for day in range(0, step * n, step)
for co in self.codes]
for start, end, co in tqdm(running_index):
start_date = self.trading_dates[start]
end_date = self.trading_dates[end]
df_label = pd.get_dummies(self.full_data["10日回报率"])
which_ones = np.logical_and(
np.logical_and(
self.full_data["代码"] == co,
self.full_data["日期"] <= end_date
),
self.full_data["日期"] >= start_date
)
df = self.full_data.loc[which_ones, :]
labels = df_label.loc[which_ones, :]
dt = df.iloc[:, 3:].values
lb = labels.iloc[-1].values
if np.sum(pd.isnull(dt)) == 0:
data_list.append(dt)
label_list.append(lb)

return data_list, label_list

def __get_first_batches__(self, start_basis, start, n, history=30, step=2):
return self.__get_n_batches__(start_basis + start,
start_basis + start + history - 1,
n=n,
step=step)

def __get_last_batches__(self, start_basis, end, n, history=30, step=2):
"""
:param end: exclusive
"""
return self.__get_n_batches__(start_basis + end - history - step * (n - 1),
start_basis + end - 1 - step * (n - 1),
n=n,
step=step)


class TestMetrics(unittest.TestCase):

def test_up_down_accuracy(self):
Expand Down Expand Up @@ -431,6 +564,37 @@ def __test_data__():
return stock_data, full_csi, codes, trading_dates


def __test_data_classification__(class_count=3):
# 测试数据准备
path_1 = "./tests/test_data/test_data.zip"
path_2 = "./test_data/test_data.zip"
if os.path.exists(path_1):
df = pd.read_csv(path_1, dtype={"代码": "category"})
elif os.path.exists(path_2):
df = pd.read_csv(path_2, dtype={"代码": "category"})
else:
raise FileNotFoundError("test data missing")
codes = df.代码.cat.categories
df["10日回报率"] = df.groupby("日期")["10日回报率"].transform(lambda x: pd.qcut(x, class_count, labels=False))
csi_label = df["10日回报率"]
csi_label = pd.get_dummies(csi_label)
df_parts = [(csi_label.loc[df.代码 == code, :], df.loc[df.代码 == code, :]) for code in codes]
stock_data = [TimeSeriesData(dates=p["日期"].values,
data=p.iloc[:, 3:].values,
labels=l.values)
for l, p in df_parts]
# 补全全部stock与日期组合,用于手动生成batch对比测试
trading_dates = df["日期"].unique()
trading_dates.sort()
full_index = pd.DataFrame([[s, d] for s in codes for d in trading_dates])
full_index.columns = ["代码", "日期"]
full_csi = full_index.merge(df,
how="left",
left_on=["代码", "日期"],
right_on=["代码", "日期"])
return stock_data, full_csi, codes, trading_dates


def __is_all_close__(data1, data2, **kwargs):
return np.all(np.isclose(data1, data2, **kwargs))

Expand Down

0 comments on commit 23c7e89

Please sign in to comment.