Skip to content

Commit

Permalink
Standardise code and fix type error (#52)
Browse files Browse the repository at this point in the history
* Fix some type error

* Standardise code style

* Enable python3.5
  • Loading branch information
Yibo-Chen13 authored Oct 22, 2024
1 parent 17d1d16 commit 8aad2aa
Show file tree
Hide file tree
Showing 20 changed files with 4,510 additions and 3,900 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ per-file-ignores =
proton_driver/bufferedreader.pyx: E225, E226, E227, E999
proton_driver/bufferedwriter.pyx: E225, E226, E227, E999
proton_driver/varint.pyx: E225, E226, E227, E999
# ignore example print warning.
example/*: T201, T001
exclude = venv,.conda,build
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
architecture: x64
env:
if: ${{ matrix.python-version == 3.5 }}
PIP_TRUSTED_HOST: "pypi.python.org pypi.org files.pythonhosted.org"
# - name: Login to Docker Hub
# uses: docker/login-action@v1
# with:
Expand Down
40 changes: 25 additions & 15 deletions example/bytewax/hackernews.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ class HNSource(SimplePollingSource):
def next_item(self):
return (
"GLOBAL_ID",
requests.get("https://hacker-news.firebaseio.com/v0/maxitem.json").json(),
requests.get(
"https://hacker-news.firebaseio.com/v0/maxitem.json"
).json(),
)


def get_id_stream(old_max_id, new_max_id) -> Tuple[str,list]:
def get_id_stream(old_max_id, new_max_id) -> Tuple[str, list]:
if old_max_id is None:
# Get the last 150 items on the first run.
old_max_id = new_max_id - 150
Expand All @@ -34,7 +36,7 @@ def download_metadata(hn_id) -> Optional[Tuple[str, dict]]:
# Given an hacker news id returned from the api, fetch metadata
# Try 3 times, waiting more and more, or give up
data = requests.get(
f"https://hacker-news.firebaseio.com/v0/item/{hn_id}.json"
f"https://hacker-news.firebaseio.com/v0/item/{hn_id}.json" # noqa
).json()

if data is None:
Expand All @@ -51,12 +53,7 @@ def recurse_tree(metadata, og_metadata=None) -> any:
parent_metadata = download_metadata(parent_id)
return recurse_tree(parent_metadata[1], og_metadata)
except KeyError:
return (metadata["id"],
{
**og_metadata,
"root_id":metadata["id"]
}
)
return (metadata["id"], {**og_metadata, "root_id": metadata["id"]})


def key_on_parent(key__metadata) -> tuple:
Expand All @@ -68,19 +65,32 @@ def format(id__metadata):
id, metadata = id__metadata
return json.dumps(metadata)


flow = Dataflow("hn_scraper")
max_id = op.input("in", flow, HNSource(timedelta(seconds=15)))
id_stream = op.stateful_map("range", max_id, lambda: None, get_id_stream).then(
op.flat_map, "strip_key_flatten", lambda key_ids: key_ids[1]).then(
op.redistribute, "redist")
id_stream = \
op.stateful_map("range", max_id, lambda: None, get_id_stream) \
.then(op.flat_map, "strip_key_flatten", lambda key_ids: key_ids[1]) \
.then(op.redistribute, "redist")

id_stream = op.filter_map("meta_download", id_stream, download_metadata)
split_stream = op.branch("split_comments", id_stream, lambda item: item[1]["type"] == "story")
split_stream = op.branch(
"split_comments", id_stream, lambda item: item[1]["type"] == "story"
)
story_stream = split_stream.trues
story_stream = op.map("format_stories", story_stream, format)
comment_stream = split_stream.falses
comment_stream = op.map("key_on_parent", comment_stream, key_on_parent)
comment_stream = op.map("format_comments", comment_stream, format)
op.inspect("stories", story_stream)
op.inspect("comments", comment_stream)
op.output("stories-out", story_stream, ProtonSink("hn_stories_raw", os.environ.get("PROTON_HOST","127.0.0.1")))
op.output("comments-out", comment_stream, ProtonSink("hn_comments_raw", os.environ.get("PROTON_HOST","127.0.0.1")))
op.output(
"stories-out",
story_stream,
ProtonSink("hn_stories_raw", os.environ.get("PROTON_HOST", "127.0.0.1")),
)
op.output(
"comments-out",
comment_stream,
ProtonSink("hn_comments_raw", os.environ.get("PROTON_HOST", "127.0.0.1")),
)
23 changes: 12 additions & 11 deletions example/bytewax/proton.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,32 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class _ProtonSinkPartition(StatelessSinkPartition):
def __init__(self, stream: str, host: str):
self.client=client.Client(host=host, port=8463)
self.stream=stream
sql=f"CREATE STREAM IF NOT EXISTS `{stream}` (raw string)"
self.client = client.Client(host=host, port=8463)
self.stream = stream
sql = f"CREATE STREAM IF NOT EXISTS `{stream}` (raw string)" # noqa
logger.debug(sql)
self.client.execute(sql)

def write_batch(self, items):
logger.debug(f"inserting data {items}")
rows=[]
rows = []
for item in items:
rows.append([item]) # single column in each row
rows.append([item]) # single column in each row
sql = f"INSERT INTO `{self.stream}` (raw) VALUES"
logger.debug(f"inserting data {sql}")
self.client.execute(sql,rows)
self.client.execute(sql, rows)


class ProtonSink(DynamicSink):
def __init__(self, stream: str, host: str):
self.stream = stream
self.host = host if host is not None and host != "" else "127.0.0.1"

"""Write each output item to Proton on that worker.

"""
Write each output item to Proton on that worker.
Items consumed from the dataflow must look like a string. Use a
proceeding map step to do custom formatting.
Expand All @@ -40,9 +43,7 @@ def __init__(self, stream: str, host: str):
Can support at-least-once processing. Messages from the resume
epoch will be duplicated right after resume.
"""

def build(self, worker_index, worker_count):
"""See ABC docstring."""
return _ProtonSinkPartition(self.stream, self.host)
return _ProtonSinkPartition(self.stream, self.host)
62 changes: 41 additions & 21 deletions example/descriptive_pipeline/server/main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, Request, BackgroundTasks
from fastapi import (
FastAPI,
WebSocket,
HTTPException,
WebSocketDisconnect,
Request,
BackgroundTasks,
)
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
import yaml
import queue
import threading
import asyncio
import json

from proton_driver import client

from .utils.logging import getLogger

logger = getLogger()


class Pipeline(BaseModel):
name: str
sqls: list[str]
sqls: list[str] # noqa


class Pipelines(BaseModel):
Expand Down Expand Up @@ -58,7 +64,11 @@ def pipeline_exist(self, name):
return False

def delete_pipeline(self, name):
updated_pipelines = [pipeline for pipeline in self.config.pipelines if pipeline.name != name]
updated_pipelines = [
pipeline
for pipeline in self.config.pipelines
if pipeline.name != name
]
self.config.pipelines = updated_pipelines
self.save()

Expand All @@ -73,11 +83,13 @@ def save(self):
yaml.dump(self.config, yaml_file)

def run_pipeline(self, name):
proton_client = client.Client(host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password)
proton_client = client.Client(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
pipeline = self.get_pipeline_by_name(name)
if pipeline is not None:
for query in pipeline.sqls[:-1]:
Expand All @@ -93,7 +105,7 @@ def conf(self):
return self.config


class Query():
class Query:
def __init__(self, sql, client):
self.sql = sql
self.lock = threading.Lock()
Expand Down Expand Up @@ -198,7 +210,7 @@ async def query_stream(name, request, background_tasks):
async def check_disconnect():
while True:
await asyncio.sleep(1)
disconnected = await request.is_disconnected();
disconnected = await request.is_disconnected()
if disconnected:
query.cancel()
logger.info('Client disconnected')
Expand All @@ -215,28 +227,34 @@ async def check_disconnect():
result = {}
for index, (name, t) in enumerate(header):
if t.startswith('date'):
result[name] = str(m[index]) # convert datetime type to string
# convert datetime type to string
result[name] = str(m[index])
else:
result[name] = m[index]
result_str = json.dumps(result).encode("utf-8") + b"\n"
yield result_str
except Exception as e:
query.cancel()
logger.info(f'query cancelled due to {e}' )
logger.info(f'query cancelled due to {e}')
break

if query.is_finshed():
break

await asyncio.sleep(0.1)


@app.get("/queries/{name}")
def query_pipeline(name: str, request: Request , background_tasks: BackgroundTasks):
def query_pipeline(
name: str, request: Request, background_tasks: BackgroundTasks
):
if not config_manager.pipeline_exist(name):
raise HTTPException(status_code=404, detail="pipeline not found")

return StreamingResponse(query_stream(name, request, background_tasks), media_type="application/json")
return StreamingResponse(
query_stream(name, request, background_tasks),
media_type="application/json",
)


@app.websocket("/queries/{name}")
Expand All @@ -258,10 +276,11 @@ async def websocket_endpoint(name: str, websocket: WebSocket):
result = {}
for index, (name, t) in enumerate(header):
if t.startswith('date'):
result[name] = str(m[index]) # convert datetime type to string
# convert datetime type to string
result[name] = str(m[index])
else:
result[name] = m[index]

await websocket.send_text(f'{json.dumps(result)}')
except Exception:
hasError = True
Expand All @@ -282,6 +301,7 @@ async def websocket_endpoint(name: str, websocket: WebSocket):
except Exception as e:
logger.exception(e)
finally:
query.cancel() # Ensure query cancellation even if an exception is raised
# Ensure query cancellation even if an exception is raised
query.cancel()
await websocket.close()
logger.debug('session closed')
27 changes: 17 additions & 10 deletions example/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,31 @@

# setup the test stream
c.execute("drop stream if exists test")
c.execute("""create stream test (
c.execute(
"""create stream test (
year int16,
first_name string
)""")
)"""
)
# add some data
df = pd.DataFrame.from_records([
{'year': 1994, 'first_name': 'Vova'},
{'year': 1995, 'first_name': 'Anja'},
{'year': 1996, 'first_name': 'Vasja'},
{'year': 1997, 'first_name': 'Petja'},
])
df = pd.DataFrame.from_records(
[
{'year': 1994, 'first_name': 'Vova'},
{'year': 1995, 'first_name': 'Anja'},
{'year': 1996, 'first_name': 'Vasja'},
{'year': 1997, 'first_name': 'Petja'},
]
)
c.insert_dataframe(
'INSERT INTO "test" (year, first_name) VALUES',
df,
settings=dict(use_numpy=True),
)
# or c.execute("INSERT INTO test(year, first_name) VALUES", df.to_dict('records'))
time.sleep(3) # wait for 3 sec to make sure data available in historical store
# or c.execute(
# "INSERT INTO test(year, first_name) VALUES", df.to_dict('records')
# )
# wait for 3 sec to make sure data available in historical store
time.sleep(3)

df = c.query_dataframe('SELECT * FROM table(test)')
print(df)
Expand Down
30 changes: 21 additions & 9 deletions example/streaming_query/car.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
"""
This example uses driver DB API.
In this example, a thread writes a huge list of data of car speed into database,
and another thread reads from the database to figure out which car is speeding.
In this example, a thread writes a huge list of data of car speed into
database, and another thread reads from the database to figure out which
car is speeding.
"""

import datetime
import random
import threading
import time

from proton_driver import connect

account='default:'
account = 'default:'


def create_stream():
with connect(f"proton://{account}@localhost:8463/default") as conn:
with connect(f"proton://{account}@localhost:8463/default") as conn: # noqa
with conn.cursor() as cursor:
cursor.execute("drop stream if exists cars")
cursor.execute("create stream if not exists car(id int64, speed float64)")
cursor.execute(
"create stream if not exists car(id int64, speed float64)"
)


def write_data(car_num: int):
car_begin_date = datetime.datetime(2022, 1, 1, 1, 0, 0)
for day in range(100):
car_begin_date += datetime.timedelta(days=1)
data = [(random.randint(0, car_num - 1), random.random() * 20 + 50,
car_begin_date
+ datetime.timedelta(milliseconds=i * 100)) for i in range(300000)]
data = [
(
random.randint(0, car_num - 1),
random.random() * 20 + 50,
car_begin_date + datetime.timedelta(milliseconds=i * 100),
)
for i in range(300000)
]
with connect(f"proton://{account}@localhost:8463/default") as conn:
with conn.cursor() as cursor:
cursor.executemany("insert into car (id, speed, _tp_time) values", data)
cursor.executemany(
"insert into car (id, speed, _tp_time) values", data
)
print(f"row count: {cursor.rowcount}")
time.sleep(10)

Expand Down
Loading

0 comments on commit 8aad2aa

Please sign in to comment.