diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..80c0421 --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +exclude = + __pycache__ + .github + .pytest_cache + .venv \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e5cb0b8..7ede33c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: jobs: test: - name: Test + name: Lint & Test runs-on: ubuntu-latest steps: @@ -22,8 +22,16 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Install dependencies for testing + run: | + if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - pip install pytest pytest-cov pytest \ No newline at end of file diff --git a/.vs-code/settings.json b/.vs-code/settings.json new file mode 100644 index 0000000..4bdddb4 --- /dev/null +++ b/.vs-code/settings.json @@ -0,0 +1,19 @@ +{ + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + "[python]": { + "editor.codeActionsOnSave": { + "source.organizeImports": true + } + }, + "black-formatter.args": [ + "--line-length", + "88" + ], + "flake8.args": [ + "--max-line-length", + "88", + "--extend-ignore", + "E203" + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9d6cb40 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "." + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/main.py b/main.py index de0f356..ea8042a 100644 --- a/main.py +++ b/main.py @@ -2,22 +2,23 @@ from pydantic import BaseModel from fastapi import FastAPI + app = FastAPI() -pipe = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier") +model = "shahrukhx01/question-vs-statement-classifier" +pipe = pipeline("text-classification", model=model) + + +custom_labels = {"LABEL_0": "STATEMENT", "LABEL_1": "QUESTION"} -custom_labels = { - "LABEL_0": "STATEMENT", - "LABEL_1": "QUESTION" -} class Payload(BaseModel): - text: str + text: str + @app.post("/test") async def test(payload: Payload): - result = pipe(payload.text)[0] - - # Customize the label - result['label'] = custom_labels.get(result['label'], result['label']) + result = pipe(payload.text)[0] - return result \ No newline at end of file + # Customize the label + result["label"] = custom_labels.get(result["label"], result["label"]) + return result diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..d04dcac --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +pytest==8.3.3 +pytest-cov==5.0.0 +httpx==0.27.2 +flake8==7.1.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b4cd9da..c684453 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,3 @@ fastapi[standard]==0.114.1 pydantic==2.8.0 -transformers[torch]==4.44.2 - -# Test -pytest==8.3.3 -pytest-cov==5.0.0 -httpx==0.27.2 \ No newline at end of file +transformers[torch]==4.44.2 \ No newline at end of file diff --git a/test_main.py b/test_main.py index 44644ea..6ef88ff 100644 --- a/test_main.py +++ b/test_main.py @@ -3,8 +3,8 @@ client = TestClient(app) -def test_route(): - response = client.post('/test', json={ "text": "Is this a test?" }) - assert response.status_code == 200 - assert response.json()['label'] == 'QUESTION' +def test_route(): + response = client.post("/test", json={"text": "Is this a test?"}) + assert response.status_code == 200 + assert response.json()["label"] == "QUESTION"