Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Jul 8, 2021
1 parent e9c2e7c commit 274dc8e
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,36 +1600,36 @@ def randomly_split_into_two_datasets(dataset, length_of_first):
return [Subset(dataset, first_dataset), Subset(dataset, second_dataset)]


class Relation(DataPoint):
def __init__(self, head: Span, tail: Span):
super().__init__()
self.head = head
self.tail = tail

def to(self, device: str, pin_memory: bool = False):
self.head.to(device, pin_memory)
self.tail.to(device, pin_memory)

def clear_embeddings(self, embedding_names: List[str] = None):
self.head.clear_embeddings(embedding_names)
self.tail.clear_embeddings(embedding_names)

@property
def embedding(self):
return torch.cat([self.head.embedding, self.tail.embedding])

def __repr__(self):
return f"Relation:\n − Head {self.head}\n − Tail {self.tail}\n − Labels: {self.labels}\n"

def to_plain_string(self):
return f"Relation: Head {self.head} || Tail {self.tail} || Labels: {self.labels}\n"

def print_span_text(self):
return f"Relation: Head {self.head} || Tail {self.tail}\n"

def __len__(self):
return len(self.head) + len(self.tail)

@property
def span_indices(self):
return (self.head.tokens[0].idx, self.head.tokens[-1].idx, self.tail.tokens[0].idx, self.tail.tokens[-1].idx)
# class Relation(DataPoint):
# def __init__(self, head: Span, tail: Span):
# super().__init__()
# self.head = head
# self.tail = tail
#
# def to(self, device: str, pin_memory: bool = False):
# self.head.to(device, pin_memory)
# self.tail.to(device, pin_memory)
#
# def clear_embeddings(self, embedding_names: List[str] = None):
# self.head.clear_embeddings(embedding_names)
# self.tail.clear_embeddings(embedding_names)
#
# @property
# def embedding(self):
# return torch.cat([self.head.embedding, self.tail.embedding])
#
# def __repr__(self):
# return f"Relation:\n − Head {self.head}\n − Tail {self.tail}\n − Labels: {self.labels}\n"
#
# def to_plain_string(self):
# return f"Relation: Head {self.head} || Tail {self.tail} || Labels: {self.labels}\n"
#
# def print_span_text(self):
# return f"Relation: Head {self.head} || Tail {self.tail}\n"
#
# def __len__(self):
# return len(self.head) + len(self.tail)
#
# @property
# def span_indices(self):
# return (self.head.tokens[0].idx, self.head.tokens[-1].idx, self.tail.tokens[0].idx, self.tail.tokens[-1].idx)

0 comments on commit 274dc8e

Please sign in to comment.