Skip to content

Commit 326c475

Browse files
committed
add test cases to ai extractors #86
1 parent 6473e78 commit 326c475

File tree

4 files changed

+52
-5
lines changed

4 files changed

+52
-5
lines changed
File renamed without changes.

txt2stix/__init__.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,30 @@
11
from .stix import txt2stixBundler
2-
from .txt2stix import extract_all
2+
from .txt2stix import extract_all
3+
from pathlib import Path
4+
5+
6+
INCLUDES_PATH = None
7+
def get_include_path():
8+
global INCLUDES_PATH
9+
10+
if INCLUDES_PATH:
11+
return INCLUDES_PATH
12+
13+
from pathlib import Path
14+
MODULE_PATH = Path(__file__).parent.parent
15+
INCLUDES_PATH = MODULE_PATH/"includes"
16+
try:
17+
from . import includes
18+
INCLUDES_PATH = Path(includes.__file__).parent
19+
except:
20+
pass
21+
return INCLUDES_PATH
22+
23+
def set_include_path(path):
24+
global INCLUDES_PATH
25+
INCLUDES_PATH = path
26+
27+
28+
__all__ = [
29+
'txt2stixBundler', 'extract_all', 'get_include_path'
30+
]

txt2stix/ai_extractor/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import json
23
import logging
34

45
import dotenv
@@ -22,11 +23,11 @@ class Relationship(BaseModel):
2223
relationship_type: str = Field(description='is a description of the relationship between target and source.')
2324

2425
class ExtractionList(BaseModel):
25-
extractions: list[Extraction]
26+
extractions: list[Extraction] = Field(default_factory=list)
2627
success: bool
2728

2829
class RelationshipList(BaseModel):
29-
relationships: list[Relationship]
30+
relationships: list[Relationship] = Field(default_factory=list)
3031
success: bool
3132

3233

@@ -50,8 +51,16 @@ def get_extractors_str(extractors):
5051
print(f"- {extractor.prompt_helper}", file=buffer)
5152
if extractor.prompt_conversion:
5253
print(f"- {extractor.prompt_conversion}", file=buffer)
54+
if extractor.prompt_positive_examples:
55+
print(f"- Here are some examples that MATCH: {json.dumps(extractor.prompt_positive_examples)}", file=buffer)
56+
if extractor.prompt_negative_examples:
57+
print(f"- Here are some examples that DO NOT MATCH: {json.dumps(extractor.prompt_positive_examples)}", file=buffer)
5358
print("</extractor>", file=buffer)
5459
print("\n"*2, file=buffer)
60+
61+
logging.debug("======== extractors ======")
62+
logging.debug(buffer.getvalue())
63+
logging.debug("======== extractors end ======")
5564
return buffer.getvalue()
5665

5766

txt2stix/extractions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ class Extractor(NamedDict):
2121
prompt_extraction_extra = None
2222

2323

24-
def __init__(self, key, dct, include_path=None):
24+
def __init__(self, key, dct, include_path=None, test_cases: dict[str, list[str]]=None):
2525
super().__init__(dct)
2626
self.extraction_key = key
2727
self.slug = key
28+
if test_cases:
29+
self.prompt_negative_examples = test_cases.get('test_negative_examples', [])
30+
self.prompt_positive_examples = test_cases.get('test_positive_examples', [])
2831
if self.file and not Path(self.file).is_absolute() and include_path:
2932
self.file = Path(include_path) / self.file
3033

@@ -38,7 +41,14 @@ def load(self):
3841

3942
def parse_extraction_config(include_path: Path):
4043
config = {}
44+
test_cases = load_test_cases_config(include_path)
4145
for p in include_path.glob("extractions/*/config.yaml"):
4246
config.update(yaml.safe_load(p.open()))
4347

44-
return {k: Extractor(k, v, include_path) for k, v in config.items()}
48+
return {k: Extractor(k, v, include_path, test_cases=test_cases.get(v.get('test_cases'))) for k, v in config.items()}
49+
50+
def load_test_cases_config(include_path: Path) -> dict[str, dict[str, list[str]]]:
51+
config_file = include_path/'test_cases.yaml'
52+
if not config_file.exists():
53+
return {}
54+
return yaml.safe_load(config_file.open())

0 commit comments

Comments
 (0)