Skip to content

Commit

Permalink
keep idx_eve and idx_sta
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Oct 17, 2024
1 parent 737551f commit e005a6f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 6 additions & 2 deletions adloc/adloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def forward(
if phase_time is not None:
phase_time_ = phase_time[phase_type == type]
resisudal[phase_type == type] = phase_time_ - t_
loss += torch.sum(F.huber_loss(t_, phase_time_, reduction="none") * phase_weight_)
# loss += torch.sum(F.huber_loss(t_, phase_time_, reduction="none") * phase_weight_)
# loss += torch.sum(F.l1_loss(t_, phase_time_, reduction="none") * phase_weight_)
loss += torch.sum(torch.abs(t_ - phase_time_) * phase_weight_)

return {"phase_time": pred_time, "residual": resisudal, "loss": loss}

Expand Down Expand Up @@ -312,7 +314,9 @@ def forward(

if phase_time is not None:
phase_time_ = phase_time[phase_type == type]
loss += torch.sum(F.huber_loss(t_, phase_time_, reduction="none") * phase_weight_)
# loss += torch.sum(F.huber_loss(t_, phase_time_, reduction="none") * phase_weight_)
# loss += torch.sum(F.l1_loss(t_, phase_time_, reduction="none") * phase_weight_)
loss += torch.sum(torch.abs(t_ - phase_time_) * phase_weight_)

if loss == 0.0:
return None
Expand Down
6 changes: 3 additions & 3 deletions adloc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def __getitem__(self, i):
if len(idx) == 0:
return None # skip empty batch

idx1_eve = self.pairs["event_index1"][idx]
idx2_eve = self.pairs["event_index2"][idx]
idx1_eve = self.pairs["idx_eve1"][idx]
idx2_eve = self.pairs["idx_eve2"][idx]
idx_eve = np.stack([idx1_eve, idx2_eve], axis=1)
idx_sta = self.pairs["station_index"][idx]
idx_sta = self.pairs["idx_sta"][idx]
phase_weight = self.pairs["phase_score"][idx]
phase_type = self.pairs["phase_type"][idx]
phase_time = self.pairs["phase_dtime"][idx]
Expand Down

0 comments on commit e005a6f

Please sign in to comment.