forked from davila7/stable-diffusion-free-gpu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
56 lines (42 loc) · 1.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from flask_ngrok import run_with_ngrok
from flask import Flask, render_template, request
import torch
from diffusers import StableDiffusionPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.set_default_device("cuda")
import base64
from io import BytesIO
# Load text model
text_model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(text_model_id, torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(text_model_id, trust_remote_code=True)
# Load text model
image_model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(image_model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
# Start flask app and set to ngrok
app = Flask(__name__)
run_with_ngrok(app)
@app.route('/')
def initial():
return render_template('index.html')
@app.route('/submit-prompt', methods=['POST'])
def generate():
#get the prompt input
prompt = request.form['prompt-input']
print(f"Generating an image of {prompt}")
# generate image
image = pipe(prompt).images[0]
print("Image generated! Converting image ...")
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
img_str = "data:image/png;base64," + str(img_str)[2:-1]
#generate text
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
generated_output = model.generate(input_ids, do_sample=True, temperature=1.0, max_length=2500, num_return_sequences=1)
generated_text = tokenizer.decode(generated_output[0], skip_special_tokens=True)
print("Sending image and text ...")
return render_template('index.html', generated_image=img_str, generated_text=generated_text, prompt=prompt)
if __name__ == '__main__':
app.run()