-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Platform management and refine Neo4j initialization
Introduce `Platform` class with relationships and methods in schema.py. Refactor Neo4j initialization using neomodel in graph.py and update related functions to handle platforms in concord.py.
- Loading branch information
Showing
11 changed files
with
371 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,72 @@ | ||
# concord.py | ||
|
||
from bert.pre_process import preprocess_documents | ||
from graph.schema import Topic | ||
from graph.schema import Topic, Channel, Platform | ||
|
||
|
||
def concord( | ||
topic_model, | ||
documents, | ||
): | ||
# Load the dataset and limit to 100 documents | ||
print(f"Loaded {len(documents)} documents.") | ||
def concord(bert_topic, channel_id, platform_id, documents): | ||
platform, channel = platform_channel_handler(channel_id, platform_id) | ||
|
||
# Preprocess the documents | ||
# Load and preprocess documents | ||
print(f"Loaded {len(documents)} documents.") | ||
print("Preprocessing documents...") | ||
documents = preprocess_documents(documents) | ||
|
||
# Fit the model on the documents | ||
# Fit the topic model | ||
print("Fitting the BERTopic model...") | ||
topics, probs = topic_model.fit_transform(documents) | ||
bert_topic.fit(documents) | ||
topic_info = bert_topic.get_topic_info() | ||
|
||
# Get topic information | ||
topic_info = topic_model.get_topic_info() | ||
|
||
# Print the main topics with importance scores | ||
# Log main topics | ||
print("\nMain Topics with Word Importance Scores:") | ||
for index, row in topic_info.iterrows(): | ||
topic_id = row['Topic'] | ||
if topic_id == -1: | ||
continue # Skip outliers | ||
topic_freq = row['Count'] | ||
topic_words = topic_model.get_topic(topic_id) | ||
topic_words = bert_topic.get_topic(topic_id) | ||
|
||
# Prepare a list of formatted word-score pairs | ||
word_score_list = [ | ||
f"{word} ({score:.4f})" for word, score in topic_words | ||
] | ||
# Create a list of word-score pairs | ||
word_score_list = [{ | ||
"term": word, | ||
"weight": score | ||
} for word, score in topic_words] | ||
|
||
# Join the pairs into a single string | ||
word_score_str = ', '.join(word_score_list) | ||
# Create or update a Topic node | ||
topic = Topic.create_topic(name=f"Topic {topic_id}", | ||
keywords=word_score_list, | ||
bertopic_metadata={ | ||
"frequency": topic_freq | ||
}).save() | ||
topic.set_topic_embedding(bert_topic.topic_embeddings_[topic_id]) | ||
channel.associate_with_topic(topic, channel_score=0.5, trend="") | ||
|
||
# Print the topic info and the word-score string | ||
print(f"\nTopic {topic_id} (Frequency: {topic_freq}):") | ||
print(f" {word_score_str}") | ||
print( | ||
f" {', '.join([f'{word} ({score:.4f})' for word, score in topic_words])}" | ||
) | ||
|
||
print("\nTopic modeling and channel update completed.") | ||
return len(documents), None | ||
|
||
|
||
print("\nTopic modeling completed.") | ||
return len(documents), Topic.create_topic() | ||
def platform_channel_handler(channel_id, platform_id): | ||
platform = Platform.nodes.get_or_none(platform_id=platform_id) | ||
if not platform: | ||
print( | ||
f"Platform with ID '{platform_id}' not found. Creating new platform..." | ||
) | ||
platform = Platform(platform_id=platform_id).save() | ||
channel = Channel.nodes.get_or_none(channel_id=channel_id) | ||
if not channel: | ||
print( | ||
f"Channel with ID '{channel_id}' not found. Creating new channel..." | ||
) | ||
channel = Channel.create_channel( | ||
channel_id=channel_id, | ||
name=f"Channel {channel_id}", | ||
description="", | ||
language="English", | ||
activity_score=0.0, | ||
).save() | ||
platform.channels.connect(channel) | ||
return platform, channel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# topic_update.py | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
from datetime import datetime | ||
from graph.schema import Topic, TopicUpdate, Channel | ||
|
||
SIMILARITY_THRESHOLD = 0.8 | ||
AMPLIFY_INCREMENT = 0.1 | ||
DIMINISH_DECREMENT = 0.05 | ||
NEW_TOPIC_INITIAL_SCORE = 0.1 | ||
|
||
|
||
def compute_cosine_similarity(vector_a, vector_b): | ||
return cosine_similarity([vector_a], [vector_b])[0][0] | ||
|
||
|
||
def update_channel_topics(channel_topics, new_topics, channel_id): | ||
initial_scores = { | ||
topic.topic_id: topic.topic_score | ||
for topic in channel_topics | ||
} | ||
topic_updates = [] | ||
|
||
for new_topic in new_topics: | ||
print( | ||
f"\nProcessing new topic: {new_topic['name']} with weight {new_topic['weight']:.4f}" | ||
) | ||
similarities = { | ||
idx: | ||
compute_cosine_similarity(new_topic['embedding'], | ||
channel_topic.topic_embedding) | ||
for idx, channel_topic in enumerate(channel_topics) | ||
} | ||
print("Similarity scores:", similarities) | ||
|
||
topic_amplified = False | ||
for idx, similarity in similarities.items(): | ||
if similarity >= SIMILARITY_THRESHOLD: | ||
channel_topic = channel_topics[idx] | ||
original_score = channel_topic.topic_score | ||
channel_topic.topic_score = min( | ||
1, channel_topic.topic_score + AMPLIFY_INCREMENT) | ||
delta = channel_topic.topic_score - original_score | ||
channel_topic.updated_at = datetime.utcnow() | ||
channel_topic.save() | ||
print( | ||
f"Amplifying topic '{channel_topic.name}' from {original_score:.4f} to " | ||
f"{channel_topic.topic_score:.4f} (delta = {delta:.4f})") | ||
|
||
topic_update = TopicUpdate.create_topic_update( | ||
keywords=channel_topic.keywords, score_delta=delta) | ||
topic_update.topic.connect(channel_topic) | ||
topic_updates.append(topic_update) | ||
|
||
topic_amplified = True | ||
|
||
if not topic_amplified: | ||
print( | ||
f"Creating new topic '{new_topic['name']}' with initial score {NEW_TOPIC_INITIAL_SCORE:.4f}" | ||
) | ||
topic_node = Topic(name=new_topic['name'], | ||
topic_embedding=new_topic['embedding'], | ||
topic_score=NEW_TOPIC_INITIAL_SCORE, | ||
updated_at=datetime.utcnow()).save() | ||
topic_node.add_update(new_topic.get('keywords', []), | ||
NEW_TOPIC_INITIAL_SCORE) | ||
Channel.nodes.get(channel_id=channel_id).associate_with_topic( | ||
topic_node, NEW_TOPIC_INITIAL_SCORE, | ||
new_topic.get('keywords', []), 1, 'New') | ||
channel_topics.append(topic_node) | ||
|
||
for channel_topic in channel_topics: | ||
if channel_topic.name not in [nt['name'] for nt in new_topics]: | ||
original_score = channel_topic.topic_score | ||
channel_topic.topic_score = max( | ||
0, channel_topic.topic_score - DIMINISH_DECREMENT) | ||
delta = original_score - channel_topic.topic_score | ||
channel_topic.updated_at = datetime.utcnow() | ||
channel_topic.save() | ||
print( | ||
f"Diminishing topic '{channel_topic.name}' from {original_score:.4f} to " | ||
f"{channel_topic.topic_score:.4f} (delta = -{delta:.4f})") | ||
|
||
if delta != 0: | ||
topic_update = TopicUpdate.create_topic_update( | ||
keywords=channel_topic.keywords, score_delta=-delta) | ||
topic_update.topic.connect(channel_topic) | ||
topic_updates.append(topic_update) | ||
|
||
print("\nUpdated Channel Topics:") | ||
print("{:<30} {:<15} {:<15}".format("Topic Name", "Initial Score", | ||
"Updated Score")) | ||
for topic in channel_topics: | ||
initial_score = initial_scores.get(topic.topic_id, | ||
NEW_TOPIC_INITIAL_SCORE) | ||
print("{:<30} {:<15.4f} {:<15.4f}".format(topic.name, initial_score, | ||
topic.topic_score)) | ||
|
||
return topic_updates |
Oops, something went wrong.