Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Multiple Choice type questions #21

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
193 changes: 135 additions & 58 deletions backend/server.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,191 @@
import http.server
import json
import random
import socketserver
import urllib.parse
import requests
import torch
from models.modelC.distractor_generator import DistractorGenerator
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json

IP='127.0.0.1'
PORT=8000
IP = "127.0.0.1"
PORT = 8000

def summarize(text):
summarizer=pipeline('summarization')
return summarizer(text,max_length=110)[0]['summary_text']


def generate_question(context,answer,model_path, tokenizer_path):
def summarize(text):
summarizer = pipeline("summarization")
return summarizer(text, max_length=110)[0]["summary_text"]

def get_distractors_conceptnet(word, context):
word = word.lower()
original_word = word
if len(word.split()) > 0:
word = word.replace(" ", "_")
distractor_list = []
# context_sentences = context.split(".")
try:
relationships = ["/r/PartOf", "/r/IsA", "/r/HasA"]

for rel in relationships:
url = f"http://api.conceptnet.io/query?node=/c/en/{word}/n&rel={rel}&start=/c/en/{word}&limit=5"
if context:
url += f"&context={context}"

obj = requests.get(url).json()

for edge in obj["edges"]:
word2 = edge["end"]["label"]
if (
word2 not in distractor_list
and original_word.lower() not in word2.lower()
):
distractor_list.append(word2)
return distractor_list

except json.decoder.JSONDecodeError as e:
print(f"Error decoding JSON from ConceptNet API: {e}")
return distractor_list
except requests.RequestException as e:
print(f"Error making request to ConceptNet API: {e}")
return distractor_list


def generate_question(context, answer, model_path, tokenizer_path):
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

input_text=f'answer: {answer} context: {context}'
input_text = f"answer: {answer} context: {context}"

inputs=tokenizer.encode_plus(
inputs = tokenizer.encode_plus(
input_text,
padding='max_length',
padding="max_length",
truncation=True,
max_length=512,
return_tensors='pt'
return_tensors="pt",
)

input_ids=inputs['input_ids'].to(device)
attention_mask=inputs['attention_mask'].to(device)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

with torch.no_grad():
output=model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=32
output = model.generate(
input_ids=input_ids, attention_mask=attention_mask, max_length=32
)

generated_question = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_question

def generate_keyphrases(abstract, model_path,tokenizer_path):
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def generate_keyphrases(abstract, model_path, tokenizer_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
model.to(device)
# tokenizer.to(device)
input_text=f'detect keyword: abstract: {abstract}'
input_ids=tokenizer.encode(input_text, truncation=True,padding='max_length',max_length=512,return_tensors='pt').to(device)
output=model.generate(input_ids)
keyphrases= tokenizer.decode(output[0],skip_special_tokens=True).split(',')
return [x.strip() for x in keyphrases if x != '']
input_text = f"detect keyword: abstract: {abstract}"
input_ids = tokenizer.encode(
input_text,
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt",
).to(device)
output = model.generate(input_ids)
keyphrases = tokenizer.decode(output[0], skip_special_tokens=True).split(",")
return [x.strip() for x in keyphrases if x != ""]


def generate_qa(text):
def generate_qa(self, text, question_type):
modelA, modelB = "./models/modelA", "./models/modelB"
tokenizerA, tokenizerB = "t5-base", "t5-base"
if question_type == "text":
text_summary = text
answers = generate_keyphrases(text_summary, modelA, tokenizerA)
qa = {}
for answer in answers:
question = generate_question(text_summary, answer, modelB, tokenizerB)
qa[question] = answer

# text_summary=summarize(text)
text_summary=text

return qa

modelA, modelB='./models/modelA','./models/modelB'
# tokenizerA, tokenizerB= './tokenizers/tokenizerA', './tokenizers/tokenizerB'
tokenizerA, tokenizerB= 't5-base', 't5-base'
elif question_type == "mcq":
text_summary = text

answers=generate_keyphrases(text_summary, modelA, tokenizerA)
answers = generate_keyphrases(text_summary, modelA, tokenizerA)

qa={}
for answer in answers:
question= generate_question(text_summary, answer, modelB, tokenizerB)
qa[question]=answer

return qa

qa = {}
for answer in answers:
question = generate_question(text_summary, answer, modelB, tokenizerB)
conceptnet_distractors = get_distractors_conceptnet(answer, text_summary)
t5_distractors = self.distractor_generator.generate(
5, answer, question, text_summary
)

dist_temp = list(set(conceptnet_distractors + t5_distractors))
dist = [x for x in dist_temp if x.lower() != answer.lower()]
print(conceptnet_distractors)

if len(dist) < 1:
distractors = []
qa[question] = answer
else:
distractors = random.sample(dist, min(3, len(dist)))
options = distractors + [answer]
random.shuffle(options)

formatted_question = f"{question} Options: {', '.join(options)}"

qa[formatted_question] = answer

return qa


class QARequestHandler(http.server.BaseHTTPRequestHandler):

def do_POST(self):
def do_OPTIONS(self):
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type")
self.send_header("Content-Length", "0")
self.end_headers()

def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()

content_length=int(self.headers["Content-Length"])
post_data=self.rfile.read(content_length).decode('utf-8')

# parsed_data=urllib.parse.parse_qs(post_data)
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length).decode("utf-8")
parsed_data = json.loads(post_data)
if self.path == "/":
input_text = parsed_data.get("input_text")
question_type = self.headers.get("Question-Type", "text")

qa = generate_qa(self, input_text, question_type)

input_text=parsed_data.get('input_text')

qa=generate_qa(input_text)
self.wfile.write(json.dumps(qa).encode("utf-8"))
self.wfile.flush()


class CustomRequestHandler(QARequestHandler):
def __init__(self, *args, **kwargs):
self.distractor_generator = kwargs.pop("distractor_generator")
super().__init__(*args, **kwargs)

self.wfile.write(json.dumps(qa).encode("utf-8"))
self.wfile.flush()

def main():
with socketserver.TCPServer((IP, PORT), QARequestHandler) as server:
print(f'Server started at http://{IP}:{PORT}')
distractor_generator = DistractorGenerator()
with socketserver.TCPServer(
(IP, PORT),
lambda x, y, z: CustomRequestHandler(
x, y, z, distractor_generator=distractor_generator
),
) as server:
print(f"Server started at http://{IP}:{PORT}")
server.serve_forever()

if __name__=="__main__":
main()


if __name__ == "__main__":
main()
5 changes: 1 addition & 4 deletions extension/html/text_input.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
<html>
<head>
<title>EduAid: Text Input</title>
<!-- <link href="https://fonts.googleapis.com/css?family=Roboto:400,500" rel="stylesheet">
<link rel="stylesheet" href="./popup.css"> -->
<script src="../pdfjs-3.9.179-dist/build/pdf.js"></script>
<link href='https://fonts.googleapis.com/css?family=Inter' rel='stylesheet'>
<link rel="stylesheet" href="../styles/text_input.css">
Expand All @@ -24,11 +22,10 @@ <h3>Generate QnA</h3>
<button id="back-button">Back</button>
<button id="next-button">Next</button>
</div>
<!-- ******************* -->
<button id="mcq-button">Generate MCQ</button>
<div id="loading-screen" class="loading-screen">
<div class="loading-spinner"></div>
</div>
<!-- ****************** -->
</main>
<script src="../js/text_input.js"></script>
</body>
Expand Down
111 changes: 58 additions & 53 deletions extension/js/question_generation.js
Original file line number Diff line number Diff line change
@@ -1,56 +1,61 @@
document.addEventListener("DOMContentLoaded", function(){
const saveButton= document.getElementById("save-button");
const backButton= document.getElementById("back-button");
const viewQuestionsButton = document.getElementById("view-questions-button");
const qaPairs=JSON.parse(localStorage.getItem("qaPairs"));
const modalClose= document.querySelector("[data-close-modal]");
const modal=document.querySelector("[data-modal]");


viewQuestionsButton.addEventListener("click", function(){
const modalQuestionList = document.getElementById("modal-question-list");
modalQuestionList.innerHTML = ""; // Clear previous content

for (const [question, answer] of Object.entries(qaPairs)) {
const questionElement = document.createElement("li");
questionElement.textContent = `Question: ${question}, Answer: ${answer}`;
modalQuestionList.appendChild(questionElement)
document.addEventListener("DOMContentLoaded", function () {
const saveButton = document.getElementById("save-button");
const backButton = document.getElementById("back-button");
const viewQuestionsButton = document.getElementById("view-questions-button");
const qaPairs = JSON.parse(localStorage.getItem("qaPairs"));
const modalClose = document.querySelector("[data-close-modal]");
const modal = document.querySelector("[data-modal]");

viewQuestionsButton.addEventListener("click", function () {
const modalQuestionList = document.getElementById("modal-question-list");
modalQuestionList.innerHTML = "";

for (const [question, answer] of Object.entries(qaPairs)) {
const questionElement = document.createElement("li");
if (question.includes("Options:")) {
const options = question.split("Options: ")[1].split(", ");
const formattedOptions = options.map(
(opt, index) => `${String.fromCharCode(97 + index)}) ${opt}`
);
questionElement.textContent = `Question: ${
question.split(" Options:")[0]
}\n${formattedOptions.join("\n")}`;
} else {
questionElement.textContent = `Question: ${question}\n\nAnswer: ${answer}\n`;
}
modal.showModal();
});

modalClose.addEventListener("click", function(){
modal.close();
});
saveButton.addEventListener("click", async function(){
let textContent= "EduAid Generated QnA:\n\n";
modalQuestionList.appendChild(questionElement);
}
modal.showModal();
});

for (const [question,answer] of Object.entries(qaPairs)){
textContent+= `Question: ${question}\nAnswer: ${answer}\n\n`;
}
const blob = new Blob([textContent], { type: "text/plain" });

// Create a URL for the Blob
const blobUrl = URL.createObjectURL(blob);

// Create a temporary <a> element to trigger the download
const downloadLink = document.createElement("a");
downloadLink.href = blobUrl;
downloadLink.download = "questions_and_answers.txt";
downloadLink.style.display = "none";

// Append the <a> element to the document
document.body.appendChild(downloadLink);

// Simulate a click on the link to trigger the download
downloadLink.click();

// Clean up: remove the temporary <a> element and revoke the Blob URL
document.body.removeChild(downloadLink);
URL.revokeObjectURL(blobUrl);
});

backButton.addEventListener("click", function(){
window.location.href="../html/text_input.html"
});
});
modalClose.addEventListener("click", function () {
modal.close();
});
saveButton.addEventListener("click", async function () {
let textContent = "EduAid Generated QnA:\n\n";

for (const [question, answer] of Object.entries(qaPairs)) {
textContent += `Question: ${question}\nAnswer: ${answer}\n\n`;
}
const blob = new Blob([textContent], { type: "text/plain" });

const blobUrl = URL.createObjectURL(blob);

const downloadLink = document.createElement("a");
downloadLink.href = blobUrl;
downloadLink.download = "questions_and_answers.txt";
downloadLink.style.display = "none";

document.body.appendChild(downloadLink);

downloadLink.click();

document.body.removeChild(downloadLink);
URL.revokeObjectURL(blobUrl);
});

backButton.addEventListener("click", function () {
window.location.href = "../html/text_input.html";
});
});
Loading