Skip to content

Commit 556e2e6

Browse files
committed
Edit entity extraction modules
1 parent 3b5e48b commit 556e2e6

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

interact.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def interact(self):
9393
def post_process(self, prediction, u_ent_features):
9494
if prediction == 0:
9595
return True
96-
attr_list = [9, 11, 6, 1]
96+
attr_list = [9, 12, 6, 1]
9797
if all(u_ent_featur == 1 for u_ent_featur in u_ent_features) and prediction in attr_list:
9898
return True
9999
else:
@@ -102,7 +102,7 @@ def post_process(self, prediction, u_ent_features):
102102
def action_post_process(self, prediction, u_entities):
103103
attr_mapping_dict = {
104104
9: '<cuisine>',
105-
11: '<location>',
105+
12: '<location>',
106106
6: '<party_size>',
107107
1: '<rest_type>'
108108
}

modules/actions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33

44
'''
55
Action Templates
6-
'api_call <cuisine> <location> <party_size> <rest_type>',
7-
'가격의 범위는 어느정도로 생각하세요',
8-
'감사합니다',
9-
'네 또 변경하실게 있나요',
10-
'다른 리스트를 보여드릴게요',
11-
'또 도와드릴게 있나요',
12-
'몇명의 인원으로 예약하실 건가요',
13-
'안녕하세요 어떻게 도와드릴까요',
14-
'알겠습니다',
15-
'어떤 종류의 요리를 좋아하나요',
16-
'예약을 진행해드리도록 하겠습니다',
17-
'위치는 <info_address> 입니다',
18-
'위치는 어디에 있어야 하나요',
19-
'이 리스트는 어떤가요: <restaurant>',
20-
'전화번호는 <info_phone> 입니다',
21-
'좋아요 몇 가지 리스트를 보여드릴게요'
6+
0. api_call <cuisine> <location> <party_size> <rest_type>',
7+
1. 가격의 범위는 어느정도로 생각하세요',
8+
2. 감사합니다',
9+
3. 네 또 변경하실게 있나요',
10+
4. 다른 리스트를 보여드릴게요',
11+
5. 또 도와드릴게 있나요',
12+
6. 몇명의 인원으로 예약하실 건가요',
13+
7. 안녕하세요 어떻게 도와드릴까요',
14+
8. 알겠습니다',
15+
9. 어떤 종류의 요리를 좋아하나요',
16+
10. 예약을 진행해드리도록 하겠습니다',
17+
11. 위치는 <info_address> 입니다',
18+
12. 위치는 어디에 있어야 하나요',
19+
13. 이 리스트는 어떤가요: <restaurant>',
20+
14. 전화번호는 <info_phone> 입니다',
21+
15. 좋아요 몇 가지 리스트를 보여드릴게요'
2222
'''
2323

2424

modules/entities.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self):
1717
# constants
1818
self.party_sizes = ['한명', '두명', '세명', '셋', '네명', '넷', '다섯', '다섯명', '여섯', '여섯명', '일곱', '여덟']
1919

20-
self.locations = ['방콕', '베이징', '붐베이', '하노이', '파리', '로마', '런던', '마드리드', '서울', '도쿄']
20+
self.locations = ['방콕', '베이징', '붐베이', '하노이', '파리', '로마', '런던', '마드리드', '서울', '도쿄', 'LA']
2121

2222
self.cuisines = ['영국','중국','프랑스', '이탈리아', '인도', '일식', '일본', '한식', '한국', '스페인', '타이', '베트남']
2323

@@ -26,23 +26,28 @@ def __init__(self):
2626
self.EntType = Enum('Entity Type', '<party_size> <location> <cuisine> <rest_type> <non_ent>')
2727

2828
def ent_type(self, ent):
29+
# entity = [word for word in locations if word in input_str]
2930
if ent.startswith(tuple(self.party_sizes)):
30-
return self.EntType['<party_size>'].name
31+
entity_word = [word for word in self.party_sizes if word in ent][0]
32+
return self.EntType['<party_size>'].name, entity_word
3133
elif ent.startswith(tuple(self.locations)):
32-
return self.EntType['<location>'].name
34+
entity_word = [word for word in self.locations if word in ent][0]
35+
return self.EntType['<location>'].name, entity_word
3336
elif ent.startswith(tuple(self.cuisines)):
34-
return self.EntType['<cuisine>'].name
37+
entity_word = [word for word in self.cuisines if word in ent][0]
38+
return self.EntType['<cuisine>'].name, entity_word
3539
elif ent.startswith(tuple(self.rest_types)):
36-
return self.EntType['<rest_type>'].name
40+
entity_word = [word for word in self.rest_types if word in ent][0]
41+
return self.EntType['<rest_type>'].name, entity_word
3742
else:
38-
return ent
43+
return ent, None
3944

4045
def extract_entities(self, utterance, update=True, is_test=False):
4146
tokenized = []
4247
for word in utterance.split(' '):
43-
entity = self.ent_type(word)
48+
entity, entity_word = self.ent_type(word)
4449
if word != entity and update:
45-
self.entities[entity] = word
50+
self.entities[entity] = entity_word
4651
tokenized.append(entity)
4752
tokenized_str = ' '.join(tokenized)
4853

0 commit comments

Comments
 (0)