Skip to content

Commit 7370cd7

Browse files
authored
feat: implement CategoryFromVideo (#126)
1 parent 5ec13e9 commit 7370cd7

File tree

3 files changed

+125
-6
lines changed

3 files changed

+125
-6
lines changed

server/pkg/handlers/handlers.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"sort"
55

66
"github.com/codingpot/pr12er/server/internal/err"
7+
"github.com/codingpot/pr12er/server/pkg/handlers/prutils"
78
"github.com/codingpot/pr12er/server/pkg/pr12er"
89
)
910

@@ -19,7 +20,7 @@ func VideosResponseFromDB(db *pr12er.Database) *pr12er.GetVideosResponse {
1920
Title: dataVideo.GetVideoTitle(),
2021
Link: getYouTubeLinkFromID(dataVideo.GetVideoId()),
2122
Presenter: dataVideo.GetUploader(),
22-
Category: getCategory(data),
23+
Category: prutils.CategoryFromVideo(data),
2324
NumberOfLike: dataVideo.GetNumberOfLikes(),
2425
Keywords: getKeywords(data),
2526
NumberOfViews: dataVideo.GetNumberOfViews(),
@@ -48,11 +49,6 @@ func getKeywords(prVideo *pr12er.PrVideo) []string {
4849
return ret
4950
}
5051

51-
// TODO: Implement getCategory based on papers.
52-
func getCategory(prVideo *pr12er.PrVideo) pr12er.Category {
53-
return pr12er.Category_CATEGORY_UNSPECIFIED
54-
}
55-
5652
// getYouTubeLinkFromID returns the full URL.
5753
func getYouTubeLinkFromID(videoID string) string {
5854
return "https://youtu.be/" + videoID
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Package prutils contains util functions for PR model.
2+
package prutils
3+
4+
import (
5+
"strings"
6+
7+
"github.com/codingpot/pr12er/server/pkg/pr12er"
8+
log "github.com/sirupsen/logrus"
9+
)
10+
11+
// NOTE.
12+
// Each keywords should be all lower cased.
13+
14+
var visionKeywords = []string{
15+
"vision",
16+
"detect",
17+
}
18+
19+
var nlpKeywords = []string{
20+
"text",
21+
"sentence",
22+
}
23+
24+
var ocrKeywords = []string{
25+
"ocr",
26+
}
27+
28+
var audioKeywords = []string{
29+
"audio",
30+
}
31+
32+
var recommendationSystemKeywords = []string{
33+
"recommend",
34+
}
35+
36+
// CategoryFromVideo TODO: Convert to more sophisticated algorithm.
37+
func CategoryFromVideo(prVideo *pr12er.PrVideo) pr12er.Category {
38+
title := strings.ToLower(prVideo.GetVideo().GetVideoTitle())
39+
40+
//nolint:gocritic
41+
if containsAnyElem(title, visionKeywords) {
42+
return pr12er.Category_CATEGORY_VISION
43+
} else if containsAnyElem(title, nlpKeywords) {
44+
return pr12er.Category_CATEGORY_NLP
45+
} else if containsAnyElem(title, ocrKeywords) {
46+
return pr12er.Category_CATEGORY_OCR
47+
} else if containsAnyElem(title, audioKeywords) {
48+
return pr12er.Category_CATEGORY_AUDIO
49+
} else if containsAnyElem(title, recommendationSystemKeywords) {
50+
return pr12er.Category_CATEGORY_RS
51+
}
52+
53+
return pr12er.Category_CATEGORY_UNSPECIFIED
54+
}
55+
56+
func containsAnyElem(title string, keywords []string) bool {
57+
for _, keyword := range keywords {
58+
if strings.Contains(title, keyword) {
59+
log.WithFields(log.Fields{
60+
"title": title,
61+
"keyword": keyword,
62+
}).Info("found category keyword")
63+
return true
64+
}
65+
}
66+
return false
67+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package prutils
2+
3+
import (
4+
"testing"
5+
6+
"github.com/codingpot/pr12er/server/pkg/pr12er"
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestCategoryFromVideo(t *testing.T) {
11+
type args struct {
12+
prVideo *pr12er.PrVideo
13+
}
14+
tests := []struct {
15+
name string
16+
args args
17+
want pr12er.Category
18+
}{
19+
{
20+
name: "Recommender Systems should return RS category",
21+
args: args{
22+
prVideo: &pr12er.PrVideo{Video: &pr12er.YouTubeVideo{
23+
VideoTitle: "PR-064: Wide & Deep Learning for Recommender Systems",
24+
}},
25+
},
26+
want: pr12er.Category_CATEGORY_RS,
27+
},
28+
{
29+
name: "Audio paper should return Audio category",
30+
args: args{
31+
prVideo: &pr12er.PrVideo{Video: &pr12er.YouTubeVideo{
32+
VideoTitle: "PR-067: Audio Super Resolution using Neural Nets",
33+
}},
34+
},
35+
want: pr12er.Category_CATEGORY_AUDIO,
36+
},
37+
{
38+
name: "Vision paper returns Vision category",
39+
args: args{
40+
prVideo: &pr12er.PrVideo{
41+
Video: &pr12er.YouTubeVideo{
42+
VideoTitle: "PR-084 MegDet: A Large Mini-Batch Object Detector",
43+
},
44+
},
45+
},
46+
want: pr12er.Category_CATEGORY_VISION,
47+
},
48+
}
49+
50+
for _, tt := range tests {
51+
t.Run(tt.name, func(t *testing.T) {
52+
got := CategoryFromVideo(tt.args.prVideo)
53+
assert.Equalf(t, tt.want, got, "want %s, got %s", tt.want, got)
54+
})
55+
}
56+
}

0 commit comments

Comments
 (0)