Skip to content

Commit 38c0a30

Browse files
authored
Merge pull request #122 from codingpot/serve-with-database
feat: serve with database instead of mock
2 parents a346185 + d934f2e commit 38c0a30

File tree

2 files changed

+52
-96
lines changed

2 files changed

+52
-96
lines changed

server/pkg/serv/serv.go

Lines changed: 4 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ package serv
33
import (
44
"context"
55
"fmt"
6-
"time"
76

8-
"github.com/codingpot/pr12er/server/internal"
7+
"github.com/codingpot/pr12er/server/internal/data"
8+
"github.com/codingpot/pr12er/server/pkg/handlers"
99
"github.com/codingpot/pr12er/server/pkg/pr12er"
10-
"google.golang.org/protobuf/types/known/timestamppb"
1110
)
1211

1312
type Server struct {
@@ -17,98 +16,13 @@ type Server struct {
1716
var _ pr12er.Pr12ErServiceServer = (*Server)(nil)
1817

1918
func (s Server) GetDetail(_ context.Context, in *pr12er.GetDetailRequest) (*pr12er.GetDetailResponse, error) {
20-
// Returns details of this ID
21-
searchPRID := in.GetPrId()
22-
23-
resp := &pr12er.GetDetailResponse{
24-
Detail: &pr12er.Detail{
25-
PrId: searchPRID,
26-
Paper: []*pr12er.Paper{
27-
{
28-
PaperId: "1",
29-
// nolint: lll
30-
Abstract: "We propose a new framework for estimating generative models via an adversarial process, in which we simultaneously train two models: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G. The training procedure for G is to maximize the probability of D making a mistake. This framework corresponds to a minimax two-player game. In the space of arbitrary functions G and D, a unique solution exists, with G recovering the training data distribution and D equal to 1/2 everywhere. In the case where G and D are defined by multilayer perceptrons, the entire system can be trained with backpropagation. There is no need for any Markov chains or unrolled approximate inference networks during either training or generation of samples. Experiments demonstrate the potential of the framework through qualitative and quantitative evaluation of the generated samples.",
31-
Repositories: []*pr12er.Repository{
32-
{
33-
Framework: pr12er.Framework_FRAMEWORK_TENSORFLOW,
34-
Owner: "goodfeli",
35-
Url: "https://github.com/tensorflow/tensorflow",
36-
},
37-
{
38-
Framework: pr12er.Framework_FRAMEWORK_PYTORCH,
39-
Owner: "eriklindernoren",
40-
Url: "https://github.com/pytorch/pytorch",
41-
},
42-
{
43-
Framework: pr12er.Framework_FRAMEWORK_TENSORFLOW,
44-
Owner: "google-research",
45-
Url: "https://github.com/tensorflow/tensorflow",
46-
},
47-
{
48-
Framework: pr12er.Framework_FRAMEWORK_PYTORCH,
49-
Owner: "eriklindernoren",
50-
Url: "https://github.com/pytorch/pytorch",
51-
},
52-
},
53-
},
54-
},
55-
SameAuthorPapers: []*pr12er.Paper{
56-
{
57-
Title: "On distinguishability criteria for estimating generative models",
58-
Authors: []string{"Ian J. Goodfellow"},
59-
PubDate: timestamppb.New(newDate(2015, 5, 21)),
60-
},
61-
},
62-
RelevantPapers: []*pr12er.Paper{
63-
{
64-
Title: "Learning to Efficiently Sample from Diffusion Probabilistic Models",
65-
Authors: []string{"Daniel Watson"},
66-
PubDate: timestamppb.New(newDate(2021, 6, 7)),
67-
},
68-
},
69-
},
70-
}
71-
72-
return resp, nil
73-
}
74-
75-
// A little helper function to create a Ymd.
76-
func newDate(y int, m time.Month, d int) time.Time {
77-
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
19+
return handlers.DetailResponseFromDB(in.GetPrId(), &data.DB)
7820
}
7921

8022
func (s Server) GetHello(_ context.Context, in *pr12er.HelloRequest) (*pr12er.HelloResponse, error) {
8123
return &pr12er.HelloResponse{Body: fmt.Sprintf("Hello %s", in.Body)}, nil
8224
}
8325

8426
func (s Server) GetVideos(_ context.Context, _ *pr12er.GetVideosRequest) (*pr12er.GetVideosResponse, error) {
85-
return getVideosFromDumpedPbtxt()
86-
}
87-
88-
func getVideosFromDumpedPbtxt() (*pr12er.GetVideosResponse, error) {
89-
var resp pr12er.GetVideosResponse
90-
metadataDump := internal.ReadPR12MetadataProtoText()
91-
92-
resp.Videos = make([]*pr12er.Video, 0, len(metadataDump.Metadata))
93-
94-
for _, metadata := range metadataDump.Metadata {
95-
video := &pr12er.Video{
96-
PrId: metadata.GetId(),
97-
Title: metadata.GetTitle(),
98-
Presenter: metadata.GetPresenter(),
99-
Keywords: metadata.GetKeywords(),
100-
101-
// TODO: Update category and number of likes.
102-
Category: pr12er.Category_CATEGORY_UNSPECIFIED,
103-
NumberOfLike: 0,
104-
}
105-
106-
videoMetadata := metadata.GetVideoMetadata()
107-
if len(videoMetadata) > 0 {
108-
video.Link = videoMetadata[0].Url
109-
}
110-
111-
resp.Videos = append(resp.Videos, video)
112-
}
113-
return &resp, nil
27+
return handlers.VideosResponseFromDB(&data.DB), nil
11428
}

server/pkg/serv/serv_test.go

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,52 @@ func TestGetVideos(t *testing.T) {
2929
assert.Greater(t, len(resp.Videos), 0)
3030
}
3131

32-
func TestServer_GetDetails(t *testing.T) {
33-
s := Server{}
34-
request := pr12er.GetDetailRequest{PrId: 1}
35-
got, err := s.GetDetail(context.Background(), &request)
36-
assert.NoError(t, err)
37-
assert.Equal(t, int32(1), got.GetDetail().GetPrId())
32+
func TestServer_GetDetail(t *testing.T) {
33+
type args struct {
34+
ctx context.Context
35+
req *pr12er.GetDetailRequest
36+
}
37+
tests := []struct {
38+
name string
39+
args args
40+
want *pr12er.GetDetailResponse
41+
wantErr bool
42+
}{
43+
{
44+
name: "Returns an error if the PR is not found with the given ID",
45+
args: args{
46+
ctx: context.Background(),
47+
req: &pr12er.GetDetailRequest{
48+
PrId: 0,
49+
},
50+
},
51+
want: nil,
52+
wantErr: true,
53+
},
54+
{
55+
name: "Returns a valid response if the PR is found",
56+
args: args{
57+
ctx: context.Background(),
58+
req: &pr12er.GetDetailRequest{PrId: 1},
59+
},
60+
want: &pr12er.GetDetailResponse{Detail: &pr12er.Detail{
61+
PrId: 1,
62+
}},
63+
wantErr: false,
64+
},
65+
}
66+
67+
for _, tt := range tests {
68+
t.Run(tt.name, func(t *testing.T) {
69+
s := Server{}
70+
got, err := s.GetDetail(tt.args.ctx, tt.args.req)
71+
if tt.wantErr {
72+
assert.Error(t, err)
73+
} else {
74+
assert.NoError(t, err)
75+
assert.Equal(t, tt.want.GetDetail().GetPrId(), got.GetDetail().GetPrId())
76+
assert.Greater(t, len(got.GetDetail().GetPaper()), 0)
77+
}
78+
})
79+
}
3880
}

0 commit comments

Comments
 (0)