diff --git a/eia/api.py b/eia/api.py index ed44224..7e5185a 100644 --- a/eia/api.py +++ b/eia/api.py @@ -18,44 +18,73 @@ except ImportError: import json -__all__ = ['Series', 'Geoset', 'Relation', 'Category', 'SeriesCategory', 'Updates', 'Search', 'BaseQuery'] +__all__ = ['Series', 'Geoset', 'Relation', 'Category', + 'SeriesCategory', 'Updates', 'Search', 'BaseQuery'] requests_cache.install_cache(backend="memory", - expire_after = 600, - ignored_parameters = "api_key") + expire_after=600, + ignored_parameters="api_key") # Monkey patch requests_cache, Ideally, we can make the backend # user-configureable + def chunk(l, n): """Chunk a list into n parts.""" for i in range(0, len(l), n): - yield l[i:i+n] + yield l[i:i + n] + class BaseQuery(object): host = "https://api.eia.gov" endpoint = "" - def __init__(self, api_key, output="json"): + def __init__(self, api_key, output="json", consumers=4): self.api_key = api_key - self.default_params = {"api_key" : self.api_key, "out" : output} + self.default_params = {"api_key": self.api_key, "out": output} self.base_url = urljoin(self.host, self.endpoint) + if hasattr(self, 'parse'): + self.queue = Queue() + self.results = Queue() + self.consumers = [] + for n in range(consumers): + consumer = Thread(target=self._consume) + consumer.daemon = True + consumer.start() + self.consumers.append(consumer) + + def _consume(self): + """Lightweight server for non-blocking requests + """ + while True: + job = self.queue.get() + try: + result = self.parse(job) + self.results.put(result) + except Exception as e: + warn(e) + self.queue.task_done() def get(self, **kwargs): params = dict(self.default_params) params.update(kwargs) - r = requests.get(self.base_url, params = params) + r = requests.get(self.base_url, params=params) r.raise_for_status() return r.json() def post(self, **data): r = requests.post(self.base_url, params=self.default_params, data=data) - r.raise_for_status() - return r.json() + try: + r.raise_for_status() + return r.json() + except Exception as e: + print(e) def query(self, *args, **kwargs): + """If a 'parse' method is defined, this method should queue jobs.""" raise NotImplementedError + class Series(BaseQuery): """ @@ -72,39 +101,32 @@ class Series(BaseQuery): endpoint = "series/" - def parse(self, q, series_id, **kwargs): - """Thread target to put results on a queue.""" - if len(series_id) + len(self.base_url) >= 1800: - q.put(self.post(series_id = series_id, **kwargs)) - else: - q.put(self.get(series_id=series_id, **kwargs)) - - def query(self, *series_ids, **kwargs): - q = Queue(maxsize=1000) - threads = [] - for c in chunk(list(set(series_ids)), 100): - t = Thread(target = self.parse, - args = (q, ';'.join(c)), kwargs=kwargs) - t.daemon = True - t.start() - threads.append(t) - map(lambda x : x.join(), threads) - data = filter(lambda x : 'series' in x, [q.get() for _ in threads]) - return [s for d in data for s in d['series']] + def parse(self, value): + return self.post(series_id=value) - def query_df(self, *series_ids, **kwargs): - """Serialize query results to DataFrame.""" - data = self.query(*series_ids, **kwargs) + def query(self, *series_ids): + """Chunks series_ids into groups of 100 to send to a consumer.""" + for i, C in enumerate(chunk(series_ids, 100)): + self.queue.put(';'.join(C)) + self.queue.join() # Wait for all queries to complete + output = [self.results.get() for _ in range(i + 1)] # Deplete results + return [s for d in output for s in d["series"] if 'series' in d] + + def query_df(self, *series_ids): + data = self.query(*series_ids) o = [] for d in data: df = pd.DataFrame(d.pop('data'), columns=['period', 'value']) - # This is better for pandas >= 0.16.0 - # o.append(df.assign(**d)) + # This is better for pandas >= 0.16.0 : o.append(df.assign(**d)) # However, do it this way for pandas < 0.16.0 support for k, v in d.items(): df[k] = v o.append(df) - return pd.concat(o, ignore_index=True) + if o: + return pd.concat(o, ignore_index=True) + else: + return pd.DataFrame() + class Geoset(BaseQuery): """ @@ -122,8 +144,8 @@ class Geoset(BaseQuery): endpoint = "geoset/" def query(self, geoset_id, *regions, **kwargs): - data = self.get(geoset_id = geoset_id, regions = ','.join(regions), - **kwargs) + data = self.get(geoset_id=geoset_id, regions=','.join(regions), + **kwargs) return data['geoset'] def query_df(self, geoset_id, *regions, **kwargs): @@ -139,11 +161,13 @@ def query_df(self, geoset_id, *regions, **kwargs): df[k] = v return df + class Relation(BaseQuery): """Not implemented for now, this does not appear to be a valid endpoint.""" # endpoint = "relation/" pass + class Category(BaseQuery): """ category_id: optional, unique numerical id of the category to fetch. If missing, the API's root category is fetched. @@ -152,32 +176,23 @@ class Category(BaseQuery): endpoint = "category/" def query(self, category_id=None, **kwargs): - return self.get(category_id = category_id, **kwargs)['category'] + return self.get(category_id=category_id, **kwargs)['category'] + class SeriesCategory(BaseQuery): endpoint = "series/categories/" - def parse(self, q, series_id, **kwargs): - """Thread target to put results on a queue.""" - # IMPROVEMENT : calculate dry-run url len - if len(series_id) + len(self.base_url) >= 1800: - q.put(self.post(series_id = series_id, **kwargs)) - else: - q.put(self.get(series_id=series_id, **kwargs)) - - def query(self, *series_ids, **kwargs): - q = Queue(maxsize=1000) - threads = [] - for c in chunk(list(set(series_ids)), 100): - t = Thread(target = self.parse, - args = (q, ';'.join(c)), kwargs=kwargs) - t.daemon = True - t.start() - threads.append(t) - map(lambda x : x.join(), threads) - data = filter(lambda x : 'series_categories' in x, [q.get() for _ in threads]) - return [s for d in data for s in d['series_categories']] + def parse(self, value): + return self.post(series_id=value) + + def query(self, *series_ids): + for i, C in enumerate(chunk(series_ids, 100)): + self.queue.put(';'.join(C)) + self.queue.join() + output = [self.results.get() for _ in range(i + 1)] # Deplete results + k = "series_categories" + return [s for d in output for s in d[k] if k in d] def query_df(self, *series_ids, **kwargs): results = self.query(*series_ids, **kwargs) @@ -193,39 +208,28 @@ class Updates(BaseQuery): endpoint = "updates/" - def parse(self, q, **kwargs): - q.put(self.get(**kwargs)) - - def query(self, category_id = None, rows=50, firstrow=0, deep = False): + def parse(self, page_params): + return self.get(**page_params) - poll = self.get(category_id = category_id, - deep=deep, - rows=1, - firstrow=0) # Check number of available rows + def query(self, category_id=None, rows=50, firstrow=0, deep=False): + poll = self.get(category_id=category_id, + deep=deep, + rows=1, + firstrow=0) # Check number of available rows n = rows or poll['data']['rows_available'] - pages = int(np.ceil(n / 10000.)) - params = {"category_id" : category_id, - "deep" : deep, - "rows" : 10000} - q = Queue(maxsize=1000) - threads = [] - for page in range(pages): + params = {"category_id": category_id, "deep": deep, "rows": 10000} + for page in range(int(np.ceil(n / 10000.))): page_params = dict(params) - first = page*10000 - page_params.update({"firstrow" : first}) + first = page * 10000 + page_params.update({"firstrow": first}) if first + 10000 > rows: rows = rows - first - page_params.update({"rows" : rows}) - t = Thread(target = self.parse, - args = (q, ), kwargs=page_params) - t.daemon = True - t.start() - threads.append(t) - map(lambda x : x.join(), threads) - data = filter( - lambda x : ('updates' in x and isinstance(x['updates'][0], dict)), - [q.get() for _ in threads]) - return [s for d in data for s in d['updates']] + page_params.update({"rows": rows}) + self.queue.put(page_params) + self.queue.join() + output = [self.results.get() for _ in range(page + 1)] + k = "updates" + return [s for d in output for s in d[k] if k in d] def query_df(self, *args, **kwargs): return pd.DataFrame(self.query(*args, **kwargs)) @@ -235,11 +239,11 @@ class Search(BaseQuery): endpoint = "search/" - def parse(self, q, *args, **kwargs): - q.put(self.get(*args, **kwargs)['response']['docs']) + def parse(self, val): + return self.get(**val)['response']['docs'] def query(self, search_term, search_value, - rows_per_page=10, page_num=1, **kwargs): + rows_per_page=10, page_num=1): """ search_term : str, one of ["series_id", "name", "last_updated"] @@ -261,47 +265,43 @@ def query(self, search_term, search_value, """ t, v = self.clean_search_params(search_term, search_value) if (rows_per_page == 0) or (rows_per_page == 'all'): - r = self.get(search_term = t, search_value = v, page_num=1, rows_per_page = 1, **kwargs) + r = self.get(search_term=t, search_value=v, + page_num=1, rows_per_page=1) total = r['response']['numFound'] # Chunk by 5000's for reliability/reasonable speed chunksize = 7500 - pages = int(np.ceil(total/float(chunksize))) - params = {"search_term" : t, "search_value" : v} - params.update(kwargs) - threads = [] - q = Queue(maxsize=1000) - for page in range(0, pages): + pages = int(np.ceil(total / float(chunksize))) + params = {"search_term": t, "search_value": v} + for page in range(pages): rows = chunksize - first = page*rows + first = page * rows if first + rows > total: rows = total - first page_params = dict(params) - page_params.update({"page_num" : page, "rows_per_page" : rows}) - t = Thread(target=self.parse, args=(q,), kwargs=page_params) - t.daemon = True - t.start() - threads.append(t) - map(lambda x : x.join(), threads) - results = [q.get() for _ in threads] - results = [i for chunk in results for i in chunk] - - else: # limit search results, paginate elsewhere - r = self.get(search_term = t, search_value = v, page_num=page_num, rows_per_page = rows_per_page, **kwargs) + page_params.update({"page_num": page, "rows_per_page": rows}) + self.queue.put(page_params) + self.queue.join() + output = [self.results.get() for _ in range(pages)] + results = [i for chunk in output for i in chunk] + else: # limit search results, paginate elsewhere + r = self.get(search_term=t, search_value=v, + page_num=page_num, rows_per_page=rows_per_page) results = r['response']['docs'] return results def clean_search_params(self, search_term, search_value): - t, v = search_term, search_value # Alias for convenience - assert t in ["series_id", "name", "last_updated"], "Invalid search term" + t, v = search_term, search_value # Alias for convenience + assert t in ["series_id", "name", + "last_updated"], "Invalid search term" if t in ["series_id", "name"]: if not re.match(re.escape('^"{}"$'.format(v.strip('"'))), v): v = '"{}"'.format(search_value) - else: # search_term == "last_updated" - if len(list(v)) == 2: # Assume it's a tuple, list, or iterable + else: # search_term == "last_updated" + if len(list(v)) == 2: # Assume it's a tuple, list, or iterable daterange = map(pd.to_datetime, list(v)) strfdaterange = map( - lambda x : x.strftime('%Y-%m-%dT%H:%M:%SZ'), + lambda x: x.strftime('%Y-%m-%dT%H:%M:%SZ'), daterange) v = "[{}]".format(" TO ".join(strfdaterange)) # else : assume user has read documentation and is feeding in