Skip to content

Commit

Permalink
Merge pull request #6 from A3Data/features/API
Browse files Browse the repository at this point in the history
Features/api
  • Loading branch information
henrique-tostes-a3 authored Sep 26, 2024
2 parents 2870187 + 82009a0 commit 3ae163f
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 1 deletion.
56 changes: 56 additions & 0 deletions api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Guia do Usuário API

1. Certifique-se de ter o ambiente virtual ativado.
2. Instale as dependências usando o Poetry.
3. Navegue até a pasta principal do repositório.
4. Execute a API com o comando:

```bash
python -m api.app.main
```

A API começará a rodar em [http://127.0.0.1:8000](http://127.0.0.1:8000) por padrão. Você pode usar ferramentas como curl, Postman ou um script Python para interagir com os endpoints.

Além disso, você pode acessar a documentação da API em [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs), que é gerada automaticamente pelo FastAPI.

## Endpoints da API

### 1. `/predict` (POST)

Prevê a espécie de uma flor de íris com base em uma única amostra de entrada.

**Solicitação:** Informe os valores neste modelo:

```json
{
"sepal_length": 0,
"sepal_width": 0,
"petal_length": 0,
"petal_width": 0
}
```

### 2. `/predict-batch` (POST)

Processa um arquivo CSV contendo várias amostras para previsões em lote.

**Solicitação:** Carregue um arquivo CSV com as seguintes colunas:

- `sepal_length`
- `sepal_width`
- `petal_length`
- `petal_width`

### 3. `/evaluate` (POST)

Processa um arquivo CSV e avalia previsões em relação a rótulos verdadeiros.

**Solicitação:** Carregue um arquivo CSV com as mesmas colunas acima, mais:

- `species` (o verdadeiro rótulo para avaliação)

## Estrutura do Código

- **main.py**: O ponto de entrada do aplicativo, inicializa o aplicativo FastAPI e inclui o roteador para manipular solicitações.
- **routes.py**: Contém as definições para os endpoints da API, manipulando solicitações de previsões e avaliações.
- **schemas.py**: Define os modelos de dados para validação de entrada e saída usando Pydantic.
Empty file added api/app/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions api/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from fastapi import FastAPI
from .routes import router
import os

app = FastAPI()

app.include_router(router)

if __name__ == "__main__":
import uvicorn

host = os.getenv("HOST", "127.0.0.1")
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host=host, port=port)
88 changes: 88 additions & 0 deletions api/app/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pandas as pd
from fastapi import APIRouter, HTTPException, File, UploadFile
from .schemas import PredictRequest, PredictionResponse
from config.settings import MODELS_DIR
from src.pipelines.predict import load_model, make_predictions, evaluate_predictions
from io import StringIO
import json

router = APIRouter()
model_path = MODELS_DIR / "model.joblib"


@router.post("/predict", response_model=PredictionResponse)
async def predict(input: PredictRequest):
data = input.to_dataframe()

try:
model = load_model(model_path)
prediction = make_predictions(model, data)
result = int(prediction[0])
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Erro ao fazer a previsão: {str(e)}"
)

return PredictionResponse(prediction=result)


@router.post("/predict-batch", response_model=dict)
async def predict_batch(file: UploadFile = File(...)):
"""Faz previsões em lote a partir de um arquivo CSV.
O arquivo CSV deve conter as seguintes colunas:
- sepal_length: Comprimento da sépala.
- sepal_width: Largura da sépala.
- petal_length: Comprimento da pétala.
- petal_width: Largura da pétala.
"""
try:
contents = await file.read()
df = pd.read_csv(StringIO(contents.decode("utf-8")))
model = load_model(model_path)

predictions = make_predictions(model, df)

return {"predictions": predictions.tolist()}

except Exception as e:
raise HTTPException(
status_code=500, detail=f"Erro ao processar o arquivo: {str(e)}"
)


@router.post("/evaluate", response_model=dict)
async def predict_batch_with_evaluation(file: UploadFile = File(...)):
"""Faz previsões em lote a partir de um arquivo CSV.
O arquivo CSV deve conter as seguintes colunas:
- sepal_length: Comprimento da sépala.
- sepal_width: Largura da sépala.
- petal_length: Comprimento da pétala.
- petal_width: Largura da pétala.
- species: Rótulo verdadeiro da espécie (para avaliação).
"""
try:
contents = await file.read()
df = pd.read_csv(StringIO(contents.decode("utf-8")))
model = load_model(model_path)

if "species" not in df.columns:
raise HTTPException(
status_code=400,
detail="Os dados devem conter a coluna 'species' para avaliação.",
)

true_labels = df.pop("species")
predictions = make_predictions(model, df)

report_dict = evaluate_predictions(predictions, true_labels)
report = json.dumps(report_dict, indent=4)
return {
"evaluation_report": json.loads(report),
}

except Exception as e:
raise HTTPException(
status_code=500, detail=f"Erro ao processar o arquivo: {str(e)}"
)
27 changes: 27 additions & 0 deletions api/app/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pandas as pd
from pydantic import BaseModel


class PredictRequest(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float

def to_dataframe(self) -> pd.DataFrame:
"""Prepara os dados como um DataFrame a partir do pedido."""
return pd.DataFrame(
[
[
self.sepal_length,
self.sepal_width,
self.petal_length,
self.petal_width,
]
],
columns=["sepal_length", "sepal_width", "petal_length", "petal_width"],
)


class PredictionResponse(BaseModel):
prediction: int
Binary file removed artifacts/models/svc_model.joblib
Binary file not shown.
4 changes: 3 additions & 1 deletion src/pipelines/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def evaluate_predictions(predictions: pd.Series, true_labels: pd.Series):
Returns:
str: Relatório de classificação.
"""
return classification_report(true_labels, predictions)
return classification_report(
true_labels, predictions, output_dict=True, zero_division=0
)


@app.command()
Expand Down

0 comments on commit 3ae163f

Please sign in to comment.