Skip to content

Commit

Permalink
Merge pull request #7 from leonvanbokhorst/training-refine
Browse files Browse the repository at this point in the history
Optimize batch size calculation and limit maximum size
  • Loading branch information
leonvanbokhorst authored Dec 25, 2024
2 parents d165baf + 9b54c7b commit fbbde0f
Show file tree
Hide file tree
Showing 6 changed files with 892 additions and 71 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ wheels/
.installed.cfg
*.egg
.venv/
cache/

# Testing
.coverage
Expand Down
86 changes: 27 additions & 59 deletions src/topic_drift/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm.auto import tqdm
from topic_drift.data_types import ConversationData
import hashlib
from datasets import load_dataset


def get_cache_path() -> Path:
Expand All @@ -21,7 +22,7 @@ def get_cache_key(repo_id: str) -> str:


def load_from_huggingface(
repo_id: str = "leonvanbokhorst/topic-drift",
repo_id: str = "leonvanbokhorst/topic-drift-v2",
token: str = None,
use_cache: bool = True,
force_reload: bool = False,
Expand All @@ -36,62 +37,29 @@ def load_from_huggingface(
Returns:
ConversationData object containing the loaded conversations
Raises:
ValueError: If no token is provided and none found in environment
Exception: If download fails or no data found
"""
# Get token from env if not provided
token = token or os.getenv("HF_TOKEN")
if not token:
raise ValueError(
"Hugging Face token not found. Set HF_TOKEN in .env or pass token parameter"
)

cache_path = get_cache_path() / f"{get_cache_key(repo_id)}.jsonl"

# Check cache first
if use_cache and not force_reload and cache_path.exists():
print(f"Loading from cache: {cache_path}")
with open(cache_path, "r") as f:
conversations = [json.loads(line) for line in f]
return ConversationData(conversations=conversations)

print(f"Downloading from Hugging Face: {repo_id}")
api = HfApi(token=token)

# Create temp directory for file operations
with tempfile.TemporaryDirectory() as tmp_dir:
try:
api.snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=tmp_dir,
token=token,
ignore_patterns=[".*"],
)

# Load conversations from JSONL
conversations = []
jsonl_path = Path(tmp_dir) / "conversations.jsonl"
if not jsonl_path.exists():
raise Exception(f"No data found in {repo_id}")

print("Reading conversations from downloaded file...")
with open(jsonl_path, "r") as f:
conversations.extend(
json.loads(line)
for line in tqdm(f, desc="Loading conversations")
)
# Update cache if enabled
if use_cache:
print(f"Updating cache: {cache_path}")
with open(cache_path, "w") as f:
for conv in conversations:
json.dump(conv, f)
f.write("\n")

return ConversationData(conversations=conversations)

except Exception as e:
raise Exception(f"Failed to load data from {repo_id}: {str(e)}") from e
print(f"Loading dataset from Hugging Face: {repo_id}")
dataset = load_dataset(repo_id)

if dataset is None:
raise ValueError("Failed to load dataset from Hugging Face")

# Convert dataset to conversation format
conversations = []
for split in ['train', 'validation', 'test']:
for example in dataset[split]:
conversation = {
'turns': example['conversation'],
'speakers': example['speakers'],
'topic_markers': example['topic_markers'],
'transition_points': example['transition_points'],
'quality_score': example.get('quality_score', 1.0)
}
conversations.append(conversation)

print(f"Loaded {len(conversations)} conversations")
print(f"Train size: {len(dataset['train'])}")
print(f"Validation size: {len(dataset['validation'])}")
print(f"Test size: {len(dataset['test'])}")

return ConversationData(conversations=conversations)
17 changes: 13 additions & 4 deletions src/topic_drift/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ def split_data(
Returns:
DataSplit object containing train/val/test tensors
"""
# Convert tensors to numpy for sklearn
embeddings_np = embeddings.numpy()
labels_np = labels.numpy()

# Convert labels to discrete bins for stratification
n_bins = 10 # Number of bins for stratification
binned_labels = np.floor(labels.numpy() * n_bins).astype(int)
binned_labels = np.floor(labels_np * n_bins).astype(int)

# Count samples in each bin
unique_bins, bin_counts = np.unique(binned_labels, return_counts=True)
Expand All @@ -59,7 +62,13 @@ def split_data(
use_stratify = min_samples >= 2
stratify = binned_labels if use_stratify else None

stratify=stratify,
# First split out test set
train_val_emb, test_emb, train_val_labels, test_labels = train_test_split(
embeddings_np,
labels_np,
test_size=test_size,
random_state=random_state,
stratify=stratify
)

# Then split remaining data into train and validation
Expand All @@ -75,7 +84,7 @@ def split_data(
train_val_labels,
test_size=val_size / (1 - test_size), # Adjust for remaining data
random_state=random_state,
stratify=stratify_remaining,
stratify=stratify_remaining
)

return DataSplit(
Expand All @@ -84,7 +93,7 @@ def split_data(
val_embeddings=torch.from_numpy(val_emb).float(),
val_labels=torch.from_numpy(val_labels).float(),
test_embeddings=torch.from_numpy(test_emb).float(),
test_labels=torch.from_numpy(test_labels).float(),
test_labels=torch.from_numpy(test_labels).float()
)


Expand Down
Loading

0 comments on commit fbbde0f

Please sign in to comment.