-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from SLIIT-24-25J-047-Research/Development
Development
- Loading branch information
Showing
25 changed files
with
599 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
frontend/.env | ||
backend/.env | ||
backend/uploads | ||
backend/candidate_classification_service/model/final_trained_model.pth | ||
backend/node_modules | ||
frontend/node_modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Ignore Python virtual environment | ||
venv/ | ||
# Ignore Python cache files | ||
*.pyc | ||
*.pyo | ||
*.pyd | ||
*.py[cod] | ||
__pycache__/ | ||
|
||
# Ignore Jupyter Notebook checkpoints | ||
.ipynb_checkpoints/ | ||
|
||
# Ignore IDE specific files | ||
.idea/ | ||
.vscode/ | ||
|
||
# Ignore local configuration files (optional) | ||
.env | ||
pyvenv.cfg | ||
|
||
models |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# prebuilt base image with large dependencies | ||
FROM python-base:latest | ||
|
||
# Working directory | ||
WORKDIR /app | ||
|
||
# Install system dependencies for PyTorch | ||
RUN apt-get update && apt-get install -y \ | ||
libgl1-mesa-glx \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
|
||
# Create the uploads directory | ||
RUN mkdir -p /app/uploads | ||
|
||
# Copy only the requirements.txt to leverage Docker caching | ||
COPY requirements.txt . | ||
|
||
# Install Python dependencies | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
# Copy the entire application code | ||
COPY . . | ||
|
||
# Expose the application port | ||
EXPOSE 3003 | ||
|
||
# Command to run the Flask app | ||
CMD ["python", "app.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os | ||
import torch | ||
from flask import Flask, request, jsonify | ||
from flask_cors import CORS | ||
from PIL import Image | ||
from torchvision import models, transforms | ||
import io | ||
import requests # To send the result to Node.js | ||
|
||
# Initialize Flask app | ||
app = Flask(__name__) | ||
CORS(app) | ||
|
||
# Define device and load the model | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Load the trained model | ||
model_path = './model/final_trained_model.pth' | ||
model = models.resnet18(weights=None) # My model is initiating to learning architecture from this ResNet18 architecture | ||
model.fc = torch.nn.Linear(model.fc.in_features, 4) # Adjust for 4 classes | ||
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) | ||
model = model.to(device) | ||
model.eval() | ||
|
||
# Define class names | ||
class_names = ["clean_casual", "clean_formal", "messy_casual", "messy_formal"] | ||
|
||
# Define image transformations (same as in training) | ||
transform = transforms.Compose([ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
|
||
# Prediction function | ||
def predict_image(image_bytes): | ||
image = Image.open(io.BytesIO(image_bytes)) # Load image from bytes | ||
image = transform(image).unsqueeze(0).to(device) # Apply transformation and add batch dimension | ||
with torch.no_grad(): | ||
outputs = model(image) | ||
_, predicted = torch.max(outputs, 1) # Get the predicted class | ||
return class_names[predicted.item()] | ||
|
||
# Create route for prediction | ||
@app.route('/classification-predict', methods=['POST']) | ||
def predict(): | ||
if 'file' not in request.files: | ||
return jsonify({"error": "No file part"}), 400 | ||
file = request.files['file'] | ||
if file.filename == '': | ||
return jsonify({"error": "No selected file"}), 400 | ||
|
||
try: | ||
image_bytes = file.read() # Read the file content as bytes | ||
prediction = predict_image(image_bytes) # Get prediction | ||
|
||
# Send the result to the Node.js backend (assuming API expects prediction and email) | ||
user_email = request.form.get("email") | ||
if user_email: | ||
response = requests.post( | ||
'http://localhost:5000/api/savePrediction', # Adjust the URL to your Node.js server | ||
json={'prediction': prediction, 'email': user_email} | ||
) | ||
if response.status_code == 200: | ||
return jsonify({"prediction": prediction, "status": "Prediction saved"}) | ||
else: | ||
return jsonify({"error": "Failed to save prediction in database"}), 500 | ||
else: | ||
return jsonify({"error": "Email is required"}), 400 | ||
|
||
except Exception as e: | ||
return jsonify({"error": str(e)}), 500 | ||
|
||
# Run the app | ||
if __name__ == '__main__': | ||
app.run(host='0.0.0.0', port=3003, debug=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Use the slim Python image as the base | ||
FROM python:3.12-slim | ||
|
||
# Set the working directory | ||
WORKDIR /app | ||
|
||
# Install necessary system packages | ||
RUN apt-get update && apt-get install -y --no-install-recommends \ | ||
build-essential \ | ||
libjpeg-dev \ | ||
zlib1g-dev \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
|
||
# Install Python dependencies that are common for the service | ||
RUN pip install --no-cache-dir \ | ||
torch \ | ||
torchvision \ | ||
Pillow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
torch | ||
torchvision | ||
flask | ||
pillow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
const ImageClassification = require('../../models/candidate/imageClassification'); | ||
|
||
// Function to store classification result | ||
const storePrediction = async (email, prediction) => { | ||
try { | ||
const newClassification = new ImageClassification({ | ||
email: email, | ||
prediction: prediction, | ||
}); | ||
await newClassification.save(); | ||
return { success: true, message: 'Prediction saved successfully' }; | ||
} catch (error) { | ||
console.error('Error saving prediction:', error); | ||
return { success: false, message: 'Error saving prediction' }; | ||
} | ||
}; | ||
|
||
|
||
module.exports = { | ||
storePrediction, | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
const Prediction = require("../../models/candidate/Prediction"); | ||
|
||
exports.savePrediction = async (req, res) => { | ||
const { email, prediction } = req.body; | ||
|
||
if (!email || !prediction) { | ||
return res.status(400).json({ error: "Email and prediction are required" }); | ||
} | ||
|
||
try { | ||
const newPrediction = new Prediction({ email, prediction }); | ||
await newPrediction.save(); | ||
return res.status(200).json({ message: "Prediction saved successfully" }); | ||
} catch (error) { | ||
console.error(error); | ||
return res.status(500).json({ error: "Failed to save prediction" }); | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.