-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from qlib.data.dataset.handler import DataHandlerLP | ||
from qlib.data.dataset.processor import Processor | ||
from qlib.utils import get_callable_kwargs | ||
from qlib.data.dataset import processor as processor_module | ||
from inspect import getfullargspec | ||
from database_utils.db_utils import save_to_db, DuckDBManager | ||
|
||
def check_transform_proc(proc_l, fit_start_time, fit_end_time): | ||
new_l = [] | ||
for p in proc_l: | ||
if not isinstance(p, Processor): | ||
klass, pkwargs = get_callable_kwargs(p, processor_module) | ||
args = getfullargspec(klass).args | ||
if "fit_start_time" in args and "fit_end_time" in args: | ||
assert ( | ||
fit_start_time is not None and fit_end_time is not None | ||
), "Make sure `fit_start_time` and `fit_end_time` are not None." | ||
pkwargs.update( | ||
{ | ||
"fit_start_time": fit_start_time, | ||
"fit_end_time": fit_end_time, | ||
} | ||
) | ||
proc_config = {"class": klass.__name__, "kwargs": pkwargs} | ||
if isinstance(p, dict) and "module_path" in p: | ||
proc_config["module_path"] = p["module_path"] | ||
new_l.append(proc_config) | ||
else: | ||
new_l.append(p) | ||
return new_l | ||
|
||
|
||
_DEFAULT_LEARN_PROCESSORS = [ | ||
{"class": "DropnaLabel"}, | ||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, | ||
] | ||
_DEFAULT_INFER_PROCESSORS = [ | ||
{"class": "ProcessInf", "kwargs": {}}, | ||
{"class": "ZScoreNorm", "kwargs": {}}, | ||
{"class": "Fillna", "kwargs": {}}, | ||
] | ||
|
||
|
||
class Alpha1(DataHandlerLP): | ||
def __init__( | ||
self, | ||
instruments="csi500", | ||
start_time=None, | ||
end_time=None, | ||
freq="day", | ||
infer_processors=[], | ||
learn_processors=_DEFAULT_LEARN_PROCESSORS, | ||
fit_start_time=None, | ||
fit_end_time=None, | ||
process_type=DataHandlerLP.PTYPE_A, | ||
filter_pipe=None, | ||
data_loader=None, | ||
inst_processors=None, | ||
**kwargs | ||
): | ||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) | ||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) | ||
|
||
data_loader = { | ||
"class": "QlibDataLoader", | ||
"kwargs": { | ||
"config": { | ||
"feature": self.get_feature_config(), | ||
"label": kwargs.pop("label", self.get_label_config()), | ||
}, | ||
"filter_pipe": filter_pipe, | ||
"freq": freq, | ||
"inst_processors": inst_processors, | ||
}, | ||
} | ||
super().__init__( | ||
instruments=instruments, | ||
start_time=start_time, | ||
end_time=end_time, | ||
data_loader=data_loader, | ||
infer_processors=infer_processors, | ||
learn_processors=learn_processors, | ||
process_type=process_type, | ||
**kwargs | ||
) | ||
# Print loaded data | ||
print("Loaded data:") | ||
df = self.data_loader.load(instruments=self.instruments, start_time=self.start_time, end_time=self.end_time) | ||
print(start_time,end_time,'debug') | ||
print(df.tail(20)) | ||
|
||
|
||
def get_feature_config(self): | ||
conf = { | ||
"kbar": {}, | ||
"price": { | ||
"windows": [0], | ||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"], | ||
}, | ||
"rolling": {}, | ||
} | ||
return self.parse_config_to_fields(conf) | ||
|
||
def get_label_config(self): | ||
return ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"] | ||
|
||
@staticmethod | ||
def parse_config_to_fields(config): | ||
"""create factors from config | ||
config = { | ||
'kbar': {}, # whether to use some hard-code kbar features | ||
'price': { # whether to use raw price features | ||
'windows': [0, 1, 2, 3, 4], # use price at n days ago | ||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use | ||
}, | ||
'volume': { # whether to use raw volume features | ||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago | ||
}, | ||
'rolling': { # whether to use rolling operator based features | ||
'windows': [5, 10, 20, 30, 60], # rolling windows size | ||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use | ||
#if include is None we will use default operators | ||
'exclude': ['RANK'], # rolling operator not to use | ||
} | ||
} | ||
""" | ||
names = ['ref'] | ||
fields = ["Ref($close, -2)/Ref($close, -1) - 1"] | ||
return fields, names | ||
|
||
|
||
class Alpha158vwap(Alpha1): | ||
def get_label_config(self): | ||
return ["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"] |