+ You may now close this page and return to the application. +
+ + +` + +var _ Getter = Browser{} diff --git a/oauth2client/localapp.go b/oauth2client/localapp.go new file mode 100644 index 0000000..d287329 --- /dev/null +++ b/oauth2client/localapp.go @@ -0,0 +1,57 @@ +package oauth2client + +import ( + "context" + "fmt" + + "golang.org/x/oauth2" +) + +// LocalAppSource implements oauth2.TokenSource for +// OAuth2 client apps that have the client app +// credentials (Client ID and Secret) available +// locally. The OAuth2 provider is accessed directly +// using the OAuth2Config field value. +// +// LocalAppSource values can be ephemeral. +type LocalAppSource struct { + // OAuth2Config is the OAuth2 configuration. + OAuth2Config *oauth2.Config + + // AuthCodeGetter is how the auth code + // is obtained. If not set, a default + // oauth2code.Browser is used. + AuthCodeGetter Getter +} + +// Config returns an OAuth2 config. +func (s LocalAppSource) Config() *oauth2.Config { + return s.OAuth2Config +} + +// Token obtains a token using s.OAuth2Config. +func (s LocalAppSource) Token() (*oauth2.Token, error) { + if s.OAuth2Config == nil { + return nil, fmt.Errorf("missing OAuth2Config") + } + if s.AuthCodeGetter == nil { + s.AuthCodeGetter = Browser{} + } + + cfg := s.Config() + + stateVal := State() + authURL := cfg.AuthCodeURL(stateVal, oauth2.AccessTypeOffline) + + code, err := s.AuthCodeGetter.Get(stateVal, authURL) + if err != nil { + return nil, fmt.Errorf("getting code via browser: %v", err) + } + + ctx := context.WithValue(context.Background(), + oauth2.HTTPClient, httpClient) + + return cfg.Exchange(ctx, code) +} + +var _ App = LocalAppSource{} diff --git a/oauth2client/oauth2.go b/oauth2client/oauth2.go new file mode 100644 index 0000000..39e3203 --- /dev/null +++ b/oauth2client/oauth2.go @@ -0,0 +1,62 @@ +package oauth2client + +import ( + mathrand "math/rand" + "net/http" + "time" + + "golang.org/x/oauth2" +) + +func init() { + mathrand.Seed(time.Now().UnixNano()) +} + +// Getter is a type that can get an OAuth2 auth code. +// It must enforce that the state parameter of the +// redirected request matches expectedStateVal. +type Getter interface { + Get(expectedStateVal, authCodeURL string) (code string, err error) +} + +// State returns a random string suitable as a state value. +func State() string { + return randString(14) +} + +// randString is not safe for cryptographic use. +func randString(n int) string { + const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[mathrand.Intn(len(letterBytes))] + } + return string(b) +} + +type ( + // OAuth2Info contains information for obtaining an auth code. + OAuth2Info struct { + StateValue string + AuthCodeURL string + } + + // App provides methods for obtaining an + // OAuth2 config and an initial token. + App interface { + oauth2.TokenSource + Config() *oauth2.Config + } +) + +// httpClient is the HTTP client to use for OAuth2 requests. +var httpClient = &http.Client{ + Timeout: 10 * time.Second, +} + +// DefaultRedirectURL is the default URL to +// which to redirect clients after a code +// has been obtained. Redirect URLs may +// have to be registered with your OAuth2 +// provider. +const DefaultRedirectURL = "http://localhost:8008/oauth2-redirect" diff --git a/oauth2client/oauth2proxy/cmd/oauth2proxy/main.go b/oauth2client/oauth2proxy/cmd/oauth2proxy/main.go new file mode 100644 index 0000000..0e5afd4 --- /dev/null +++ b/oauth2client/oauth2proxy/cmd/oauth2proxy/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "flag" + "log" + "net/http" + + "github.com/BurntSushi/toml" + "github.com/mholt/timeliner/oauth2client/oauth2proxy" + "golang.org/x/oauth2" +) + +func init() { + flag.StringVar(&credentialsFile, "credentials", credentialsFile, "The path to the file containing the OAuth2 app credentials for each provider") + flag.StringVar(&addr, "addr", addr, "The address to listen on") + flag.StringVar(&basePath, "path", basePath, "The base path on which to serve the proxy endpoints") +} + +var ( + credentialsFile = "credentials.toml" + addr = ":7233" + basePath = "/oauth2" +) + +func main() { + flag.Parse() + + if credentialsFile == "" { + log.Fatal("[FATAL] No credentials file specified (use -credentials)") + } + if addr == "" { + log.Fatal("[FATAL] No address specified (use -addr)") + } + + // decode app credentials + var creds oauth2Credentials + md, err := toml.DecodeFile(credentialsFile, &creds) + if err != nil { + log.Fatalf("[FATAL] Decoding credentials file: %v", err) + } + if len(md.Undecoded()) > 0 { + log.Fatalf("[FATAL] Unrecognized key(s) in credentials file: %+v", md.Undecoded()) + } + + // convert them into oauth2.Configs (the structure of + // oauth2.Config as TOML is too verbose for my taste) + oauth2Configs := make(map[string]oauth2.Config) + for id, prov := range creds.Providers { + oauth2Configs[id] = oauth2.Config{ + ClientID: prov.ClientID, + ClientSecret: prov.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: prov.AuthURL, + TokenURL: prov.TokenURL, + }, + } + log.Println("Provider:", id) + } + + log.Println("Serving OAuth2 proxy on", addr) + + p := oauth2proxy.New(basePath, oauth2Configs) + http.ListenAndServe(addr, p) +} + +type oauth2Credentials struct { + Providers map[string]oauth2ProviderConfig `toml:"providers"` +} + +type oauth2ProviderConfig struct { + ClientID string `toml:"client_id"` + ClientSecret string `toml:"client_secret"` + AuthURL string `toml:"auth_url"` + TokenURL string `toml:"token_url"` +} diff --git a/oauth2client/oauth2proxy/proxy.go b/oauth2client/oauth2proxy/proxy.go new file mode 100644 index 0000000..f593775 --- /dev/null +++ b/oauth2client/oauth2proxy/proxy.go @@ -0,0 +1,191 @@ +package oauth2proxy + +import ( + "encoding/json" + "io" + "io/ioutil" + "net/http" + "net/url" + "path" + "strings" + + "github.com/mholt/timeliner/oauth2client" + "golang.org/x/oauth2" +) + +// New returns a new OAuth2 proxy that serves its endpoints +// under the given basePath and which replaces credentials +// and endpoints with those found in the configs given in +// the providers map. +// +// The map value does not use pointers, so that temporary +// manipulations of the value can occur without modifying +// the original template value. +func New(basePath string, providers map[string]oauth2.Config) http.Handler { + basePath = path.Join("/", basePath) + + proxy := oauth2Proxy{providers: providers} + + mux := http.NewServeMux() + mux.HandleFunc(path.Join(basePath, "auth-code-url"), proxy.handleAuthCodeURL) + mux.HandleFunc(path.Join(basePath, "proxy")+"/", proxy.handleOAuth2) + + return mux +} + +type oauth2Proxy struct { + providers map[string]oauth2.Config +} + +func (proxy oauth2Proxy) handleAuthCodeURL(w http.ResponseWriter, r *http.Request) { + providerID := r.FormValue("provider") + redir := r.FormValue("redirect") + scopes := r.URL.Query()["scope"] + + oauth2CfgCopy, ok := proxy.providers[providerID] + if !ok { + http.Error(w, "unknown service ID", http.StatusBadRequest) + return + } + + // augment the template config with parameters specific to this + // request (this is why it's important that the configs aren't + // pointers; we should be mutating only copies here) + oauth2CfgCopy.Scopes = scopes + oauth2CfgCopy.RedirectURL = redir + + stateVal := oauth2client.State() + url := oauth2CfgCopy.AuthCodeURL(stateVal, oauth2.AccessTypeOffline) + + info := oauth2client.OAuth2Info{ + StateValue: stateVal, + AuthCodeURL: url, + } + + json.NewEncoder(w).Encode(info) +} + +func (proxy oauth2Proxy) handleOAuth2(w http.ResponseWriter, r *http.Request) { + // knead the URL into its two parts: the service + // ID and which endpoint to proxy to + // reqURL := strings.TrimPrefix(r.URL.Path, basePath+"/proxy") + // reqURL = path.Clean(strings.TrimPrefix(reqURL, "/")) + + // we want the last two components of the path + urlParts := strings.Split(r.URL.Path, "/") + if len(urlParts) < 2 { + http.Error(w, "bad path length", http.StatusBadRequest) + return + } + + providerID := urlParts[len(urlParts)-2] + whichEndpoint := urlParts[len(urlParts)-1] + + // get the OAuth2 config matching the service ID + oauth2Config, ok := proxy.providers[providerID] + if !ok { + http.Error(w, "unknown service: "+providerID, http.StatusBadRequest) + return + } + + // figure out which endpoint we'll use for upstream + var upstreamEndpoint string + switch whichEndpoint { + case "auth": + upstreamEndpoint = oauth2Config.Endpoint.AuthURL + case "token": + upstreamEndpoint = oauth2Config.Endpoint.TokenURL + } + + // read the body so we can replace values if necessary + // (don't use r.ParseForm because we need to keep body + // and query string distinct) + reqBodyBytes, err := ioutil.ReadAll(r.Body) //http.MaxBytesReader(w, r.Body, 64*1024)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // if the request body is form-encoded, replace any + // credential placeholders with the real credentials + var upstreamBody io.Reader + if strings.Contains(r.Header.Get("Content-Type"), "x-www-form-urlencoded") { + bodyForm, err := url.ParseQuery(string(reqBodyBytes)) + if err != nil { + http.Error(w, "error parsing request body", http.StatusBadRequest) + return + } + replaceCredentials(bodyForm, oauth2Config) + upstreamBody = strings.NewReader(bodyForm.Encode()) + } + + // now do the same thing for the query string + qs := r.URL.Query() + replaceCredentials(qs, oauth2Config) + + // make outgoing URL + upstreamURL, err := url.Parse(upstreamEndpoint) + if err != nil { + http.Error(w, "bad upstream URL", http.StatusInternalServerError) + return + } + upstreamURL.RawQuery = qs.Encode() + + // set the real credentials -- this has to be done + // carefully because apparently a lot of OAuth2 + // providers are broken (against RFC 6749), so + // the downstream OAuth2 client lib must be sure + // to set the credentials in the right place, and + // we should be sure to mirror that behavior; + // this means that even though the downstream may + // not have the real client ID and secret, they + // need to provide SOMETHING as bogus placeholder + // values to signal to us where to put the real + // credentials + if r.Header.Get("Authorization") != "" { + r.SetBasicAuth(oauth2Config.ClientID, oauth2Config.ClientSecret) + } + + // prepare the request to upstream + upstream, err := http.NewRequest(r.Method, upstreamURL.String(), upstreamBody) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + upstream.Header = r.Header + delete(upstream.Header, "Content-Length") + + // perform the upstream request + resp, err := http.DefaultClient.Do(upstream) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // copy the upstream headers to the response downstream + for key, vals := range resp.Header { + for _, val := range vals { + w.Header().Add(key, val) + } + } + + // carry over the status code + w.WriteHeader(resp.StatusCode) + + // copy the response body downstream + _, err = io.Copy(w, resp.Body) + if err != nil { + http.Error(w, "writing body: "+err.Error(), http.StatusBadGateway) + return + } +} + +func replaceCredentials(form url.Values, oauth2Config oauth2.Config) { + if form.Get("client_id") != "" { + form.Set("client_id", oauth2Config.ClientID) + } + if form.Get("client_secret") != "" { + form.Set("client_secret", oauth2Config.ClientSecret) + } +} diff --git a/oauth2client/remoteapp.go b/oauth2client/remoteapp.go new file mode 100644 index 0000000..377c6d8 --- /dev/null +++ b/oauth2client/remoteapp.go @@ -0,0 +1,192 @@ +package oauth2client + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "golang.org/x/oauth2" +) + +// RemoteAppSource implements oauth2.TokenSource for +// OAuth2 client apps that have their credentials +// (Client ID and Secret, as well as endpoint info) +// stored remotely. Thus, this type obtains tokens +// through a remote proxy that presumably has the +// client app credentials, which it will replace +// before proxying to the provider. +// +// RemoteAppSource values can be ephemeral. +type RemoteAppSource struct { + // How to obtain the auth URL. + // Default: DirectAuthURLMode + AuthURLMode AuthURLMode + + // The URL to the proxy server (its + // address + base path). + ProxyURL string + + // The ID of the OAuth2 provider. + ProviderID string + + // The scopes for which to obtain + // authorization. + Scopes []string + + // The URL to redirect to to finish + // the ceremony. + RedirectURL string + + // How the auth code is obtained. + // If not set, a default + // oauth2code.Browser is used. + AuthCodeGetter Getter +} + +// Config returns an OAuth2 config. +func (s RemoteAppSource) Config() *oauth2.Config { + redirURL := s.RedirectURL + if redirURL == "" { + redirURL = DefaultRedirectURL + } + + return &oauth2.Config{ + ClientID: "placeholder", + ClientSecret: "placeholder", + RedirectURL: redirURL, + Scopes: s.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: s.ProxyURL + "/proxy/" + s.ProviderID + "/auth", + TokenURL: s.ProxyURL + "/proxy/" + s.ProviderID + "/token", + }, + } +} + +// Token obtains a token. +func (s RemoteAppSource) Token() (*oauth2.Token, error) { + if s.AuthCodeGetter == nil { + s.AuthCodeGetter = Browser{} + } + if s.AuthURLMode == "" { + s.AuthURLMode = DirectAuthURLMode + } + + cfg := s.Config() + + // obtain a state value and auth URL + var stateVal, authURL string + var err error + switch s.AuthURLMode { + case DirectAuthURLMode: + stateVal, authURL, err = s.getDirectAuthURLFromProxy() + case ProxiedAuthURLMode: + stateVal, authURL, err = s.getProxiedAuthURL(cfg) + default: + return nil, fmt.Errorf("unknown AuthURLMode: %s", s.AuthURLMode) + } + if err != nil { + return nil, err + } + + // now obtain the code + code, err := s.AuthCodeGetter.Get(stateVal, authURL) + if err != nil { + return nil, fmt.Errorf("getting code via browser: %v", err) + } + + // and complete the ceremony + ctx := context.WithValue(context.Background(), + oauth2.HTTPClient, httpClient) + + return cfg.Exchange(ctx, code) +} + +// getDirectAuthURLFromProxy returns an auth URL that goes directly to the +// OAuth2 provider server, but it gets that URL by querying the proxy server +// for what it should be ("DirectAuthURLMode"). +func (s RemoteAppSource) getDirectAuthURLFromProxy() (state string, authURL string, err error) { + redirURL := s.RedirectURL + if redirURL == "" { + redirURL = DefaultRedirectURL + } + + v := url.Values{ + "provider": {s.ProviderID}, + "scope": s.Scopes, + "redirect": {redirURL}, + } + + proxyURL := strings.TrimSuffix(s.ProxyURL, "/") + resp, err := http.Get(proxyURL + "/auth-code-url?" + v.Encode()) + if err != nil { + return "", "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("requesting auth code URL from proxy: HTTP %d: %s", + resp.StatusCode, resp.Status) + } + + var info OAuth2Info + err = json.NewDecoder(resp.Body).Decode(&info) + if err != nil { + return "", "", err + } + + return info.StateValue, info.AuthCodeURL, nil +} + +// getProxiedAuthURL returns an auth URL that goes to the remote proxy ("ProxiedAuthURLMode"). +func (s RemoteAppSource) getProxiedAuthURL(cfg *oauth2.Config) (state string, authURL string, err error) { + state = State() + authURL = cfg.AuthCodeURL(state, oauth2.AccessTypeOffline) + return +} + +// AuthURLMode describes what kind of auth URL a +// RemoteAppSource should obtain. +type AuthURLMode string + +const ( + // DirectAuthURLMode queries the remote proxy to get + // an auth URL that goes directly to the OAuth2 provider + // web page the user must go to in order to obtain + // authorization. Although this mode incurs one extra + // HTTP request (that is not part of the OAuth2 spec, + // it is purely our own), it is perhaps more robust in + // more environments, since the browser will access the + // auth provider's site directly, meaning that any HTML + // or JavaScript on the page that expects HTTPS or a + // certain hostname will be able to function correctly. + DirectAuthURLMode AuthURLMode = "direct" + + // ProxiedAuthURLMode makes an auth URL that goes to + // the remote proxy, not directly to the provider. + // This is perhaps a "purer" approach than + // DirectAuthURLMode, but it may not work if HTML or + // JavaScript on the provider's auth page expects + // a certain scheme or hostname in the page's URL. + // This mode usually works when the proxy is running + // over HTTPS, but this mode may break depending on + // the provider, when the proxy uses HTTP (which + // should only be in dev environments of course). + // + // For example, Google's OAuth2 page will try to set a + // secure-context cookie using JavaScript, which fails + // if the auth page is proxied through a plaintext HTTP + // localhost endpoint, which is what we do during + // development for convenience; the lack of HTTPS caused + // the page to reload infinitely because, even though + // the request was reverse-proxied, the JS on the page + // expected HTTPS. (See my self-congratulatory tweet: + // https://twitter.com/mholt6/status/1078518306045231104) + // Using DirectAuthURLMode is the easiest way around + // this problem. + ProxiedAuthURLMode AuthURLMode = "proxied" +) + +var _ App = RemoteAppSource{} diff --git a/persons.go b/persons.go new file mode 100644 index 0000000..bf980ee --- /dev/null +++ b/persons.go @@ -0,0 +1,75 @@ +package timeliner + +import ( + "database/sql" + "fmt" +) + +// getPerson returns the person mapped to userID on service. +// If the person does not exist, it is created. +func (t *Timeline) getPerson(dataSourceID, userID, name string) (Person, error) { + // first, load the person + var p Person + err := t.db.QueryRow(`SELECT persons.id, persons.name + FROM persons, person_identities + WHERE person_identities.data_source_id=? + AND person_identities.user_id=? + AND persons.id = person_identities.person_id + LIMIT 1`, dataSourceID, userID).Scan(&p.ID, &p.Name) + if err == sql.ErrNoRows { + // person does not exist; create this mapping - TODO: do in a transaction + p = Person{Name: name} + res, err := t.db.Exec(`INSERT INTO persons (name) VALUES (?)`, p.Name) + if err != nil { + return Person{}, fmt.Errorf("adding new person: %v", err) + } + p.ID, err = res.LastInsertId() + if err != nil { + return Person{}, fmt.Errorf("getting person ID: %v", err) + } + _, err = t.db.Exec(`INSERT INTO person_identities + (person_id, data_source_id, user_id) VALUES (?, ?, ?)`, + p.ID, dataSourceID, userID) + if err != nil { + return Person{}, fmt.Errorf("adding new person identity mapping: %v", err) + } + } else if err != nil { + return Person{}, fmt.Errorf("selecting person identity: %v", err) + } + + // now get all the person's identities + rows, err := t.db.Query(`SELECT id, person_id, data_source_id, user_id + FROM person_identities WHERE person_id=?`, p.ID) + if err != nil { + return Person{}, fmt.Errorf("selecting person's known identities: %v", err) + } + defer rows.Close() + for rows.Next() { + var ident PersonIdentity + err := rows.Scan(&ident.ID, &ident.PersonID, &ident.DataSourceID, &ident.UserID) + if err != nil { + return Person{}, fmt.Errorf("loading person's identity: %v", err) + } + p.Identities = append(p.Identities, ident) + } + if err = rows.Err(); err != nil { + return Person{}, fmt.Errorf("scanning identity rows: %v", err) + } + + return p, nil +} + +// Person represents a person. +type Person struct { + ID int64 + Name string + Identities []PersonIdentity +} + +// PersonIdentity is a way to map a user ID on a service to a person. +type PersonIdentity struct { + ID int64 + PersonID string + DataSourceID string + UserID string +} diff --git a/processing.go b/processing.go new file mode 100644 index 0000000..0c63e9c --- /dev/null +++ b/processing.go @@ -0,0 +1,544 @@ +package timeliner + +import ( + "bytes" + "crypto/sha256" + "database/sql" + "encoding/base64" + "fmt" + "io" + "log" + "os" + "sync" + "time" +) + +// beginProcessing starts workers to process items that are +// obtained from ac. It returns a WaitGroup which blocks until +// all workers have finished, and a channel into which the +// service should pipe its items. +func (wc *WrappedClient) beginProcessing(cc concurrentCuckoo, reprocess, integrity bool) (*sync.WaitGroup, chan<- *ItemGraph) { + wg := new(sync.WaitGroup) + ch := make(chan *ItemGraph) + + const workers = 2 // TODO: Make configurable? + for i := 0; i < workers; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for ig := range ch { + if ig == nil { + continue + } + _, err := wc.processItemGraph(ig, &recursiveState{ + timestamp: time.Now(), + reprocess: reprocess, + integrityCheck: integrity, + seen: make(map[*ItemGraph]int64), + idmap: make(map[string]int64), + cuckoo: cc, + }) + if err != nil { + log.Printf("[ERROR][%s/%s] Processing item graph: %v", + wc.ds.ID, wc.acc.UserID, err) + } + } + }(i) + } + + return wg, ch +} + +type recursiveState struct { + timestamp time.Time + reprocess bool + integrityCheck bool + seen map[*ItemGraph]int64 // value is the item's row ID + idmap map[string]int64 // map an item's service ID to the row ID -- TODO: I don't love this... any better way? + + // the cuckoo filter pointer lives for + // the duration of the entire operation; + // it is often nil, but if it is set, + // then the service-produced ID of each + // item should be added to the filter so + // that a prune can take place when the + // entire operation is complete + cuckoo concurrentCuckoo +} + +func (wc *WrappedClient) processItemGraph(ig *ItemGraph, state *recursiveState) (int64, error) { + // don't visit a node twice + if igID, ok := state.seen[ig]; ok { + return igID, nil + } + + var igRowID int64 + + if ig.Node == nil { + // mark this node as visited + state.seen[ig] = 0 + } else { + // process root node + var err error + igRowID, err = wc.processSingleItemGraphNode(ig.Node, state) + if err != nil { + return 0, fmt.Errorf("processing node of item graph: %v", err) + } + + // mark this node as visited + state.seen[ig] = igRowID + + // map individual items to their row IDs + state.idmap[ig.Node.ID()] = igRowID + + // process all connected nodes + if ig.Edges != nil { + for connectedIG, relations := range ig.Edges { + // if node not yet visited, process it now + connectedIGRowID, visited := state.seen[connectedIG] + if !visited { + connectedIGRowID, err = wc.processItemGraph(connectedIG, state) + if err != nil { + return igRowID, fmt.Errorf("processing node of item graph: %v", err) + } + state.seen[connectedIG] = connectedIGRowID + } + + // store this item's ID for later + state.idmap[connectedIG.Node.ID()] = connectedIGRowID + + // insert relations to this connected node into DB + for _, rel := range relations { + _, err = wc.tl.db.Exec(`INSERT INTO relationships + (from_item_id, to_item_id, directed, label) + VALUES (?, ?, ?, ?)`, + igRowID, connectedIGRowID, !rel.Bidirectional, rel.Label) + if err != nil { + return igRowID, fmt.Errorf("storing item relationship: %v (from_item=%d to_item=%d directed=%t label=%v)", + err, igRowID, connectedIGRowID, !rel.Bidirectional, rel.Label) + } + } + } + } + } + + // process collections, if any + for _, coll := range ig.Collections { + // attach the item's row ID to each item in the collection + // to speed up processing; we won't have to query the database + // again for items that were already processed from the graph + for i, it := range coll.Items { + coll.Items[i].itemRowID = state.idmap[it.Item.ID()] + } + + err := wc.processCollection(coll, state.timestamp) + if err != nil { + return 0, fmt.Errorf("processing collection: %v (original_id=%s)", err, coll.OriginalID) + } + } + + // process raw relations, if any + for _, rr := range ig.Relations { + // get each item's row ID from their data source item ID + fromItemRowID, err := wc.itemRowIDFromOriginalID(rr.FromItemID) + if err == sql.ErrNoRows { + continue // item does not exist in timeline; skip this relation + } + if err != nil { + return 0, fmt.Errorf("querying 'from' item row ID: %v", err) + } + toItemRowID, err := wc.itemRowIDFromOriginalID(rr.ToItemID) + if err == sql.ErrNoRows { + continue // item does not exist in timeline; skip this relation + } + if err != nil { + return 0, fmt.Errorf("querying 'to' item row ID: %v", err) + } + + // store the relation + _, err = wc.tl.db.Exec(`INSERT INTO relationships + (from_item_id, to_item_id, directed, label) + VALUES (?, ?, ?, ?)`, + fromItemRowID, toItemRowID, rr.Bidirectional, rr.Label) + if err != nil { + return 0, fmt.Errorf("storing raw item relationship: %v (from_item=%d to_item=%d directed=%t label=%v)", + err, fromItemRowID, toItemRowID, !rr.Bidirectional, rr.Label) + } + } + + return igRowID, nil +} + +// TODO: is this function useful? +func (wc *WrappedClient) processSingleItemGraphNode(it Item, state *recursiveState) (int64, error) { + if itemID := it.ID(); itemID != "" && state.cuckoo.Filter != nil { + state.cuckoo.Lock() + state.cuckoo.InsertUnique([]byte(itemID)) + state.cuckoo.Unlock() + } + + return wc.storeItemFromService(it, state.timestamp, state.reprocess, state.integrityCheck) +} + +func (wc *WrappedClient) storeItemFromService(it Item, timestamp time.Time, reprocess, integrity bool) (int64, error) { + if it == nil { + return 0, nil + } + + // process this item only one at a time + itemOriginalID := it.ID() + itemLockID := fmt.Sprintf("%s_%d_%s", wc.ds.ID, wc.acc.ID, itemOriginalID) + itemLocks.Lock(itemLockID) + defer itemLocks.Unlock(itemLockID) + + // if there is a data file, prepare to download it + // and get its file name; but don't actually begin + // downloading it until after it is in the DB, since + // we need to know, if we encounter this item later, + // whether it was downloaded successfully; if not, + // like if the download was interrupted and we didn't + // have a chance to clean up, we can overwrite any + // existing file by that name. + rc, err := it.DataFileReader() + if err != nil { + return 0, fmt.Errorf("getting item's data file content stream: %v", err) + } + if rc != nil { + defer rc.Close() + } + + // if the item is already in our DB, load it + var ir ItemRow + if itemOriginalID != "" { + ir, err = wc.loadItemRow(wc.acc.ID, itemOriginalID) + if err != nil { + return 0, fmt.Errorf("checking for item in database: %v", err) + } + if ir.ID > 0 { + // already have it + + if !wc.shouldProcessExistingItem(it, ir, reprocess, integrity) { + return ir.ID, nil + } + + // at this point, we will be replacing the existing + // file, so move it temporarily as a safe measure, + // and also because our filename-generator will not + // allow a file to be overwritten, but we want to + // replace the existing file in this case + if ir.DataFile != nil && rc != nil { + origFile := wc.tl.fullpath(*ir.DataFile) + bakFile := wc.tl.fullpath(*ir.DataFile + ".bak") + err = os.Rename(origFile, bakFile) + if err != nil && !os.IsNotExist(err) { + return 0, fmt.Errorf("temporarily moving data file: %v", err) + } + + // if this function returns with an error, + // restore the original file in case it was + // partially written or something; otherwise + // delete the old file altogether + defer func() { + if err == nil { + err := os.Remove(bakFile) + if err != nil && !os.IsNotExist(err) { + log.Printf("[ERROR] Deleting data file backup: %v", err) + } + } else { + err := os.Rename(bakFile, origFile) + if err != nil && !os.IsNotExist(err) { + log.Printf("[ERROR] Restoring original data file from backup: %v", err) + } + } + }() + } + } + } + + var dataFileName *string + var datafile *os.File + if rc != nil { + datafile, dataFileName, err = wc.tl.openUniqueCanonicalItemDataFile(it, wc.ds.ID) + if err != nil { + return 0, fmt.Errorf("opening output data file: %v", err) + } + defer datafile.Close() + } + + // prepare the item's DB row values + err = wc.fillItemRow(&ir, it, timestamp, dataFileName) + if err != nil { + return 0, fmt.Errorf("assembling item for storage: %v", err) + } + + // TODO: Insert modified time too, if edited locally? + // TODO: On conflict, maybe we just want to ignore -- make this configurable... + _, err = wc.tl.db.Exec(`INSERT INTO items + (account_id, original_id, person_id, timestamp, stored, + class, mime_type, data_text, data_file, data_hash, metadata, + latitude, longitude) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (account_id, original_id) DO UPDATE + SET person_id=?, timestamp=?, stored=?, class=?, mime_type=?, data_text=?, + data_file=?, data_hash=?, metadata=?, latitude=?, longitude=?`, + ir.AccountID, ir.OriginalID, ir.PersonID, ir.Timestamp.Unix(), ir.Stored.Unix(), + ir.Class, ir.MIMEType, ir.DataText, ir.DataFile, ir.DataHash, ir.metaGob, + ir.Latitude, ir.Longitude, + ir.PersonID, ir.Timestamp.Unix(), ir.Stored.Unix(), ir.Class, ir.MIMEType, ir.DataText, + ir.DataFile, ir.DataHash, ir.metaGob, ir.Latitude, ir.Longitude) + if err != nil { + return 0, fmt.Errorf("storing item in database: %v (item_id=%v)", err, ir.OriginalID) + } + + // get the item's row ID (this works regardless of whether + // the last query was an insert or an update) + var itemRowID int64 + err = wc.tl.db.QueryRow(`SELECT id FROM items + WHERE account_id=? AND original_id=? LIMIT 1`, + ir.AccountID, ir.OriginalID).Scan(&itemRowID) + if err != nil && err != sql.ErrNoRows { + return 0, fmt.Errorf("getting item row ID: %v", err) + } + + // if there is a data file, download it and compute its checksum; + // then update the item's row in the DB with its name and checksum + if rc != nil && dataFileName != nil { + h := sha256.New() + err := wc.tl.downloadItemFile(rc, datafile, h) + if err != nil { + return 0, fmt.Errorf("downloading data file: %v (item_id=%v)", err, itemRowID) + } + + // now that download is complete, compute its hash + dfHash := h.Sum(nil) + b64hash := base64.StdEncoding.EncodeToString(dfHash) + + // if the exact same file (byte-for-byte) already exists, + // delete this copy and reuse the existing one + err = wc.tl.replaceWithExisting(dataFileName, b64hash, itemRowID) + if err != nil { + return 0, fmt.Errorf("replacing data file with identical existing file: %v", err) + } + + // save the file's name and hash to confirm it was downloaded successfully + _, err = wc.tl.db.Exec(`UPDATE items SET data_hash=? WHERE id=?`, // TODO: LIMIT 1... + b64hash, itemRowID) + if err != nil { + log.Printf("[ERROR][%s/%s] Updating item's data file hash in DB: %v; cleaning up data file: %s (item_id=%d)", + wc.ds.ID, wc.acc.UserID, err, datafile.Name(), itemRowID) + os.Remove(wc.tl.fullpath(*dataFileName)) + } + } + + return itemRowID, nil +} + +func (wc *WrappedClient) shouldProcessExistingItem(it Item, dbItem ItemRow, reprocess, integrity bool) bool { + // if integrity check is enabled and checksum mismatches, always reprocess + if integrity && dbItem.DataFile != nil && dbItem.DataHash != nil { + datafile, err := os.Open(wc.tl.fullpath(*dbItem.DataFile)) + if err != nil { + log.Printf("[ERROR][%s/%s] Integrity check: opening existing data file: %v; reprocessing (item_id=%d)", + wc.ds.ID, wc.acc.UserID, err, dbItem.ID) + return true + } + defer datafile.Close() + h := sha256.New() + _, err = io.Copy(h, datafile) + if err != nil { + log.Printf("[ERROR][%s/%s] Integrity check: reading existing data file: %v; reprocessing (item_id=%d)", + wc.ds.ID, wc.acc.UserID, err, dbItem.ID) + return true + } + b64hash := base64.StdEncoding.EncodeToString(h.Sum(nil)) + if b64hash != *dbItem.DataHash { + log.Printf("[ERROR][%s/%s] Integrity check: checksum mismatch: expected %s, got %s; reprocessing (item_id=%d)", + wc.ds.ID, wc.acc.UserID, *dbItem.DataHash, b64hash, dbItem.ID) + return true + } + } + + // if modified locally, do not overwrite changes + if dbItem.Modified != nil { + return false + } + + // if a data file is expected, but no completed file exists + // (i.e. its hash is missing), then reprocess to allow download + // to complete successfully this time + if dbItem.DataFile != nil && dbItem.DataHash == nil { + return true + } + + // if service reports hashes/etags and we see that it + // has changed, reprocess + if serviceHash := it.DataFileHash(); serviceHash != nil && + dbItem.Metadata != nil && + dbItem.Metadata.ServiceHash != nil && + !bytes.Equal(serviceHash, dbItem.Metadata.ServiceHash) { + return true + } + + // finally, if the user wants to reprocess anyway, then do so + return reprocess +} + +func (wc *WrappedClient) fillItemRow(ir *ItemRow, it Item, timestamp time.Time, canonicalDataFileName *string) error { + // unpack the item's information into values to use in the row + + ownerID, ownerName := it.Owner() + if ownerID == nil { + ownerID = &wc.acc.UserID // assume current account + } + if ownerName == nil { + empty := "" + ownerName = &empty + } + person, err := wc.tl.getPerson(wc.ds.ID, *ownerID, *ownerName) + if err != nil { + return fmt.Errorf("getting person associated with item: %v", err) + } + + txt, err := it.DataText() + if err != nil { + return fmt.Errorf("getting item text: %v", err) + } + + loc, err := it.Location() + if err != nil { + return fmt.Errorf("getting item location data: %v", err) + } + if loc == nil { + loc = new(Location) // avoid nil pointer dereference below + } + + // metadata (optional) needs to be gob-encoded + metadata, err := it.Metadata() + if err != nil { + return fmt.Errorf("getting item metadata: %v", err) + } + if serviceHash := it.DataFileHash(); serviceHash != nil { + metadata.ServiceHash = serviceHash + } + var metaGob []byte + if metadata != nil { + metaGob, err = metadata.encode() // use special encoding method for massive space savings + if err != nil { + return fmt.Errorf("gob-encoding metadata: %v", err) + } + } + + ir.AccountID = wc.acc.ID + ir.OriginalID = it.ID() + ir.PersonID = person.ID + ir.Timestamp = it.Timestamp() + ir.Stored = timestamp + ir.Class = it.Class() + ir.MIMEType = it.DataFileMIMEType() + ir.DataText = txt + ir.DataFile = canonicalDataFileName + ir.Metadata = metadata + ir.metaGob = metaGob + ir.Location = *loc + + return nil +} + +func (wc *WrappedClient) processCollection(coll Collection, timestamp time.Time) error { + _, err := wc.tl.db.Exec(`INSERT INTO collections + (account_id, original_id, name) VALUES (?, ?, ?) + ON CONFLICT (account_id, original_id) + DO UPDATE SET name=?`, + wc.acc.ID, coll.OriginalID, coll.Name, + coll.Name) + if err != nil { + return fmt.Errorf("inserting collection: %v", err) + } + + // get the collection's row ID, regardless of whether it was inserted or updated + var collID int64 + err = wc.tl.db.QueryRow(`SELECT id FROM collections + WHERE account_id=? AND original_id=? LIMIT 1`, + wc.acc.ID, coll.OriginalID).Scan(&collID) + if err != nil { + return fmt.Errorf("getting existing collection's row ID: %v", err) + } + + // now add all the items + // (TODO: could batch this for faster inserts) + for _, cit := range coll.Items { + if cit.itemRowID == 0 { + itID, err := wc.storeItemFromService(cit.Item, timestamp, false, false) // never reprocess or check integrity here + if err != nil { + return fmt.Errorf("adding item from collection to storage: %v", err) + } + cit.itemRowID = itID + } + + _, err = wc.tl.db.Exec(`INSERT OR IGNORE INTO collection_items + (item_id, collection_id, position) + VALUES (?, ?, ?)`, + cit.itemRowID, collID, cit.Position, cit.Position) + if err != nil { + return fmt.Errorf("adding item to collection: %v", err) + } + } + + return nil +} + +func (wc *WrappedClient) loadItemRow(accountID int64, originalID string) (ItemRow, error) { + var ir ItemRow + var metadataGob []byte + var ts, stored int64 // will convert from Unix timestamp + var modified *int64 + err := wc.tl.db.QueryRow(`SELECT + id, account_id, original_id, person_id, timestamp, stored, + modified, class, mime_type, data_text, data_file, data_hash, + metadata, latitude, longitude + FROM items WHERE account_id=? AND original_id=? LIMIT 1`, accountID, originalID).Scan( + &ir.ID, &ir.AccountID, &ir.OriginalID, &ir.PersonID, &ts, &stored, + &modified, &ir.Class, &ir.MIMEType, &ir.DataText, &ir.DataFile, &ir.DataHash, + &metadataGob, &ir.Latitude, &ir.Longitude) + if err == sql.ErrNoRows { + return ItemRow{}, nil + } + if err != nil { + return ItemRow{}, fmt.Errorf("loading item: %v", err) + } + + // the metadata is gob-encoded; decode it into the struct + ir.Metadata = new(Metadata) + err = ir.Metadata.decode(metadataGob) + if err != nil { + return ItemRow{}, fmt.Errorf("gob-decoding metadata: %v", err) + } + + ir.Timestamp = time.Unix(ts, 0) + ir.Stored = time.Unix(stored, 0) + if modified != nil { + modTime := time.Unix(*modified, 0) + ir.Modified = &modTime + } + + return ir, nil +} + +// itemRowIDFromOriginalID returns an item's row ID from the ID +// associated with the data source of wc, along with its original +// item ID from that data source. If the item does not exist, +// sql.ErrNoRows will be returned. +func (wc *WrappedClient) itemRowIDFromOriginalID(originalID string) (int64, error) { + var rowID int64 + err := wc.tl.db.QueryRow(`SELECT items.id + FROM items, accounts + WHERE items.original_id=? + AND accounts.data_source_id=? + AND items.account_id = accounts.id + LIMIT 1`, originalID, wc.ds.ID).Scan(&rowID) + return rowID, err +} + +// itemLocks is used to ensure that an item +// is not processed twice at the same time. +var itemLocks = newMapMutex() diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..0491053 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,59 @@ +package timeliner + +import ( + "net/http" + "time" +) + +// RateLimit describes a rate limit. +type RateLimit struct { + RequestsPerHour int + BurstSize int + + ticker *time.Ticker + token chan struct{} +} + +// NewRateLimitedRoundTripper adds rate limiting to rt based on the rate +// limiting policy registered by the data source associated with acc. +func (acc Account) NewRateLimitedRoundTripper(rt http.RoundTripper) http.RoundTripper { + rlKey := acc.DataSourceID + "_" + acc.UserID + + rl, ok := acc.t.rateLimiters[rlKey] + + if !ok && acc.ds.RateLimit.RequestsPerHour > 0 { + secondsBetweenReqs := 60.0 / (float64(acc.ds.RateLimit.RequestsPerHour) / 60.0) + reqInterval := time.Duration(secondsBetweenReqs) * time.Second + + rl.ticker = time.NewTicker(reqInterval) + rl.token = make(chan struct{}, rl.BurstSize) + + for i := 0; i < cap(rl.token); i++ { + rl.token <- struct{}{} + } + go func() { + for range rl.ticker.C { + rl.token <- struct{}{} + } + }() + + acc.t.rateLimiters[rlKey] = rl + } + + return rateLimitedRoundTripper{ + RoundTripper: rt, + token: rl.token, + } +} + +type rateLimitedRoundTripper struct { + http.RoundTripper + token <-chan struct{} +} + +func (rt rateLimitedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + <-rt.token + return rt.RoundTripper.RoundTrip(req) +} + +var rateLimiters = make(map[string]RateLimit) diff --git a/timeliner.go b/timeliner.go new file mode 100644 index 0000000..4510a6f --- /dev/null +++ b/timeliner.go @@ -0,0 +1,133 @@ +// Timeliner - A personal data aggregation utility +// Copyright (C) 2019 Matthew Holt +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see