Skip to content

Commit 7451287

Browse files
authored
Merge pull request #123 from codingpot/multi-request-youtube
feat: batch request youtube
2 parents 38c0a30 + b812181 commit 7451287

File tree

5 files changed

+16854
-16729
lines changed

5 files changed

+16854
-16729
lines changed

.golangci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ linters:
3131
- wrapcheck
3232
- funlen
3333
- goimports # False positive with paperswithcode_go
34+
- lll # long line lint

metadata-manager/cmd/genmetadata.go

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ var genMetaCmd = &cobra.Command{
2929
RunE: generateMetadata,
3030
}
3131

32-
//nolint:funlen
3332
func generateMetadata(cmd *cobra.Command, args []string) error {
3433
apiKey := viper.GetString(envNameYouTubeAPIKey)
3534
mappingFile := viper.GetString(envNameMappingFile)
@@ -54,45 +53,86 @@ func generateMetadata(cmd *cobra.Command, args []string) error {
5453
PrIdToVideo: make(map[int32]*pr12er.PrVideo),
5554
}
5655

57-
youtubeService, err := youtube.NewService(context.Background(), option.WithAPIKey(apiKey))
56+
client, err := fetcherClient(apiKey)
5857
if err != nil {
5958
return err
6059
}
61-
client := fetcher.New(paperswithcode_go.NewClient(), youtubeService)
6260

6361
in := make(chan *pr12er.MappingTableRow, len(mappingTable.GetRows()))
6462
out := make(chan *pr12er.PrVideo, len(mappingTable.GetRows()))
6563

66-
for w := 0; w < workers; w++ {
67-
go func(id int, in <-chan *pr12er.MappingTableRow, out chan<- *pr12er.PrVideo) {
68-
for row := range in {
69-
prVideo, err := client.FetchPrVideo(row)
70-
if err != nil {
71-
log.WithError(err).Warn("FetchPrVideo has failed")
72-
// don't block so it still sends nil to out
73-
}
74-
out <- prVideo
75-
}
76-
}(w, in, out)
64+
startWorkers(workers, client, in, out)
65+
// required for multi youtube video fetching.
66+
videoIDToPrMap := beginTheJobAndPrepareVideoMap(mappingTable, in)
67+
68+
close(in)
69+
70+
waitUntilDatabaseUpdate(mappingTable, out, database)
71+
72+
close(out)
73+
74+
updateDatabaseWithYouTubeMetadata(client, videoIDToPrMap, database)
75+
76+
return io.DumpDatabase(database, databaseOutFile)
77+
}
78+
79+
func fetcherClient(apiKey string) (*fetcher.Fetcher, error) {
80+
youtubeService, err := youtube.NewService(context.Background(), option.WithAPIKey(apiKey))
81+
if err != nil {
82+
return nil, err
7783
}
84+
client := fetcher.New(paperswithcode_go.NewClient(), youtubeService)
85+
return client, nil
86+
}
7887

79-
for _, prRow := range mappingTable.GetRows() {
80-
in <- prRow
88+
func updateDatabaseWithYouTubeMetadata(client *fetcher.Fetcher, videoIDToPrMap map[string]int32, database *pr12er.Database) {
89+
videos, err := client.FetchYouTubeVideos(videoIDToPrMap /*batchSize=*/, 50)
90+
if err != nil {
91+
log.WithError(err).Panic("failed to fetch multi YT videos")
8192
}
8293

83-
close(in)
94+
for prID, video := range videos {
95+
database.GetPrIdToVideo()[prID].Video = video
96+
}
97+
}
8498

99+
func waitUntilDatabaseUpdate(mappingTable *pr12er.MappingTable, out chan *pr12er.PrVideo, database *pr12er.Database) {
85100
for range mappingTable.GetRows() {
86101
prVideo := <-out
87102

88103
if prVideo != nil {
89104
database.PrIdToVideo[prVideo.GetPrId()] = prVideo
90105
}
91106
}
107+
}
92108

93-
close(out)
109+
func beginTheJobAndPrepareVideoMap(mappingTable *pr12er.MappingTable, in chan *pr12er.MappingTableRow) map[string]int32 {
110+
videoIDToPrMap := make(map[string]int32)
111+
for _, prRow := range mappingTable.GetRows() {
112+
in <- prRow
94113

95-
return io.DumpDatabase(database, databaseOutFile)
114+
if prRow.GetYoutubeVideoId() != "" {
115+
videoIDToPrMap[prRow.GetYoutubeVideoId()] = prRow.GetPrId()
116+
}
117+
}
118+
return videoIDToPrMap
119+
}
120+
121+
func startWorkers(workers int, client *fetcher.Fetcher, in chan *pr12er.MappingTableRow, out chan *pr12er.PrVideo) {
122+
for w := 0; w < workers; w++ {
123+
go func(id int, in <-chan *pr12er.MappingTableRow, out chan<- *pr12er.PrVideo) {
124+
for row := range in {
125+
// only get papers information
126+
// we will run youtube multi fetch later
127+
prVideo, err := client.FetchOnlyPapers(row)
128+
if err != nil {
129+
log.WithError(err).Warn("FetchPrVideo has failed")
130+
// don't block so it still sends nil to out
131+
}
132+
out <- prVideo
133+
}
134+
}(w, in, out)
135+
}
96136
}
97137

98138
// nolint: gochecknoinits

metadata-manager/internal/fetcher/fetcher.go

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
package fetcher
33

44
import (
5+
"context"
56
"time"
67

78
"github.com/codingpot/paperswithcode-go/v2"
89
"github.com/codingpot/pr12er/metadata-manager/internal/transform"
910
"github.com/codingpot/pr12er/server/pkg/pr12er"
10-
"github.com/sirupsen/logrus"
11+
log "github.com/sirupsen/logrus"
1112
"google.golang.org/api/youtube/v3"
1213
"google.golang.org/protobuf/types/known/timestamppb"
1314
)
@@ -25,7 +26,7 @@ func (f *Fetcher) fetchArxivPapersInfo(paperArxivIDs []string) ([]*pr12er.Paper,
2526
var pr12erPapers []*pr12er.Paper
2627

2728
for _, arxivID := range paperArxivIDs {
28-
logrus.WithField("arxivID", arxivID).Info("processing a paper")
29+
log.WithField("arxivID", arxivID).Info("processing a paper")
2930
params := paperswithcode_go.PaperListParamsDefault()
3031
params.ArxivID = arxivID
3132
papers, err := f.client.PaperList(params)
@@ -68,42 +69,88 @@ func (f *Fetcher) fetchArxivPapersInfo(paperArxivIDs []string) ([]*pr12er.Paper,
6869
return pr12erPapers, nil
6970
}
7071

71-
func (f *Fetcher) fetchYouTubeVideoInfo(videoID string) (*pr12er.YouTubeVideo, error) {
72-
logrus.WithField("videoID", videoID).Info("fetching YouTube video info")
73-
74-
part := []string{"contentDetails", "snippet", "statistics"}
75-
call := f.youtubeService.Videos.List(part).Id(videoID)
76-
resp, err := call.Do()
77-
if err != nil {
78-
return nil, err
72+
// FetchYouTubeVideos fetches YouTubeVideo and returns a map[PR-ID]Video.
73+
// Because we can't send 200+ IDs requests at once, we use a wrapper function to split by batchSize.
74+
// We need to return the map so that it can plug back to the correct PR video.
75+
func (f *Fetcher) FetchYouTubeVideos(videoIDToPr map[string]int32, batchSize int) (map[int32]*pr12er.YouTubeVideo, error) {
76+
videoIDs := make([]string, len(videoIDToPr))
77+
i := 0
78+
for videoID := range videoIDToPr {
79+
videoIDs[i] = videoID
80+
i++
7981
}
8082

81-
// make video information
82-
youTubeVideo := pr12er.YouTubeVideo{}
83-
youTubeVideo.VideoId = videoID
84-
if len(resp.Items) > 0 {
85-
youTubeVideo.VideoTitle = resp.Items[0].Snippet.Title
83+
ret := make(map[int32]*pr12er.YouTubeVideo)
84+
85+
for i := 0; i < len(videoIDs); i += batchSize {
86+
end := i + batchSize
87+
if len(videoIDs) < end {
88+
end = len(videoIDs)
89+
}
8690

87-
ts, err := time.Parse(time.RFC3339, resp.Items[0].Snippet.PublishedAt)
91+
// Get the batch response.
92+
videos, err := f.FetchMultiYouTubeVideoByIDs(videoIDs[i:end])
8893
if err != nil {
8994
return nil, err
9095
}
91-
youTubeVideo.PublishedDate = timestamppb.New(ts)
92-
youTubeVideo.NumberOfLikes = int64(resp.Items[0].Statistics.LikeCount)
93-
youTubeVideo.NumberOfViews = int64(resp.Items[0].Statistics.ViewCount)
94-
youTubeVideo.Uploader = resp.Items[0].Snippet.ChannelTitle
96+
97+
for _, video := range videos {
98+
prID := videoIDToPr[video.GetVideoId()]
99+
ret[prID] = video
100+
}
95101
}
96102

97-
return &youTubeVideo, nil
103+
return ret, nil
98104
}
99105

100-
// FetchPrVideo fetches YouTubeVideo and Papers information.
101-
func (f *Fetcher) FetchPrVideo(prRow *pr12er.MappingTableRow) (*pr12er.PrVideo, error) {
102-
video, err := f.fetchYouTubeVideoInfo(prRow.YoutubeVideoId)
106+
// FetchMultiYouTubeVideoByIDs is a low level function that returns videos by its IDs.
107+
// If there is a next page token, it will iterate each page.
108+
func (f *Fetcher) FetchMultiYouTubeVideoByIDs(videoIDs []string) ([]*pr12er.YouTubeVideo, error) {
109+
log.WithField("videoIDs", videoIDs).Info("fetching YouTube")
110+
111+
part := []string{"contentDetails", "snippet", "statistics"}
112+
113+
var ret []*pr12er.YouTubeVideo
114+
err := f.youtubeService.Videos.List(part).Id(videoIDs...).
115+
Pages(context.Background(), func(response *youtube.VideoListResponse) error {
116+
videos, err := handleResponse(response)
117+
if err != nil {
118+
return err
119+
}
120+
ret = append(ret, videos...)
121+
return nil
122+
})
103123
if err != nil {
104124
return nil, err
105125
}
106126

127+
return ret, nil
128+
}
129+
130+
func handleResponse(resp *youtube.VideoListResponse) ([]*pr12er.YouTubeVideo, error) {
131+
ret := make([]*pr12er.YouTubeVideo, len(resp.Items))
132+
133+
for i, item := range resp.Items {
134+
ts, err := time.Parse(time.RFC3339, item.Snippet.PublishedAt)
135+
if err != nil {
136+
return nil, err
137+
}
138+
139+
ret[i] = &pr12er.YouTubeVideo{
140+
VideoId: item.Id,
141+
VideoTitle: item.Snippet.Title,
142+
NumberOfLikes: int64(item.Statistics.LikeCount),
143+
NumberOfViews: int64(item.Statistics.ViewCount),
144+
PublishedDate: timestamppb.New(ts),
145+
Uploader: item.Snippet.ChannelTitle,
146+
}
147+
}
148+
149+
return ret, nil
150+
}
151+
152+
// FetchOnlyPapers fetches papers without video information.
153+
func (f *Fetcher) FetchOnlyPapers(prRow *pr12er.MappingTableRow) (*pr12er.PrVideo, error) {
107154
papers, err := f.fetchArxivPapersInfo(prRow.PaperArxivId)
108155
if err != nil {
109156
return nil, err
@@ -112,6 +159,5 @@ func (f *Fetcher) FetchPrVideo(prRow *pr12er.MappingTableRow) (*pr12er.PrVideo,
112159
return &pr12er.PrVideo{
113160
PrId: prRow.GetPrId(),
114161
Papers: papers,
115-
Video: video,
116162
}, nil
117163
}
Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package fetcher
22

33
import (
4+
"context"
5+
"os"
46
"testing"
5-
"time"
67

78
"github.com/codingpot/paperswithcode-go/v2"
89
"github.com/codingpot/pr12er/server/pkg/pr12er"
10+
"github.com/google/go-cmp/cmp"
911
"github.com/stretchr/testify/assert"
10-
"google.golang.org/protobuf/types/known/timestamppb"
12+
"google.golang.org/api/option"
13+
"google.golang.org/api/youtube/v3"
14+
"google.golang.org/protobuf/testing/protocmp"
1115
)
1216

1317
func TestFetchArxivPapersInfo(t *testing.T) {
@@ -19,24 +23,63 @@ func TestFetchArxivPapersInfo(t *testing.T) {
1923
assert.Equal(t, 1, len(papers))
2024
}
2125

22-
func TestFetchYouTubeVideoInfo(t *testing.T) {
23-
t.Skip("TODO: Replace YOUTUBE API with a mock")
24-
youtubeID := "L3hz57whyNw"
26+
func TestFetcher_FetchMultipleYouTubeVideoIDs(t *testing.T) {
27+
t.SkipNow()
28+
type args struct {
29+
videoIDs []string
30+
}
2531

26-
ts, _ := time.Parse(time.RFC3339, "2017-04-22T05:36:37Z")
27-
expectedVideo := &pr12er.YouTubeVideo{
28-
VideoId: youtubeID,
29-
VideoTitle: "PR-001: Generative adversarial nets by Jaejun Yoo (2017/4/13)",
30-
PublishedDate: timestamppb.New(ts),
31-
Uploader: "Sung Kim",
32+
tests := []struct {
33+
name string
34+
args args
35+
want []*pr12er.YouTubeVideo
36+
wantErr bool
37+
}{
38+
{
39+
name: "Returns multiple vidoes",
40+
args: args{
41+
videoIDs: []string{
42+
"iQVvhLxGAt8",
43+
"Kgh88DLHHTo",
44+
},
45+
},
46+
want: []*pr12er.YouTubeVideo{
47+
{
48+
VideoId: "iQVvhLxGAt8",
49+
VideoTitle: "PR-326: VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning",
50+
Uploader: "만끽 MaanGeek",
51+
},
52+
{
53+
VideoId: "Kgh88DLHHTo",
54+
VideoTitle: "PR-325: Pixel-BERT: Aligning Image Pixels with Text by Deep Multi-Modal Transformers",
55+
Uploader: "Sunghoon Joo",
56+
},
57+
},
58+
wantErr: false,
59+
},
3260
}
61+
for _, tt := range tests {
62+
t.Run(tt.name, func(t *testing.T) {
63+
youtubeService, _ := youtube.NewService(context.Background(),
64+
option.WithAPIKey(os.Getenv("YOUTUBE_API_KEY")))
3365

34-
c := New(paperswithcode_go.NewClient(), nil)
35-
actualVideo, err := c.fetchYouTubeVideoInfo(youtubeID)
36-
assert.NoError(t, err)
66+
f := &Fetcher{
67+
client: paperswithcode_go.NewClient(),
68+
youtubeService: youtubeService,
69+
}
70+
got, err := f.FetchMultiYouTubeVideoByIDs(tt.args.videoIDs)
3771

38-
assert.Equal(t, expectedVideo.VideoId, actualVideo.VideoId)
39-
assert.Equal(t, expectedVideo.VideoTitle, actualVideo.VideoTitle)
40-
assert.Equal(t, expectedVideo.PublishedDate, actualVideo.PublishedDate)
41-
assert.Equal(t, expectedVideo.Uploader, actualVideo.Uploader)
72+
if tt.wantErr {
73+
assert.Error(t, err)
74+
} else {
75+
assert.NoError(t, err)
76+
if diff := cmp.Diff(tt.want, got,
77+
protocmp.IgnoreFields(&pr12er.YouTubeVideo{},
78+
"number_of_likes", "number_of_views", "published_date"),
79+
protocmp.Transform()); diff != "" {
80+
t.Error(diff)
81+
}
82+
}
83+
})
84+
}
4285
}

0 commit comments

Comments
 (0)