Skip to content

Commit

Permalink
update literals interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sapkotaruz11 committed Oct 28, 2024
1 parent b69f9a5 commit dbb648b
Showing 1 changed file with 52 additions and 9 deletions.
61 changes: 52 additions & 9 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Tuple, Set, Iterable, Dict, Union
import torch
from torch import optim
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from .abstracts import BaseInteractiveKGE
from .dataset_classes import TriplePredictionDataset
Expand Down Expand Up @@ -1618,20 +1619,62 @@ def train(
f"On average Improvement: {first_avg_loss_per_triple - last_avg_loss_per_triple:.3f}"
)

def train_literals(self, path):
def train_literals(self, path, rel_to_predict = []):
"Funtion to train regression model for literals with pre-trained Embeddings"

self.literal_KG = KG(path_single_kg=path, backend="rdflib")
self.weight_dict = {}
dataset = self.literal_KG
for rel in dataset.relations_str:
rel_idx = dataset.relation_to_idx[rel]
filtered_rows = dataset.train_set[dataset.train_set[:, 1] == rel_idx]
h_idx, _, t_indx = filtered_rows.T.tolist()
head_entites = [dataset.idx_to_entity[idx] for idx in h_idx]
literal_values = [dataset.idx_to_entity[idx] for idx in t_indx]
head_embeddings = self.get_transductive_entity_embeddings(head_entites)
assert len(literal_values) == len(head_embeddings)
rel_name =rel.split(sep="#")[-1]
if rel_name in rel_to_predict:
rel_idx = dataset.relation_to_idx[rel]
filtered_rows = dataset.train_set[dataset.train_set[:, 1] == rel_idx]
h_idx, _, t_indx = filtered_rows.T.tolist()
head_entites = [dataset.idx_to_entity[idx] for idx in h_idx]
literal_values = [dataset.idx_to_entity[idx] for idx in t_indx]
head_embeddings = self.get_transductive_entity_embeddings(head_entites)
assert len(head_embeddings) == len(literal_values)
inputs = torch.tensor(head_embeddings)
weights = torch.randn(inputs.shape[1], requires_grad=True)
y = torch.tensor([int(x) for x in literal_values])

learning_rate = 0.1
epochs = 100
losses = []
loss_fn = MSELoss()
for epoch in range(epochs):
# Zero the gradients from previous iteration
if weights.grad is not None:
weights.grad.zero_()

# Step 4: Forward pass (compute predicted output yhat)
product = inputs * weights # Dot product of weights and inputs
yhat = product.sum(dim=1)
# # Step 5: Compute the loss (squared error)
#loss = torch.mean((yhat - y) ** 2)
loss = loss_fn(yhat, y.float())


# # Step 6: Backpropagation (compute gradients)
loss.backward()

# # Step 7: Update weights manually using gradient descent
with torch.no_grad(): # Temporarily disable gradient tracking for manual update
weights -= learning_rate * weights.grad

# # Step 8: Store and print the loss for tracking
losses.append(loss.item())
# if epoch % 10 == 0:
# print(f'for relation {rel_name}, Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')
self.weight_dict[rel] = weights



def predict_literals(self, h, r):
"Funtion to predict literals using pre-trained KGE models"
pass
h_embed = self.get_entity_embeddings([h])[0]
w = self.weight_dict[r]
prod_sum = torch.sum(h_embed * w)
literal_value = prod_sum.item()
return literal_value

0 comments on commit dbb648b

Please sign in to comment.