forked from sci-assess/SciAssess
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopenai_with_pdf.py
78 lines (67 loc) · 2.72 KB
/
openai_with_pdf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Any, Optional, Union
import os
import traceback
from openai import OpenAI
from evals.api import CompletionFn, CompletionResult
from evals.base import CompletionFnSpec
from evals.prompt.base import (
ChatCompletionPrompt,
CompletionPrompt,
OpenAICreateChatPrompt,
OpenAICreatePrompt,
Prompt,
)
from evals.record import record_sampling
from evals.utils.api_utils import (
openai_chat_completion_create_retrying,
openai_completion_create_retrying,
)
from evals.completion_fns.openai import OpenAIChatCompletionResult, OpenAICompletionResult
from .utils import extract_text, ErrorCompletionResult, call_without_throw, cache_to_disk
class OpenAIChatCompletionFnWithPDF(CompletionFnSpec):
def __init__(
self,
model: Optional[str] = None,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
n_ctx: Optional[int] = None,
extra_options: Optional[dict] = {},
**kwargs,
):
self.model = model
self.api_base = api_base
self.api_key = api_key
self.n_ctx = n_ctx
self.extra_options = extra_options
@call_without_throw
def __call__(
self,
prompt: Union[str, OpenAICreateChatPrompt],
**kwargs,
) -> OpenAIChatCompletionResult:
if not isinstance(prompt, Prompt):
assert (
isinstance(prompt, str)
or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt))
or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt))
or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt))
), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]"
prompt = ChatCompletionPrompt(
raw_prompt=prompt,
)
openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt()
if "file_name" in kwargs:
attached_file_content = "\nThe file is as follows:\n\n" + "".join(extract_text(kwargs["file_name"]))
kwargs.pop('file_name')
else:
attached_file_content = ""
openai_create_prompt[-1]["content"] += attached_file_content
result = openai_chat_completion_create_retrying(
OpenAI(api_key=self.api_key, base_url=self.api_base),
model=self.model,
messages=openai_create_prompt,
**{**kwargs, **self.extra_options},
)
result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt)
record_sampling(prompt=result.prompt, sampled=result.get_completions())
return result