Skip to content

Commit

Permalink
add test cases to ai extractors #86
Browse files Browse the repository at this point in the history
  • Loading branch information
fqrious committed Nov 22, 2024
1 parent 6473e78 commit 326c475
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
File renamed without changes.
30 changes: 29 additions & 1 deletion txt2stix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,30 @@
from .stix import txt2stixBundler
from .txt2stix import extract_all
from .txt2stix import extract_all
from pathlib import Path


INCLUDES_PATH = None
def get_include_path():
global INCLUDES_PATH

if INCLUDES_PATH:
return INCLUDES_PATH

from pathlib import Path
MODULE_PATH = Path(__file__).parent.parent
INCLUDES_PATH = MODULE_PATH/"includes"
try:
from . import includes
INCLUDES_PATH = Path(includes.__file__).parent
except:
pass
return INCLUDES_PATH

def set_include_path(path):
global INCLUDES_PATH
INCLUDES_PATH = path


__all__ = [
'txt2stixBundler', 'extract_all', 'get_include_path'
]
13 changes: 11 additions & 2 deletions txt2stix/ai_extractor/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import json
import logging

import dotenv
Expand All @@ -22,11 +23,11 @@ class Relationship(BaseModel):
relationship_type: str = Field(description='is a description of the relationship between target and source.')

class ExtractionList(BaseModel):
extractions: list[Extraction]
extractions: list[Extraction] = Field(default_factory=list)
success: bool

class RelationshipList(BaseModel):
relationships: list[Relationship]
relationships: list[Relationship] = Field(default_factory=list)
success: bool


Expand All @@ -50,8 +51,16 @@ def get_extractors_str(extractors):
print(f"- {extractor.prompt_helper}", file=buffer)
if extractor.prompt_conversion:
print(f"- {extractor.prompt_conversion}", file=buffer)
if extractor.prompt_positive_examples:
print(f"- Here are some examples that MATCH: {json.dumps(extractor.prompt_positive_examples)}", file=buffer)
if extractor.prompt_negative_examples:
print(f"- Here are some examples that DO NOT MATCH: {json.dumps(extractor.prompt_positive_examples)}", file=buffer)
print("</extractor>", file=buffer)
print("\n"*2, file=buffer)

logging.debug("======== extractors ======")
logging.debug(buffer.getvalue())
logging.debug("======== extractors end ======")
return buffer.getvalue()


Expand Down
14 changes: 12 additions & 2 deletions txt2stix/extractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ class Extractor(NamedDict):
prompt_extraction_extra = None


def __init__(self, key, dct, include_path=None):
def __init__(self, key, dct, include_path=None, test_cases: dict[str, list[str]]=None):
super().__init__(dct)
self.extraction_key = key
self.slug = key
if test_cases:
self.prompt_negative_examples = test_cases.get('test_negative_examples', [])
self.prompt_positive_examples = test_cases.get('test_positive_examples', [])
if self.file and not Path(self.file).is_absolute() and include_path:
self.file = Path(include_path) / self.file

Expand All @@ -38,7 +41,14 @@ def load(self):

def parse_extraction_config(include_path: Path):
config = {}
test_cases = load_test_cases_config(include_path)
for p in include_path.glob("extractions/*/config.yaml"):
config.update(yaml.safe_load(p.open()))

return {k: Extractor(k, v, include_path) for k, v in config.items()}
return {k: Extractor(k, v, include_path, test_cases=test_cases.get(v.get('test_cases'))) for k, v in config.items()}

def load_test_cases_config(include_path: Path) -> dict[str, dict[str, list[str]]]:
config_file = include_path/'test_cases.yaml'
if not config_file.exists():
return {}
return yaml.safe_load(config_file.open())

0 comments on commit 326c475

Please sign in to comment.