From ec144cc3bdf47f17a790fc91fcd4da2bae0b056e Mon Sep 17 00:00:00 2001 From: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> Date: Sat, 20 Jul 2024 14:37:33 -0400 Subject: [PATCH] Use newly release firefunction-v2 and perform additional verification step after query selection --- core/models.py | 5 + core/prompts.py | 66 +++++++++-- core/server.py | 191 ++++++++++++++++++++++++------ extension/public/js/background.js | 14 ++- extension/public/js/linkFinder.js | 8 +- extension/src/App.js | 18 +-- 6 files changed, 236 insertions(+), 66 deletions(-) diff --git a/core/models.py b/core/models.py index cdd61e2..31800b5 100644 --- a/core/models.py +++ b/core/models.py @@ -6,6 +6,11 @@ class URL(BaseModel): url: str +class ScrapedURLs(BaseModel): + urls: List[str] + source_url: str + + class SourceDocument(BaseModel): service: str url: str diff --git a/core/prompts.py b/core/prompts.py index 43d5da8..a5fc887 100644 --- a/core/prompts.py +++ b/core/prompts.py @@ -1,9 +1,10 @@ from langchain.prompts import StringPromptTemplate from pydantic import BaseModel, validator -PROMPT = """ +RAG_PROMPT = """ <|system|> -You are an expert lawyer analyzing terms of service agreements. Given a statement about the service and 4 pieces of text extracted from its documents, pick the number of the text that directly answers the query in its entirety. Output a valid JSON object containing the choice of text and concise reasoning. If none of the texts can explicitly answer the statement, return 0. If there is a text that answers the question, set the "answer" field to true. In all other cases, set it to false. +You are an expert lawyer analyzing terms of service agreements for a website (called "service") Given a query statement and 4 pieces of text extracted from the service's documents, pick the number of the text that directly answers the query in its entirety. Output a valid JSON object containing the choice of text and concise reasoning. If none of the texts can explicitly answer the statement, return 0. If there is a text that answers the question, set the "answer" field to true. In all other cases, set it to false. DO NOT IMPLY ANYTHING NOT GIVEN IN THE TEXT. + Here are some examples: Given the statement "You sign away all moral rights", which of the following texts, if any, answer it fully? @@ -41,9 +42,6 @@ * Location information * Log data * Information from cookie data and similar technologies (To find out more about how we use cookies, please see our Cookie Policy) -* Device information -* Usage data and inferences -* User choices ``` 2) ``` @@ -55,9 +53,6 @@ When we use cookies to learn about your behavior on or off of our services, we or our partners will obtain consent that we may need under applicable law. To find out more about how we use cookies, please see our Cookie Policy. -Additional Info for EEA, Swiss and UK Data Subjects: Legal bases we rely on -where we use your information -The below section only applies for residents in the EEA, Switzerland, and UK. ``` 4) ``` @@ -81,7 +76,7 @@ }} <|user|> -Given the statement "{query}", which text provides enough context to explicitly answer the entire statement? Do not infer or imply anything not provided in the texts. Answer with a single JSON object as demonstrated above. +Given the statement "{query}", which text provides enough context to explicitly answer the entire statement? Answer with a single JSON object as demonstrated above. DO NOT IMPLY ANYTHING NOT GIVEN IN THE TEXT. 1) ``` {result1} @@ -102,7 +97,56 @@ <|assistant|> """ -n_results = 4 +DOC_PROMPT = """ +<|user|> +Respond with a JSON object with all the URLs that are likely to contain the terms and conditions, +user agreements, cookie policy, privacy policy etc. for {source} like so: +{{ + "valid_urls": ["https://example.com/terms", "https://example.com/legal/cookies"] +}} +Here are the URLs. +{urls} + +<|assistant|> +""" + +VERIFY_PROMPT = """ +<|user|> +Given a statement about the service {service} and a piece of text that answers it, respond with a JSON object indicating if the statement is true or false like so: +{{ + "statement": bool +}} +Statement: +{statement} +Text: +{text} + +<|assistant|> +""" + + +class VerifyStatementPromptTemplate(StringPromptTemplate, BaseModel): + def format(self, **kwargs) -> str: + prompt = VERIFY_PROMPT.format( + service=kwargs["service"], + statement=kwargs["case"], + text=kwargs["text"] + ) + return prompt + + +class DocClassifierPromptTemplate(StringPromptTemplate, BaseModel): + """ + Determine from the title and source domain of a document discovered by the linkFinder content script + whether is is likely to be a terms and conditions document or not + """ + def format(self, **kwargs) -> str: + prompt = DOC_PROMPT.format( + urls=kwargs["urls"], + source=kwargs["source"] + ) + return prompt + class RAGQueryPromptTemplate(StringPromptTemplate, BaseModel): """ @@ -111,7 +155,7 @@ class RAGQueryPromptTemplate(StringPromptTemplate, BaseModel): """ def format(self, **kwargs) -> str: - prompt = PROMPT.format( + prompt = RAG_PROMPT.format( query=kwargs["query"], result1=kwargs["results"][0], result2=kwargs["results"][1], diff --git a/core/server.py b/core/server.py index 8859b50..e3484b1 100644 --- a/core/server.py +++ b/core/server.py @@ -1,5 +1,4 @@ from fastapi import FastAPI, HTTPException -from typing import List import json import asyncio import chromadb @@ -21,8 +20,16 @@ ) from fastapi.middleware.cors import CORSMiddleware -from prompts import RAGQueryPromptTemplate -from models import URL, SourceDocument, LLMQuery +from prompts import ( + RAGQueryPromptTemplate, + DocClassifierPromptTemplate, + VerifyStatementPromptTemplate +) +from models import ( + ScrapedURLs, + SourceDocument, + LLMQuery +) app = FastAPI() # storage = VectorStore() @@ -59,10 +66,14 @@ # Decent functionality, poor accuracy "zephyr-7b": "accounts/fireworks/models/zephyr-7b-beta", # Expensive and capable Mixtral finetune - "firefunction": "accounts/fireworks/models/firefunction-v1" + "firefunction-v1": "accounts/fireworks/models/firefunction-v1", + "firefunction-v2": "accounts/fireworks/models/firefunction-v2" + } -llm = Fireworks( - model=fireworks_models["firefunction"], + +# llm for rag queries +query_llm = Fireworks( + model=fireworks_models["firefunction-v2"], model_kwargs={ "temperature": 0.1, "max_tokens": 150, @@ -95,9 +106,60 @@ } ) +# llm for verifying query statement +verifier_llm = Fireworks( + model=fireworks_models["firefunction-v2"], + model_kwargs={ + "temperature": 0.1, + "max_tokens": 100, + "top_p": 1.0, + "response_format": { + "type": "json_object", + "schema": """{ + "type": "object", + "properties": { + "statement": { + "type": "boolean" + } + }, + "required": [ + "statement" + ] + }""" + } + } +) + +# llm for determining documents +doc_classifer_llm = Fireworks( + model=fireworks_models["firefunction-v2"], + model_kwargs={ + "temperature": 0.1, + "max_tokens": 500, + "top_p": 1.0, + "response_format": { + "type": "json_object", + "schema": """{ + "type" : "object", + "properties": { + "valid_urls": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": [ + "valid_urls" + ] + }""" + } + } +) + @app.post("/add", status_code=200) -async def add_src_document(src_doc: SourceDocument): +async def add_src_doc(src_doc: SourceDocument): """ Gets a SourceDocument object from user's POSTrequest body containing the raw HTML of the page, parses i, chunks it @@ -181,7 +243,7 @@ async def add_src_document(src_doc: SourceDocument): } -async def scrape_raw_document_from_url(browser, url, service): +async def scrape_src_doc(browser, url, service): try: page = await browser.new_page() await page.goto(url) @@ -190,14 +252,14 @@ async def scrape_raw_document_from_url(browser, url, service): # where the service would be "github.com" but source doc links # are in "docs.github.com" name = await page.title() - src_doc = SourceDocument( - service=service, - url=url, - name=name, - text=html - ) try: - await add_src_document(src_doc) + src_doc = SourceDocument( + service=service, + url=url, + name=name, + text=html + ) + await add_src_doc(src_doc) return True except HTTPException: return False @@ -205,21 +267,48 @@ async def scrape_raw_document_from_url(browser, url, service): return False +async def classify_urls(urls, service): + """ + Uses the LLM to identify which scraped URLs are likely + to contain the T&C documents or not before we add them to the DB + """ + template = DocClassifierPromptTemplate( + input_variables=[ + "urls", + "source" + ] + ) + # Format the URLs first + urls = "\n".join(urls) + prompt = template.format( + urls=urls, + source=service + ) + llm_response = doc_classifer_llm(prompt) + try: + response = json.loads(llm_response) + return response["valid_urls"] + except Exception: + return None + + @app.post("/add_from_url", status_code=200) -async def add_src_document_from_url(urls: List[URL]): +async def add_src_doc_from_url(scraped_urls: ScrapedURLs): """ - Gets a URL to a resource and retrieves the raw document + Handles the URLs that have been scraped from the current page """ - # Assuming all the docs will have the same domain - service = tldextract.extract(urls[0].url).registered_domain + + service = tldextract.extract(scraped_urls.source_url).registered_domain + urls = scraped_urls.urls + valid_urls = await classify_urls(urls, service) succeeded = 0 async with async_playwright() as p: browser = await p.firefox.launch(headless=True) - for url in urls: - if await scrape_raw_document_from_url(browser, url.url, service): + for url in valid_urls: + if await scrape_src_doc(browser, url, service): succeeded += 1 return { - "message": f"Processed {succeeded}/{len(urls)}\ + "message": f"Processed {succeeded}/{len(valid_urls)}\ discovered document URLs from {service}", "service": service } @@ -236,11 +325,11 @@ async def make_query(query: LLMQuery): "results": [] } # For each case, search the vector database for results - for q in query.tosdr_cases: + for case in query.tosdr_cases: result = {} query_response = await asyncio.to_thread( db.similarity_search, - query=q["text"], + query=case["text"], k=4, filter={"service": query.service}, include=["documents", "metadatas"] @@ -252,7 +341,7 @@ async def make_query(query: LLMQuery): continue # For each returned text from the vector store, insert into prompt, # send to model and parse response - template = RAGQueryPromptTemplate( + query_template = RAGQueryPromptTemplate( input_variables=[ "query", "result1", @@ -261,16 +350,16 @@ async def make_query(query: LLMQuery): "result4" ] ) - prompt = template.format( - query=q["text"], + query_prompt = query_template.format( + query=case["text"], results=[doc.page_content for doc in query_response] ) - print("="*100) - print(prompt) - print("="*100) + # print("="*100) + # print(query_prompt) + # print("="*100) - llm_response = llm(prompt) - print(llm_response) + llm_response = query_llm(query_prompt) + # print(llm_response) try: response = json.loads(llm_response) # Extract the choice @@ -279,7 +368,7 @@ async def make_query(query: LLMQuery): source_text = chosen_doc.page_content if choice != 0 else "" # TODO: Fix field duplication later result["source_text"] = source_text - result["tosdr_case"] = q + result["tosdr_case"] = case result["source_doc"] = chosen_doc.metadata["name"] result["source_url"] = chosen_doc.metadata["url"] result["source_service"] = chosen_doc.metadata["service"] @@ -288,11 +377,37 @@ async def make_query(query: LLMQuery): if source_text: result["error"] = None else: - # Model chose 0 - result["error"] = 1 + # Model chose 0, none of the texts are relevant + result["error"] = "irrelevant" except json.JSONDecodeError: - print("Error decoding response from model") - result["error"] = 2 - extension_response["results"].append(result) + result["error"] = "json" + if not result["error"] and result["answer"]: + # Verify the statement if there is no error and a source text + # has been picked + verify_template = VerifyStatementPromptTemplate( + input_variables=[ + "service", + "case", + "text" + ] + ) + verify_prompt = verify_template.format( + service=query.service, + case=result["tosdr_case"]["text"], + text=result["source_text"] + ) + llm_response = verifier_llm(verify_prompt) + # print("="*100) + # print(verify_prompt) + # print("="*100) + # print(llm_response) + try: + response = json.loads(llm_response) + check = response["statement"] + if check: + # Only append it to results if the statement actually appleis + extension_response["results"].append(result) + except json.JSONDecodeError: + print("Error") return extension_response diff --git a/extension/public/js/background.js b/extension/public/js/background.js index 1bef477..fda2c6f 100644 --- a/extension/public/js/background.js +++ b/extension/public/js/background.js @@ -133,20 +133,23 @@ chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => { injectGetContent(false); } - if (msg.action === 'addQueries') { - console.log('[!] addQueries event received'); + if (msg.action === 'updateQueries') { + console.log('[!] updateQueries event received'); tosdr_cases = msg.data; } - if (msg.action === 'retrieveContent') { - console.log('[!] retrieveContent event received'); + if (msg.action === 'scrapedURLs') { + console.log('[!] scrapedURLs event received'); // send it to our backend server fetch(url_upload_endpoint, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(msg.urls.map(url => ({ url }))) + body: JSON.stringify({ + 'urls': msg.urls, + 'source_url': msg.source_url + }) }) .then(response => { if (!response.ok) { @@ -172,7 +175,6 @@ chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => { action: 'backendResponse', type: 'upload_url', error: true, - service: data.service, message: error.message }); console.log("Error in fetching: ", error); diff --git a/extension/public/js/linkFinder.js b/extension/public/js/linkFinder.js index 0ab5cfa..c18cd90 100644 --- a/extension/public/js/linkFinder.js +++ b/extension/public/js/linkFinder.js @@ -10,7 +10,8 @@ function removeQueryAndHash(url) { function findLegalLinks() { const links = Array.from(document.links); - console.log(links) + const currUrl = window.location.href + // simple regex test let legalLinks = links.filter(link => { return /terms|privacy|legal|policy/i.test(link.href); }); @@ -18,8 +19,9 @@ function findLegalLinks() { legalLinks = removeDuplicates(legalLinks) console.log(legalLinks); chrome.runtime.sendMessage({ - action: 'retrieveContent', - urls: legalLinks + action: 'scrapedURLs', + urls: legalLinks, + source_url: currUrl }) } diff --git a/extension/src/App.js b/extension/src/App.js index a93e929..db327d6 100644 --- a/extension/src/App.js +++ b/extension/src/App.js @@ -43,7 +43,7 @@ function App() { }, { 'text': 'This service can be used without providing a user profile', - 'rating': 'good' + 'rating': 'positive' } ], 'checked': false @@ -51,6 +51,10 @@ function App() { { 'name': '👁️ Tracking and data collection', 'cases': [ + { + 'text': 'This service does not track your IP address', + 'rating': 'positive' + }, { 'text': 'This service tracks you on other websites', 'rating': 'negative' @@ -64,7 +68,7 @@ function App() { 'rating': 'warning' }, { - 'text': 'Tracking via third-party cookies for other purposes without your consent', + 'text': 'This service tracks you via third-party cookies for other purposes without your consent', 'rating': 'negative' }, { @@ -76,7 +80,7 @@ function App() { 'rating': 'negative' }, { - 'text': 'Your biometric data is collected', + 'text': 'This service collects your biometric data', 'rating': 'negative' }, { @@ -84,7 +88,7 @@ function App() { 'rating': 'negative' }, { - 'text': 'The cookies used only collect anonymous, aggregated data that cannot be linked to a unique identity.', + 'text': 'This service uses cookies that only collect anonymous, aggregated data that cannot be linked to a unique identity.', 'rating': 'positive' } ], @@ -120,7 +124,7 @@ function App() { 'name': '⚖️ Legal rights', 'cases': [ { - 'text': 'Terms may be changed any time at their discretion, without notice to you', + 'text': 'This service may change their terms at any time at their discretion, without notice to you', 'rating': 'warning' }, { @@ -256,7 +260,7 @@ function App() { const handleSubmitQueryCategories = () => { browser.runtime.sendMessage({ - action: 'addQueries', + action: 'updateQueries', data: queryCategories .filter( query => query.checked === true @@ -342,7 +346,6 @@ function App() { {results.map((item, index) => ( <> { - item.answer && (
{item.reason}