forked from russs123/brawler_tut
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbedrock_flask_server.py
82 lines (59 loc) · 1.99 KB
/
bedrock_flask_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import boto3
from flask import Flask, jsonify, request
from flask_cors import CORS
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
# Update if needed
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
region_name="us-east-1",
)
def generate_conversation(model_id, system_prompts, messages):
"""
Sends messages to a model.
Args:
model_id (str): The model ID to use.
system_prompts (JSON) : The system prompts for the model to use.
messages (JSON) : The messages to send to the model.
Returns:
response (JSON): The conversation that the model generated.
"""
print(f"Generating message with model {model_id}")
# Inference parameters to use.
temperature = 0.7
# Base inference parameters to use.
inference_config = {"temperature": temperature}
# Send the message.
response = bedrock_runtime.converse(
modelId=model_id,
messages=messages,
system=system_prompts,
inferenceConfig=inference_config,
)
# Log token usage.
token_usage = response["usage"]
print(f"Input tokens: {token_usage['inputTokens']}")
print(f"Output tokens: {token_usage['outputTokens']}")
print(f"Total tokens: {token_usage['totalTokens']}")
print(f"Stop reason: {response['stopReason']}")
text_response = response["output"]["message"]["content"][0]["text"]
return text_response
@app.route("/invoke_model", methods=["POST"])
def invoke_model():
data = request.json
model = data["model"]
system_prompt = data["system_prompt"]
prompt = data["prompt"]
system_prompts = [{"text": system_prompt}]
message_1 = {
"role": "user",
"content": [{"text": prompt}],
}
messages = [message_1]
print(system_prompts)
print(model)
results = generate_conversation(model, system_prompts, messages)
print(results)
return jsonify({"actions": results})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8080)