-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Platform management and refine Neo4j initialization
Introduce `Platform` class with relationships and methods in schema.py. Refactor Neo4j initialization using neomodel in graph.py and update related functions to handle platforms in concord.py.
- Loading branch information
Showing
10 changed files
with
316 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,72 @@ | ||
# concord.py | ||
|
||
from bert.pre_process import preprocess_documents | ||
from graph.schema import Topic | ||
from graph.schema import Topic, Channel, Platform | ||
|
||
|
||
def concord( | ||
topic_model, | ||
documents, | ||
): | ||
# Load the dataset and limit to 100 documents | ||
print(f"Loaded {len(documents)} documents.") | ||
def concord(bert_topic, channel_id, platform_id, documents): | ||
platform, channel = platform_channel_handler(channel_id, platform_id) | ||
|
||
# Preprocess the documents | ||
# Load and preprocess documents | ||
print(f"Loaded {len(documents)} documents.") | ||
print("Preprocessing documents...") | ||
documents = preprocess_documents(documents) | ||
|
||
# Fit the model on the documents | ||
# Fit the topic model | ||
print("Fitting the BERTopic model...") | ||
topics, probs = topic_model.fit_transform(documents) | ||
bert_topic.fit(documents) | ||
topic_info = bert_topic.get_topic_info() | ||
|
||
# Get topic information | ||
topic_info = topic_model.get_topic_info() | ||
|
||
# Print the main topics with importance scores | ||
# Log main topics | ||
print("\nMain Topics with Word Importance Scores:") | ||
for index, row in topic_info.iterrows(): | ||
topic_id = row['Topic'] | ||
if topic_id == -1: | ||
continue # Skip outliers | ||
topic_freq = row['Count'] | ||
topic_words = topic_model.get_topic(topic_id) | ||
topic_words = bert_topic.get_topic(topic_id) | ||
|
||
# Prepare a list of formatted word-score pairs | ||
word_score_list = [ | ||
f"{word} ({score:.4f})" for word, score in topic_words | ||
] | ||
# Create a list of word-score pairs | ||
word_score_list = [{ | ||
"term": word, | ||
"weight": score | ||
} for word, score in topic_words] | ||
|
||
# Join the pairs into a single string | ||
word_score_str = ', '.join(word_score_list) | ||
# Create or update a Topic node | ||
topic = Topic.create_topic(name=f"Topic {topic_id}", | ||
keywords=word_score_list, | ||
bertopic_metadata={ | ||
"frequency": topic_freq | ||
}).save() | ||
topic.set_topic_embedding(bert_topic.topic_embeddings_[topic_id]) | ||
channel.associate_with_topic(topic, channel_score=0.5, trend="") | ||
|
||
# Print the topic info and the word-score string | ||
print(f"\nTopic {topic_id} (Frequency: {topic_freq}):") | ||
print(f" {word_score_str}") | ||
print( | ||
f" {', '.join([f'{word} ({score:.4f})' for word, score in topic_words])}" | ||
) | ||
|
||
print("\nTopic modeling and channel update completed.") | ||
return len(documents), None | ||
|
||
|
||
print("\nTopic modeling completed.") | ||
return len(documents), Topic.create_topic() | ||
def platform_channel_handler(channel_id, platform_id): | ||
platform = Platform.nodes.get_or_none(platform_id=platform_id) | ||
if not platform: | ||
print( | ||
f"Platform with ID '{platform_id}' not found. Creating new platform..." | ||
) | ||
platform = Platform(platform_id=platform_id).save() | ||
channel = Channel.nodes.get_or_none(channel_id=channel_id) | ||
if not channel: | ||
print( | ||
f"Channel with ID '{channel_id}' not found. Creating new channel..." | ||
) | ||
channel = Channel.create_channel( | ||
channel_id=channel_id, | ||
name=f"Channel {channel_id}", | ||
description="", | ||
language="English", | ||
activity_score=0.0, | ||
).save() | ||
platform.channels.connect(channel) | ||
return platform, channel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# topic_update.py | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
from datetime import datetime | ||
from graph.schema import Topic, TopicUpdate, Channel | ||
|
||
SIMILARITY_THRESHOLD = 0.8 | ||
AMPLIFY_INCREMENT = 0.1 | ||
DIMINISH_DECREMENT = 0.05 | ||
NEW_TOPIC_INITIAL_SCORE = 0.1 | ||
|
||
|
||
def compute_cosine_similarity(vector_a, vector_b): | ||
return cosine_similarity([vector_a], [vector_b])[0][0] | ||
|
||
|
||
def update_channel_topics(channel_topics, new_topics, channel_id): | ||
initial_scores = { | ||
topic.topic_id: topic.topic_score | ||
for topic in channel_topics | ||
} | ||
topic_updates = [] | ||
|
||
for new_topic in new_topics: | ||
print( | ||
f"\nProcessing new topic: {new_topic['name']} with weight {new_topic['weight']:.4f}" | ||
) | ||
similarities = { | ||
idx: | ||
compute_cosine_similarity(new_topic['embedding'], | ||
channel_topic.topic_embedding) | ||
for idx, channel_topic in enumerate(channel_topics) | ||
} | ||
print("Similarity scores:", similarities) | ||
|
||
topic_amplified = False | ||
for idx, similarity in similarities.items(): | ||
if similarity >= SIMILARITY_THRESHOLD: | ||
channel_topic = channel_topics[idx] | ||
original_score = channel_topic.topic_score | ||
channel_topic.topic_score = min( | ||
1, channel_topic.topic_score + AMPLIFY_INCREMENT) | ||
delta = channel_topic.topic_score - original_score | ||
channel_topic.updated_at = datetime.utcnow() | ||
channel_topic.save() | ||
print( | ||
f"Amplifying topic '{channel_topic.name}' from {original_score:.4f} to " | ||
f"{channel_topic.topic_score:.4f} (delta = {delta:.4f})") | ||
|
||
topic_update = TopicUpdate.create_topic_update( | ||
keywords=channel_topic.keywords, score_delta=delta) | ||
topic_update.topic.connect(channel_topic) | ||
topic_updates.append(topic_update) | ||
|
||
topic_amplified = True | ||
|
||
if not topic_amplified: | ||
print( | ||
f"Creating new topic '{new_topic['name']}' with initial score {NEW_TOPIC_INITIAL_SCORE:.4f}" | ||
) | ||
topic_node = Topic(name=new_topic['name'], | ||
topic_embedding=new_topic['embedding'], | ||
topic_score=NEW_TOPIC_INITIAL_SCORE, | ||
updated_at=datetime.utcnow()).save() | ||
topic_node.add_update(new_topic.get('keywords', []), | ||
NEW_TOPIC_INITIAL_SCORE) | ||
Channel.nodes.get(channel_id=channel_id).associate_with_topic( | ||
topic_node, NEW_TOPIC_INITIAL_SCORE, | ||
new_topic.get('keywords', []), 1, 'New') | ||
channel_topics.append(topic_node) | ||
|
||
for channel_topic in channel_topics: | ||
if channel_topic.name not in [nt['name'] for nt in new_topics]: | ||
original_score = channel_topic.topic_score | ||
channel_topic.topic_score = max( | ||
0, channel_topic.topic_score - DIMINISH_DECREMENT) | ||
delta = original_score - channel_topic.topic_score | ||
channel_topic.updated_at = datetime.utcnow() | ||
channel_topic.save() | ||
print( | ||
f"Diminishing topic '{channel_topic.name}' from {original_score:.4f} to " | ||
f"{channel_topic.topic_score:.4f} (delta = -{delta:.4f})") | ||
|
||
if delta != 0: | ||
topic_update = TopicUpdate.create_topic_update( | ||
keywords=channel_topic.keywords, score_delta=-delta) | ||
topic_update.topic.connect(channel_topic) | ||
topic_updates.append(topic_update) | ||
|
||
print("\nUpdated Channel Topics:") | ||
print("{:<30} {:<15} {:<15}".format("Topic Name", "Initial Score", | ||
"Updated Score")) | ||
for topic in channel_topics: | ||
initial_score = initial_scores.get(topic.topic_id, | ||
NEW_TOPIC_INITIAL_SCORE) | ||
print("{:<30} {:<15.4f} {:<15.4f}".format(topic.name, initial_score, | ||
topic.topic_score)) | ||
|
||
return topic_updates |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
# app_lifespan.py | ||
from contextlib import asynccontextmanager | ||
from fastapi import FastAPI | ||
|
||
from bert.model_manager import ModelManager | ||
from graph.graph import initialize_neo4j | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
# Startup: Initialize the BERTopic model | ||
# Startup: Initialize the BERTopic model and Neo4j connection | ||
ModelManager.initialize_model() | ||
initialize_neo4j() # Initialize Neo4j connection | ||
|
||
yield | ||
# Shutdown: Perform any necessary cleanup here | ||
# Shutdown logic here (if needed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,14 @@ | ||
# graph.py | ||
import os | ||
from neomodel import config | ||
|
||
from neo4j import GraphDatabase | ||
# Get the connection details from environment variables | ||
DATABASE_URL = os.getenv("DATABASE_URL", | ||
"localhost:7687") # No `bolt://` prefix here | ||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") | ||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "dev-password") | ||
|
||
|
||
# Initialize Neo4j driver | ||
def initialize_neo4j(): | ||
# Get uri username and password from ENV | ||
user = os.environ.get("NEO4J_USERNAME") | ||
password = os.environ.get("NEO4J_PASSWORD") | ||
uri = os.environ.get("NEO4J_URI") | ||
|
||
return GraphDatabase.driver(uri, auth=(user, password)) | ||
|
||
|
||
# Function to store topics in Neo4j | ||
def store_topics_in_neo4j(model, batch_num): | ||
""" | ||
Store topics and their relationships in Neo4j. | ||
""" | ||
driver = initialize_neo4j() | ||
with driver.session() as session: | ||
topics = model.get_topics() | ||
for topic_num, words in topics.items(): | ||
if topic_num == -1: | ||
continue # -1 is usually the outlier/noise topic | ||
# Create Topic node | ||
session.run( | ||
"MERGE (t:Topic {id: $id}) " | ||
"SET t.keywords = $keywords, t.batch = $batch", | ||
id=topic_num, | ||
keywords=words, | ||
batch=batch_num, | ||
) | ||
driver.close() | ||
print("Topics stored in Neo4j.") | ||
# Add the 'bolt://' prefix and format the URL with credentials | ||
config.DATABASE_URL = f"bolt://{NEO4J_USER}:{NEO4J_PASSWORD}@{DATABASE_URL}" |
Oops, something went wrong.