Skip to content

Commit 828bc01

Browse files
committed
improve log2pg: add CORS, return id etc
1 parent a4fa2f0 commit 828bc01

File tree

1 file changed

+64
-7
lines changed

1 file changed

+64
-7
lines changed

cmd/log2pg/main.go

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,60 @@ package main
22

33
import (
44
"database/sql"
5+
"encoding/json"
56
"flag"
67
"fmt"
78
"io"
89
"log"
910
"net"
1011
"net/http"
1112
"os"
13+
"strings"
1214

1315
"github.com/lib/pq"
1416
)
1517

1618
type LogReceiver struct {
1719
InsertStatement *sql.Stmt
1820
AddURL bool
21+
AllowedOrigins map[string]bool
22+
}
23+
24+
type ArrayVarType []string
25+
26+
func (vt *ArrayVarType) String() string {
27+
return strings.Join(*vt, ", ")
28+
}
29+
30+
func (vt *ArrayVarType) Set(v string) error {
31+
*vt = append(*vt, v)
32+
return nil
1933
}
2034

2135
func (lr LogReceiver) ServeHTTP(w http.ResponseWriter, r *http.Request) {
22-
if body, err := io.ReadAll(r.Body); err != nil {
36+
origin := r.Header.Get("Origin")
37+
allowOrigin := lr.AllowedOrigins == nil
38+
if !allowOrigin && lr.AllowedOrigins[origin] {
39+
allowOrigin = true
40+
}
41+
if !allowOrigin {
42+
log.Printf("Rejected %d bytes via %s to %s from %s %v",
43+
r.ContentLength, r.Method, r.RequestURI, r.RemoteAddr, r.Header)
44+
w.WriteHeader(http.StatusForbidden)
45+
return
46+
}
47+
if r.Method == "OPTIONS" && origin != "" && r.Header.Get("Access-Control-Request-Method") != "" {
48+
w.Header().Add("Access-Control-Allow-Origin", origin)
49+
w.Header().Add("Access-Control-Allow-Methods", "POST")
50+
if rqh := r.Header.Get("Access-Control-Request-Headers"); rqh != "" {
51+
w.Header().Add("Access-Control-Allow-Headers", rqh)
52+
}
53+
w.WriteHeader(http.StatusNoContent)
54+
return
55+
}
56+
if r.Body == http.NoBody {
57+
log.Printf("No body received from %s", r.RemoteAddr)
58+
} else if body, err := io.ReadAll(r.Body); err != nil {
2359
log.Printf("Could not read body")
2460
} else {
2561
var insertId int
@@ -28,13 +64,24 @@ func (lr LogReceiver) ServeHTTP(w http.ResponseWriter, r *http.Request) {
2864
insertArgs = append(insertArgs, r.RequestURI)
2965
}
3066
if err := lr.InsertStatement.QueryRow(insertArgs...).Scan(&insertId); err != nil {
31-
log.Printf("Error executing statement: %s", err)
67+
log.Printf("Error executing statement: %s, args: %#v", err, insertArgs)
3268
} else {
3369
log.Printf("Received log entry #%d from %s", insertId, r.RemoteAddr)
34-
w.WriteHeader(http.StatusNoContent)
70+
if origin != "" {
71+
w.Header().Set("Access-Control-Allow-Origin", origin)
72+
}
73+
w.Header().Set("Content-Type", "application/json")
74+
d, err := json.Marshal(insertId)
75+
if err != nil {
76+
log.Printf("error encoding %#v to response: %s", insertId, err)
77+
w.WriteHeader(http.StatusInternalServerError)
78+
return
79+
}
80+
if _, err := w.Write(d); err != nil {
81+
log.Printf("error writing response: %s", err)
82+
}
3583
return
3684
}
37-
3885
}
3986
w.WriteHeader(http.StatusBadRequest)
4087
}
@@ -44,6 +91,8 @@ func main() {
4491
dsn := flag.String("dsn", "sslmode=disable", "Postgresql DSN")
4592
dbTable := flag.String("table", "log", "Table in database receiving log")
4693
addUrl := flag.Bool("add-url", true, "Add RequestURI to 'url' column if it exists")
94+
var origins ArrayVarType
95+
flag.Var(&origins, "origin", "allowed origins (multi-arg)")
4796
tableCreateSql := flag.String("create", "CREATE TABLE \"%s\" (id BIGSERIAL PRIMARY KEY, stamp TIMESTAMPTZ DEFAULT now(), url text, src text, msg JSONB)", "Table creation SQL")
4897
flag.Parse()
4998

@@ -53,12 +102,12 @@ func main() {
53102
}
54103
returnValue := "id"
55104
if _, err := db.Exec(fmt.Sprintf("SELECT 1 FROM \"%s\" WHERE 1=0 AND src IS NOT NULL and msg IS NOT NULL", *dbTable)); err != nil {
56-
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "42P01" {
105+
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "undefined_table" {
57106
if _, err := db.Exec(fmt.Sprintf(*tableCreateSql, *dbTable)); err != nil {
58107
log.Fatalf("Creating table %#v failed: %s\n", *dbTable, err)
59108
}
60109
log.Printf("Created table %#v in database", *dbTable)
61-
} else if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "42501" {
110+
} else if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "insufficient_privilege" {
62111
log.Printf("Skipping table schema check, no permission to SELECT from %#v: %s", *dbTable, err)
63112
returnValue = "-1"
64113
} else {
@@ -93,7 +142,15 @@ func main() {
93142
if err != nil {
94143
log.Fatalf("Listen on %#v failed: %s", *listenFlag, err)
95144
}
96-
if err := http.Serve(ln, LogReceiver{InsertStatement: sqlInsert, AddURL: haveUrl}); err != nil {
145+
var originMap map[string]bool
146+
for _, origin := range origins {
147+
if originMap == nil {
148+
originMap = map[string]bool{}
149+
}
150+
originMap[origin] = true
151+
}
152+
153+
if err := http.Serve(ln, LogReceiver{InsertStatement: sqlInsert, AddURL: haveUrl, AllowedOrigins: originMap}); err != nil {
97154
log.Fatalf("Could not start HTTP server: %s", err)
98155
}
99156
}

0 commit comments

Comments
 (0)