-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
94 lines (80 loc) · 3.02 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from flask import Flask, request, jsonify
from flask_cors import cross_origin
from transformers import T5ForConditionalGeneration, AutoTokenizer
from utils import *
app = Flask(__name__)
# available models
available_models = [
"CodeVerbTLM-0.7B"
]
# inference types
inference_types = [
"Comment2Python",
"Speech2Python",
"Algo2Python",
]
# Load our model
checkpoint = "Salesforce/codet5p-770m-py"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading Model on Device: ", device)
with print_time('Loading Parameters: '):
model = T5ForConditionalGeneration.from_pretrained(checkpoint).to(device)
with print_time('Fetching Tokenizer: '):
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def preprocess_code(code):
# Find the first occurrence of """
first_occurrence = code.find('"""')
if first_occurrence == -1:
# No extra docstring found, return the code as is
return code
# Find the next occurrence of """
second_occurrence = code.find('"""', first_occurrence + 3)
if second_occurrence == -1:
# No closing docstring found, remove the starting docstring
code = code.replace('"""', '', 1)
return code
# Define a simple route
@app.route('/', methods=['GET'])
@cross_origin()
def home():
msg = {
"API Name": "CodeVerb TLM 0.7M API",
"API Version": "v1.0",
"API Status": "Running",
"Available Models": available_models
}
return jsonify(msg), 200, {'Content-Type': 'application/json; charset=utf-8'}
# Define a route that accepts POST requests with JSON data
@app.route('/api/predict', methods=['POST'])
@cross_origin()
def process_data():
if request.method == 'POST':
data = request.json
query = data['query']
model_name = data['model']
inference_type = data['inference_type']
if inference_type not in inference_types:
msg = {
"error": "Inference type not available! Available inference types: {}".format(inference_types)
}
return jsonify(msg), 400, {'Content-Type': 'application/json; charset=utf-8'}
if model_name not in available_models:
msg = {
"error": "Model not available! Available models: {}".format(available_models)
}
return jsonify(msg), 400, {'Content-Type': 'application/json; charset=utf-8'}
# Preprocess input
query = preprocess_string(query)
# Predicted code here
input = tokenizer.encode(query, return_tensors="pt").to(device)
predicted_code = model.generate(input, max_length=512)
predicted_code = tokenizer.decode(predicted_code[0], skip_special_tokens=True)
predicted_code = preprocess_code(predicted_code)
msg = {
"query": query,
"result": predicted_code
}
return jsonify(msg), 200, {'Content-Type': 'application/json; charset=utf-8'}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8083, debug=False)