Skip to content

Commit

Permalink
update KGE literals
Browse files Browse the repository at this point in the history
  • Loading branch information
sapkotaruz11 committed Oct 28, 2024
1 parent dbb648b commit d87caf1
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,10 +1632,9 @@ def train_literals(self, path, rel_to_predict = []):
filtered_rows = dataset.train_set[dataset.train_set[:, 1] == rel_idx]
h_idx, _, t_indx = filtered_rows.T.tolist()
head_entites = [dataset.idx_to_entity[idx] for idx in h_idx]
literal_values = [dataset.idx_to_entity[idx] for idx in t_indx]
head_embeddings = self.get_transductive_entity_embeddings(head_entites)
assert len(head_embeddings) == len(literal_values)
inputs = torch.tensor(head_embeddings)
literal_values = [float(dataset.idx_to_entity[idx]) for idx in t_indx]
inputs = self.get_entity_embeddings(head_entites)
assert len(inputs) == len(literal_values)
weights = torch.randn(inputs.shape[1], requires_grad=True)
y = torch.tensor([int(x) for x in literal_values])

Expand All @@ -1648,19 +1647,14 @@ def train_literals(self, path, rel_to_predict = []):
if weights.grad is not None:
weights.grad.zero_()

# Step 4: Forward pass (compute predicted output yhat)
product = inputs * weights # Dot product of weights and inputs
yhat = product.sum(dim=1)
# # Step 5: Compute the loss (squared error)
#loss = torch.mean((yhat - y) ** 2)

loss = loss_fn(yhat, y.float())


# # Step 6: Backpropagation (compute gradients)
loss.backward()

# # Step 7: Update weights manually using gradient descent
with torch.no_grad(): # Temporarily disable gradient tracking for manual update
with torch.no_grad():
weights -= learning_rate * weights.grad

# # Step 8: Store and print the loss for tracking
Expand Down

0 comments on commit d87caf1

Please sign in to comment.