@@ -2,24 +2,60 @@ package main
2
2
3
3
import (
4
4
"database/sql"
5
+ "encoding/json"
5
6
"flag"
6
7
"fmt"
7
8
"io"
8
9
"log"
9
10
"net"
10
11
"net/http"
11
12
"os"
13
+ "strings"
12
14
13
15
"github.com/lib/pq"
14
16
)
15
17
16
18
type LogReceiver struct {
17
19
InsertStatement * sql.Stmt
18
20
AddURL bool
21
+ AllowedOrigins map [string ]bool
22
+ }
23
+
24
+ type ArrayVarType []string
25
+
26
+ func (vt * ArrayVarType ) String () string {
27
+ return strings .Join (* vt , ", " )
28
+ }
29
+
30
+ func (vt * ArrayVarType ) Set (v string ) error {
31
+ * vt = append (* vt , v )
32
+ return nil
19
33
}
20
34
21
35
func (lr LogReceiver ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
22
- if body , err := io .ReadAll (r .Body ); err != nil {
36
+ origin := r .Header .Get ("Origin" )
37
+ allowOrigin := lr .AllowedOrigins == nil
38
+ if ! allowOrigin && lr .AllowedOrigins [origin ] {
39
+ allowOrigin = true
40
+ }
41
+ if ! allowOrigin {
42
+ log .Printf ("Rejected %d bytes via %s to %s from %s %v" ,
43
+ r .ContentLength , r .Method , r .RequestURI , r .RemoteAddr , r .Header )
44
+ w .WriteHeader (http .StatusForbidden )
45
+ return
46
+ }
47
+ if r .Method == "OPTIONS" && origin != "" && r .Header .Get ("Access-Control-Request-Method" ) != "" {
48
+ w .Header ().Add ("Access-Control-Allow-Origin" , origin )
49
+ w .Header ().Add ("Access-Control-Allow-Methods" , "POST" )
50
+ if rqh := r .Header .Get ("Access-Control-Request-Headers" ); rqh != "" {
51
+ w .Header ().Add ("Access-Control-Allow-Headers" , rqh )
52
+ }
53
+ w .WriteHeader (http .StatusNoContent )
54
+ return
55
+ }
56
+ if r .Body == http .NoBody {
57
+ log .Printf ("No body received from %s" , r .RemoteAddr )
58
+ } else if body , err := io .ReadAll (r .Body ); err != nil {
23
59
log .Printf ("Could not read body" )
24
60
} else {
25
61
var insertId int
@@ -28,13 +64,24 @@ func (lr LogReceiver) ServeHTTP(w http.ResponseWriter, r *http.Request) {
28
64
insertArgs = append (insertArgs , r .RequestURI )
29
65
}
30
66
if err := lr .InsertStatement .QueryRow (insertArgs ... ).Scan (& insertId ); err != nil {
31
- log .Printf ("Error executing statement: %s" , err )
67
+ log .Printf ("Error executing statement: %s, args: %#v " , err , insertArgs )
32
68
} else {
33
69
log .Printf ("Received log entry #%d from %s" , insertId , r .RemoteAddr )
34
- w .WriteHeader (http .StatusNoContent )
70
+ if origin != "" {
71
+ w .Header ().Set ("Access-Control-Allow-Origin" , origin )
72
+ }
73
+ w .Header ().Set ("Content-Type" , "application/json" )
74
+ d , err := json .Marshal (insertId )
75
+ if err != nil {
76
+ log .Printf ("error encoding %#v to response: %s" , insertId , err )
77
+ w .WriteHeader (http .StatusInternalServerError )
78
+ return
79
+ }
80
+ if _ , err := w .Write (d ); err != nil {
81
+ log .Printf ("error writing response: %s" , err )
82
+ }
35
83
return
36
84
}
37
-
38
85
}
39
86
w .WriteHeader (http .StatusBadRequest )
40
87
}
@@ -44,6 +91,8 @@ func main() {
44
91
dsn := flag .String ("dsn" , "sslmode=disable" , "Postgresql DSN" )
45
92
dbTable := flag .String ("table" , "log" , "Table in database receiving log" )
46
93
addUrl := flag .Bool ("add-url" , true , "Add RequestURI to 'url' column if it exists" )
94
+ var origins ArrayVarType
95
+ flag .Var (& origins , "origin" , "allowed origins (multi-arg)" )
47
96
tableCreateSql := flag .String ("create" , "CREATE TABLE \" %s\" (id BIGSERIAL PRIMARY KEY, stamp TIMESTAMPTZ DEFAULT now(), url text, src text, msg JSONB)" , "Table creation SQL" )
48
97
flag .Parse ()
49
98
@@ -53,12 +102,12 @@ func main() {
53
102
}
54
103
returnValue := "id"
55
104
if _ , err := db .Exec (fmt .Sprintf ("SELECT 1 FROM \" %s\" WHERE 1=0 AND src IS NOT NULL and msg IS NOT NULL" , * dbTable )); err != nil {
56
- if pqErr , ok := err .(* pq.Error ); ok && pqErr .Code == "42P01 " {
105
+ if pqErr , ok := err .(* pq.Error ); ok && pqErr .Code . Name () == "undefined_table " {
57
106
if _ , err := db .Exec (fmt .Sprintf (* tableCreateSql , * dbTable )); err != nil {
58
107
log .Fatalf ("Creating table %#v failed: %s\n " , * dbTable , err )
59
108
}
60
109
log .Printf ("Created table %#v in database" , * dbTable )
61
- } else if pqErr , ok := err .(* pq.Error ); ok && pqErr .Code == "42501 " {
110
+ } else if pqErr , ok := err .(* pq.Error ); ok && pqErr .Code . Name () == "insufficient_privilege " {
62
111
log .Printf ("Skipping table schema check, no permission to SELECT from %#v: %s" , * dbTable , err )
63
112
returnValue = "-1"
64
113
} else {
@@ -93,7 +142,15 @@ func main() {
93
142
if err != nil {
94
143
log .Fatalf ("Listen on %#v failed: %s" , * listenFlag , err )
95
144
}
96
- if err := http .Serve (ln , LogReceiver {InsertStatement : sqlInsert , AddURL : haveUrl }); err != nil {
145
+ var originMap map [string ]bool
146
+ for _ , origin := range origins {
147
+ if originMap == nil {
148
+ originMap = map [string ]bool {}
149
+ }
150
+ originMap [origin ] = true
151
+ }
152
+
153
+ if err := http .Serve (ln , LogReceiver {InsertStatement : sqlInsert , AddURL : haveUrl , AllowedOrigins : originMap }); err != nil {
97
154
log .Fatalf ("Could not start HTTP server: %s" , err )
98
155
}
99
156
}
0 commit comments