Skip to content

Commit

Permalink
Add streaming and system messages support in Airforce
Browse files Browse the repository at this point in the history
  • Loading branch information
hlohaus committed Dec 14, 2024
1 parent a591c5d commit 315a2f2
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 139 deletions.
7 changes: 6 additions & 1 deletion etc/unittest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@ async def provider_has_model(self, provider: Type[BaseProvider], model: str):
except (MissingRequirementsError, MissingAuthError):
return
if self.cache[provider.__name__]:
self.assertIn(model, self.cache[provider.__name__], provider.__name__)
self.assertIn(model, self.cache[provider.__name__], provider.__name__)

async def test_all_providers_working(self):
for model, providers in __models__.values():
for provider in providers:
self.assertTrue(provider.working, f"{provider.__name__} in {model.name}")
27 changes: 14 additions & 13 deletions g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from aiohttp import ClientSession
from typing import List
from requests.packages.urllib3.exceptions import InsecureRequestWarning

from ..typing import AsyncResult, Messages
from ..image import ImageResponse
from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin

from .. import debug
Expand All @@ -32,7 +34,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
api_endpoint_imagine2 = "https://api.airforce/imagine2"

working = True
supports_stream = False
supports_stream = True
supports_system_message = True
supports_message_history = True

Expand Down Expand Up @@ -87,20 +89,19 @@ def get_models(cls):
debug.log(f"Error fetching text models: {e}")

return cls.models

@classmethod
async def check_api_key(cls, api_key: str) -> bool:
"""
Always returns True to allow all models.
"""
if not api_key or api_key == "null":
return True # No restrictions if no key.

headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Accept": "*/*",
}

try:
async with ClientSession(headers=headers) as session:
async with session.get(f"https://api.airforce/check?key={api_key}") as response:
Expand Down Expand Up @@ -195,11 +196,13 @@ async def generate_text(
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
full_message = "\n".join([msg['content'] for msg in messages])
message_chunks = split_message(full_message, max_length=1000)

final_messages = []
for message in messages:
message_chunks = split_message(message["content"], max_length=1000)
final_messages.extend([{"role": message["role"], "content": chunk} for chunk in message_chunks])
data = {
"messages": [{"role": "user", "content": chunk} for chunk in message_chunks],
"messages": final_messages,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
Expand All @@ -209,10 +212,9 @@ async def generate_text(

async with ClientSession(headers=headers) as session:
async with session.post(cls.api_endpoint_completions, json=data, proxy=proxy) as response:
response.raise_for_status()
await raise_for_status(response)

if stream:
buffer = [] # Buffer to collect partial responses
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('data: '):
Expand All @@ -222,12 +224,11 @@ async def generate_text(
if 'choices' in chunk and chunk['choices']:
delta = chunk['choices'][0].get('delta', {})
if 'content' in delta:
buffer.append(delta['content'])
chunk = cls._filter_response(delta['content'])
if chunk:
yield chunk
except json.JSONDecodeError:
continue
# Combine the buffered response and filter it
filtered_response = cls._filter_response(''.join(buffer))
yield filtered_response
else:
# Non-streaming response
result = await response.json()
Expand Down
51 changes: 30 additions & 21 deletions g4f/Provider/Copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):

websocket_url = "wss://copilot.microsoft.com/c/api/chat?api-version=2"
conversation_url = f"{url}/c/api/conversations"

_access_token: str = None
_cookies: CookieJar = None

Expand Down Expand Up @@ -94,20 +94,20 @@ def create_completion(
) as session:
if cls._access_token is not None:
cls._cookies = session.cookies.jar
if cls._access_token is None:
try:
url = "https://copilot.microsoft.com/cl/eus-sc/collect"
headers = {
"Accept": "application/x-clarity-gzip",
"referrer": "https://copilot.microsoft.com/onboarding"
}
response = session.post(url, headers=headers, data=get_clarity())
clarity_token = json.loads(response.text.split(" ", maxsplit=1)[-1])[0]["value"]
debug.log(f"Copilot: Clarity Token: ...{clarity_token[-12:]}")
except Exception as e:
debug.log(f"Copilot: {e}")
else:
clarity_token = None
# if cls._access_token is None:
# try:
# url = "https://copilot.microsoft.com/cl/eus-sc/collect"
# headers = {
# "Accept": "application/x-clarity-gzip",
# "referrer": "https://copilot.microsoft.com/onboarding"
# }
# response = session.post(url, headers=headers, data=get_clarity())
# clarity_token = json.loads(response.text.split(" ", maxsplit=1)[-1])[0]["value"]
# debug.log(f"Copilot: Clarity Token: ...{clarity_token[-12:]}")
# except Exception as e:
# debug.log(f"Copilot: {e}")
# else:
# clarity_token = None
response = session.get("https://copilot.microsoft.com/c/api/user")
raise_for_status(response)
user = response.json().get('firstName')
Expand All @@ -121,6 +121,14 @@ def create_completion(
if return_conversation:
yield Conversation(conversation_id)
prompt = format_prompt(messages)
if len(prompt) > 10000:
if len(messages) > 6:
prompt = format_prompt(messages[:3]+messages[-3:])
elif len(messages) > 2:
prompt = format_prompt(messages[:2]+messages[-1:])
if len(prompt) > 10000:
prompt = messages[-1]["content"]
debug.log(f"Copilot: Trim messages to: {len(prompt)}")
debug.log(f"Copilot: Created conversation: {conversation_id}")
else:
conversation_id = conversation.conversation_id
Expand All @@ -138,14 +146,15 @@ def create_completion(
)
raise_for_status(response)
uploaded_images.append({"type":"image", "url": response.json().get("url")})
break

wss = session.ws_connect(cls.websocket_url)
if clarity_token is not None:
wss.send(json.dumps({
"event": "challengeResponse",
"token": clarity_token,
"method":"clarity"
}).encode(), CurlWsFlag.TEXT)
# if clarity_token is not None:
# wss.send(json.dumps({
# "event": "challengeResponse",
# "token": clarity_token,
# "method":"clarity"
# }).encode(), CurlWsFlag.TEXT)
wss.send(json.dumps({
"event": "send",
"conversationId": conversation_id,
Expand Down
2 changes: 2 additions & 0 deletions g4f/Provider/openai/har_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, arkURL, arkBx, arkHeader, arkBody, arkCookies, userAgent):
self.userAgent = userAgent

def get_har_files():
if not os.access(get_cookies_dir(), os.R_OK):
raise NoValidHarFileError("har_and_cookies dir is not readable")
harPath = []
for root, _, files in os.walk(get_cookies_dir()):
for file in files:
Expand Down
2 changes: 1 addition & 1 deletion g4f/gui/client/static/css/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ ul {
.buttons {
align-items: flex-start;
flex-wrap: wrap;
gap: 15px;
gap: 12px;
}

.mobile-sidebar {
Expand Down
6 changes: 3 additions & 3 deletions g4f/gui/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
}

def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
def log_handler(text: str):
def decorated_log(text: str):
debug.logs.append(text)
if debug.logging:
print(text)
debug.log_handler = log_handler
debug.log_handler(text)
debug.log = decorated_log
proxy = os.environ.get("G4F_PROXY")
provider = kwargs.get("provider")
model, provider_handler = get_model_and_provider(
Expand Down
Loading

0 comments on commit 315a2f2

Please sign in to comment.