Skip to content

Commit

Permalink
feat: Add Platform management and refine Neo4j initialization
Browse files Browse the repository at this point in the history
Introduce `Platform` class with relationships and methods in schema.py.
Refactor Neo4j initialization using neomodel in graph.py and update related
functions to handle platforms in concord.py.
  • Loading branch information
sajz authored and Septimus4 committed Nov 8, 2024
1 parent 8ae5876 commit 3d373f6
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 117 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,6 @@ cython_debug/
/concord/bertopic_model.pkl
/.idea/rust.xml
/nltk_data/
/concord/dataset_topic_messages.csv
/topic_model
/topic_visualization.html
35 changes: 0 additions & 35 deletions .idea/runConfigurations/Concord.xml

This file was deleted.

17 changes: 17 additions & 0 deletions .idea/runConfigurations/Server.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

77 changes: 51 additions & 26 deletions src/bert/concord.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,72 @@
# concord.py

from bert.pre_process import preprocess_documents
from graph.schema import Topic
from graph.schema import Topic, Channel, Platform


def concord(
topic_model,
documents,
):
# Load the dataset and limit to 100 documents
print(f"Loaded {len(documents)} documents.")
def concord(bert_topic, channel_id, platform_id, documents):
platform, channel = platform_channel_handler(channel_id, platform_id)

# Preprocess the documents
# Load and preprocess documents
print(f"Loaded {len(documents)} documents.")
print("Preprocessing documents...")
documents = preprocess_documents(documents)

# Fit the model on the documents
# Fit the topic model
print("Fitting the BERTopic model...")
topics, probs = topic_model.fit_transform(documents)
bert_topic.fit(documents)
topic_info = bert_topic.get_topic_info()

# Get topic information
topic_info = topic_model.get_topic_info()

# Print the main topics with importance scores
# Log main topics
print("\nMain Topics with Word Importance Scores:")
for index, row in topic_info.iterrows():
topic_id = row['Topic']
if topic_id == -1:
continue # Skip outliers
topic_freq = row['Count']
topic_words = topic_model.get_topic(topic_id)
topic_words = bert_topic.get_topic(topic_id)

# Prepare a list of formatted word-score pairs
word_score_list = [
f"{word} ({score:.4f})" for word, score in topic_words
]
# Create a list of word-score pairs
word_score_list = [{
"term": word,
"weight": score
} for word, score in topic_words]

# Join the pairs into a single string
word_score_str = ', '.join(word_score_list)
# Create or update a Topic node
topic = Topic.create_topic(name=f"Topic {topic_id}",
keywords=word_score_list,
bertopic_metadata={
"frequency": topic_freq
}).save()
topic.set_topic_embedding(bert_topic.topic_embeddings_[topic_id])
channel.associate_with_topic(topic, channel_score=0.5, trend="")

# Print the topic info and the word-score string
print(f"\nTopic {topic_id} (Frequency: {topic_freq}):")
print(f" {word_score_str}")
print(
f" {', '.join([f'{word} ({score:.4f})' for word, score in topic_words])}"
)

print("\nTopic modeling and channel update completed.")
return len(documents), None


print("\nTopic modeling completed.")
return len(documents), Topic.create_topic()
def platform_channel_handler(channel_id, platform_id):
platform = Platform.nodes.get_or_none(platform_id=platform_id)
if not platform:
print(
f"Platform with ID '{platform_id}' not found. Creating new platform..."
)
platform = Platform(platform_id=platform_id).save()
channel = Channel.nodes.get_or_none(channel_id=channel_id)
if not channel:
print(
f"Channel with ID '{channel_id}' not found. Creating new channel..."
)
channel = Channel.create_channel(
channel_id=channel_id,
name=f"Channel {channel_id}",
description="",
language="English",
activity_score=0.0,
).save()
platform.channels.connect(channel)
return platform, channel
98 changes: 98 additions & 0 deletions src/bert/topic_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# topic_update.py
from sklearn.metrics.pairwise import cosine_similarity
from datetime import datetime
from graph.schema import Topic, TopicUpdate, Channel

SIMILARITY_THRESHOLD = 0.8
AMPLIFY_INCREMENT = 0.1
DIMINISH_DECREMENT = 0.05
NEW_TOPIC_INITIAL_SCORE = 0.1


def compute_cosine_similarity(vector_a, vector_b):
return cosine_similarity([vector_a], [vector_b])[0][0]


def update_channel_topics(channel_topics, new_topics, channel_id):
initial_scores = {
topic.topic_id: topic.topic_score
for topic in channel_topics
}
topic_updates = []

for new_topic in new_topics:
print(
f"\nProcessing new topic: {new_topic['name']} with weight {new_topic['weight']:.4f}"
)
similarities = {
idx:
compute_cosine_similarity(new_topic['embedding'],
channel_topic.topic_embedding)
for idx, channel_topic in enumerate(channel_topics)
}
print("Similarity scores:", similarities)

topic_amplified = False
for idx, similarity in similarities.items():
if similarity >= SIMILARITY_THRESHOLD:
channel_topic = channel_topics[idx]
original_score = channel_topic.topic_score
channel_topic.topic_score = min(
1, channel_topic.topic_score + AMPLIFY_INCREMENT)
delta = channel_topic.topic_score - original_score
channel_topic.updated_at = datetime.utcnow()
channel_topic.save()
print(
f"Amplifying topic '{channel_topic.name}' from {original_score:.4f} to "
f"{channel_topic.topic_score:.4f} (delta = {delta:.4f})")

topic_update = TopicUpdate.create_topic_update(
keywords=channel_topic.keywords, score_delta=delta)
topic_update.topic.connect(channel_topic)
topic_updates.append(topic_update)

topic_amplified = True

if not topic_amplified:
print(
f"Creating new topic '{new_topic['name']}' with initial score {NEW_TOPIC_INITIAL_SCORE:.4f}"
)
topic_node = Topic(name=new_topic['name'],
topic_embedding=new_topic['embedding'],
topic_score=NEW_TOPIC_INITIAL_SCORE,
updated_at=datetime.utcnow()).save()
topic_node.add_update(new_topic.get('keywords', []),
NEW_TOPIC_INITIAL_SCORE)
Channel.nodes.get(channel_id=channel_id).associate_with_topic(
topic_node, NEW_TOPIC_INITIAL_SCORE,
new_topic.get('keywords', []), 1, 'New')
channel_topics.append(topic_node)

for channel_topic in channel_topics:
if channel_topic.name not in [nt['name'] for nt in new_topics]:
original_score = channel_topic.topic_score
channel_topic.topic_score = max(
0, channel_topic.topic_score - DIMINISH_DECREMENT)
delta = original_score - channel_topic.topic_score
channel_topic.updated_at = datetime.utcnow()
channel_topic.save()
print(
f"Diminishing topic '{channel_topic.name}' from {original_score:.4f} to "
f"{channel_topic.topic_score:.4f} (delta = -{delta:.4f})")

if delta != 0:
topic_update = TopicUpdate.create_topic_update(
keywords=channel_topic.keywords, score_delta=-delta)
topic_update.topic.connect(channel_topic)
topic_updates.append(topic_update)

print("\nUpdated Channel Topics:")
print("{:<30} {:<15} {:<15}".format("Topic Name", "Initial Score",
"Updated Score"))
for topic in channel_topics:
initial_score = initial_scores.get(topic.topic_id,
NEW_TOPIC_INITIAL_SCORE)
print("{:<30} {:<15.4f} {:<15.4f}".format(topic.name, initial_score,
topic.topic_score))

return topic_updates
8 changes: 5 additions & 3 deletions src/concord/server/app_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# app_lifespan.py
from contextlib import asynccontextmanager
from fastapi import FastAPI

from bert.model_manager import ModelManager
from graph.graph import initialize_neo4j


@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: Initialize the BERTopic model
# Startup: Initialize the BERTopic model and Neo4j connection
ModelManager.initialize_model()
initialize_neo4j() # Initialize Neo4j connection

yield
# Shutdown: Perform any necessary cleanup here
# Shutdown logic here (if needed)
39 changes: 8 additions & 31 deletions src/graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,14 @@
# graph.py
import os
from neomodel import config

from neo4j import GraphDatabase
# Get the connection details from environment variables
DATABASE_URL = os.getenv("DATABASE_URL",
"localhost:7687") # No `bolt://` prefix here
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "dev-password")


# Initialize Neo4j driver
def initialize_neo4j():
# Get uri username and password from ENV
user = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
uri = os.environ.get("NEO4J_URI")

return GraphDatabase.driver(uri, auth=(user, password))


# Function to store topics in Neo4j
def store_topics_in_neo4j(model, batch_num):
"""
Store topics and their relationships in Neo4j.
"""
driver = initialize_neo4j()
with driver.session() as session:
topics = model.get_topics()
for topic_num, words in topics.items():
if topic_num == -1:
continue # -1 is usually the outlier/noise topic
# Create Topic node
session.run(
"MERGE (t:Topic {id: $id}) "
"SET t.keywords = $keywords, t.batch = $batch",
id=topic_num,
keywords=words,
batch=batch_num,
)
driver.close()
print("Topics stored in Neo4j.")
# Add the 'bolt://' prefix and format the URL with credentials
config.DATABASE_URL = f"bolt://{NEO4J_USER}:{NEO4J_PASSWORD}@{DATABASE_URL}"
Loading

0 comments on commit 3d373f6

Please sign in to comment.