Skip to content

Commit

Permalink
Fixes None trainer error
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Oct 28, 2024
1 parent 9fcff4d commit e3dff25
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 33 deletions.
5 changes: 0 additions & 5 deletions dicee/read_preprocess_save_load_kg/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ def start(self) -> None:
self.kg.re_vocab = executor.submit(get_re_vocab, data, self.kg.path_for_serialization + '/re_vocab.p')
self.kg.ee_vocab = executor.submit(get_ee_vocab, data, self.kg.path_for_serialization + '/ee_vocab.p')

# TODO: Deprecate it
self.kg.constraints = executor.submit(create_constraints, self.kg.train_set,
self.kg.path_for_serialization + '/constraints.p')
self.kg.domain_constraints_per_rel, self.kg.range_constraints_per_rel = None, None

# string containing
assert isinstance(self.kg.raw_train_set, pd.DataFrame) or isinstance(self.kg.raw_train_set, pl.DataFrame)

Expand Down
18 changes: 0 additions & 18 deletions dicee/read_preprocess_save_load_kg/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,6 @@ def pandas_dataframe_indexer(df_pandas: pd.DataFrame, idx_entity: pd.DataFrame,
del relation_to_index
return df_pandas

def dept_index_triples_with_pandas(train_set, entity_to_idx: dict, relation_to_idx: dict) -> pd.core.frame.DataFrame:
"""
:param train_set: pandas dataframe
:param entity_to_idx: a mapping from str to integer index
:param relation_to_idx: a mapping from str to integer index
:param num_core: number of cores to be used
:return: indexed triples, i.e., pandas dataframe
"""
n, d = train_set.shape
train_set['subject'] = train_set['subject'].apply(lambda x: entity_to_idx.get(x))
train_set['relation'] = train_set['relation'].apply(lambda x: relation_to_idx.get(x))
train_set['object'] = train_set['object'].apply(lambda x: entity_to_idx.get(x))
# train_set = train_set.dropna(inplace=True)
if isinstance(train_set, pd.core.frame.DataFrame):
assert (n, d) == train_set.shape
else:
raise KeyError('Wrong type training data')
return train_set

def apply_reciprical_or_noise(add_reciprical: bool, eval_model: str, df: object = None, info: str = None):
""" (1) Add reciprocal triples (2) Add noisy triples """
Expand Down
17 changes: 7 additions & 10 deletions dicee/trainer/dice_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,12 @@ def initialize_trainer(args, callbacks):
if args.trainer == 'torchCPUTrainer':
print('Initializing TorchTrainer CPU Trainer...', end='\t')
return TorchTrainer(args, callbacks=callbacks)
if args.trainer == 'MP':
elif args.trainer == 'MP':
print('Initializing MPTrainer...', end='\t')
return MP(args, callbacks=callbacks)

elif args.trainer == 'torchDDP':
if torch.cuda.is_available():
print('Initializing TorchDDPTrainer GPU', end='\t')
return TorchDDPTrainer(args, callbacks=callbacks)
else:
print('Initializing TorchTrainer CPU Trainer', end='\t')
return TorchTrainer(args, callbacks=callbacks)

assert torch.cuda.is_available()
print('Initializing TorchDDPTrainer GPU', end='\t')
elif args.trainer == 'PL':
print('Initializing Pytorch-lightning Trainer', end='\t')
kwargs = vars(args)
Expand Down Expand Up @@ -134,7 +128,10 @@ def initialize_trainer(args, callbacks):
detect_anomaly=False,
barebones=False)
else:
raise KeyError("Please choose a valid trainer from ['MP','torchDDP', 'torchCPUTrainer', 'PL']")
print('Initializing TorchTrainer CPU Trainer...', end='\t')
return TorchTrainer(args, callbacks=callbacks)




def get_callbacks(args):
Expand Down

0 comments on commit e3dff25

Please sign in to comment.