Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions Ganblr Evaluation/ganblr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import torch
import torch.nn as nn
import pandas as pd
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

class GANBLR:
def __init__(self, input_dim):
logging.info(f"Initializing GANBLR with input dimension {input_dim}")
self.generator = self.build_generator(output_dim=input_dim)
self.discriminator = self.build_discriminator(input_dim=input_dim)
self.criterion = nn.BCELoss()
Expand Down Expand Up @@ -34,6 +38,7 @@ def build_discriminator(self, input_dim):
)

def fit(self, data):
logging.info(f"Starting training on data with {data.shape[0]} samples and {data.shape[1]} features")
# Convert data to a tensor
data_tensor = torch.tensor(data.values, dtype=torch.float32)

Expand All @@ -59,18 +64,18 @@ def fit(self, data):
d_loss.backward()
self.optimizer_D.step()

print(f"Epoch {epoch+1}/100: Generator Loss: {g_loss.item()}, Discriminator Loss: {d_loss.item()}")
logging.info(f"Epoch {epoch+1}/100: Generator Loss: {g_loss.item()}, Discriminator Loss: {d_loss.item()}")

def generate(self):
# Generate synthetic data
logging.info("Generating synthetic data...")
noise = torch.randn(1000, 100) # Example: Generate 1000 samples
synthetic_data = self.generator(noise).detach().numpy()
return pd.DataFrame(synthetic_data, columns=[f"Feature_{i}" for i in range(synthetic_data.shape[1])])

def save(self, path):
torch.save(self.generator.state_dict(), path)
print(f"Model saved to {path}")
logging.info(f"Model saved to {path}")

def load(self, path):
self.generator.load_state_dict(torch.load(path, weights_only=False))
print(f"Model loaded from {path}")
logging.info(f"Model loaded from {path}")
25 changes: 17 additions & 8 deletions Ganblr Evaluation/run_ganblr.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
import argparse
import pandas as pd
import logging
from ganblr import GANBLR
import os

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

def train_model(input_file):
print(f"Loading dataset from {input_file}...")
if not os.path.exists(input_file):
logging.error(f"Dataset file {input_file} does not exist.")
return

logging.info(f"Loading dataset from {input_file}...")
data = pd.read_csv(input_file)

# Initialize the GANBLR model with the input dimension of the dataset
model = GANBLR(input_dim=data.shape[1])

# Train the model
print("Training the GANBLR model...")
logging.info("Training the GANBLR model...")
model.fit(data)

# Save the trained model
model.save("ganblr_model_checkpoint.pth")
print("Training complete. Model saved.")

logging.info("Training complete. Model saved.")

def generate_data(output_file):
print("Generating synthetic data using GANBLR...")
# Load the preprocessed dataset to infer the input dimensions
if not os.path.exists("preprocessed_real_dataset.csv"):
logging.error("Preprocessed dataset file does not exist.")
return

logging.info("Generating synthetic data using GANBLR...")
data = pd.read_csv("preprocessed_real_dataset.csv")
model = GANBLR(input_dim=data.shape[1])

Expand All @@ -31,8 +41,7 @@ def generate_data(output_file):

# Save the synthetic data
synthetic_data.to_csv(output_file, index=False)
print(f"Synthetic data saved to {output_file}.")

logging.info(f"Synthetic data saved to {output_file}.")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train or Generate Data with GANBLR")
Expand Down