-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Completed Text Generation Models with raw responses
- Loading branch information
1 parent
a18a332
commit ff267d0
Showing
12 changed files
with
173 additions
and
45 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1 @@ | ||
class CloudflareAI { | ||
late String apiToken; | ||
late String accountId; | ||
|
||
CloudflareAI({required apiToken, required accountId}); | ||
} | ||
export 'text_generation/text_generation.dart'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,18 @@ | ||
import 'package:dio/dio.dart'; | ||
|
||
class NetworkService { | ||
Dio dio = Dio(); | ||
final Dio _dio = Dio(); | ||
|
||
Future<Map<String, dynamic>> post(String url, String apiKey, Map data) async { | ||
Response res = await _dio.post( | ||
url, | ||
options: Options( | ||
headers: { | ||
"Authorization": "Bearer $apiKey", | ||
}, | ||
), | ||
data: data, | ||
); | ||
return res.data; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
enum Role { system, user, assistant } | ||
|
||
class ChatModel { | ||
late Role role; | ||
late String message; | ||
|
||
ChatModel({ | ||
required Role role, | ||
required String messsage, | ||
}); | ||
|
||
ChatModel.fromJson(data) { | ||
role = data['role']; | ||
message = data['message']; | ||
} | ||
|
||
Map toJson() { | ||
return { | ||
"role": role.name, | ||
"message": message, | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
class TextGenerationResponseModel { | ||
late ResultModel result; | ||
late bool success; | ||
|
||
TextGenerationResponseModel({required this.result, required this.success}); | ||
|
||
TextGenerationResponseModel.fromJson(Map<String, dynamic> json) { | ||
result = ResultModel.fromJson(json['result']); | ||
success = json['success']; | ||
} | ||
|
||
Map<String, dynamic> toJson() { | ||
final Map<String, dynamic> data = <String, dynamic>{}; | ||
data['result'] = result.toJson(); | ||
data['success'] = success; | ||
return data; | ||
} | ||
} | ||
|
||
class ResultModel { | ||
String? response; | ||
|
||
ResultModel({this.response}); | ||
|
||
ResultModel.fromJson(Map<String, dynamic> json) { | ||
response = json['response']; | ||
} | ||
|
||
Map<String, dynamic> toJson() { | ||
final Map<String, dynamic> data = <String, dynamic>{}; | ||
data['response'] = response; | ||
return data; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import 'package:cloudflare_ai/src/text_generation/raw_response_model.dart'; | ||
|
||
import '../services/network_service.dart'; | ||
import 'text_generation_models.dart'; | ||
|
||
class TextGenerationModel { | ||
late String accountId; | ||
late String apiKey; | ||
late TextGenerationModelsEnum model; | ||
late bool raw; | ||
NetworkService networkService = NetworkService(); | ||
late String baseUrl; | ||
|
||
TextGenerationModel({ | ||
required this.accountId, | ||
required this.apiKey, | ||
required this.model, | ||
this.raw = true, | ||
}) { | ||
baseUrl = "https://api.cloudflare.com/client/v4/accounts/$accountId/ai/run"; | ||
} | ||
|
||
Future<TextGenerationResponseModel> generateText(String prompt) async { | ||
Map<String, dynamic> res = | ||
await networkService.post("$baseUrl/${model.value}", apiKey, { | ||
"prompt": prompt, | ||
"raw": raw, | ||
}); | ||
TextGenerationResponseModel response = | ||
TextGenerationResponseModel.fromJson(res); | ||
return response; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
enum TextGenerationModelsEnum { | ||
LLAMA_2_7B("@cf/meta/llama-2-7b-chat-fp16"), | ||
LLAMA_2_7B_INT8("@cf/meta/llama-2-7b-chat-int8"), | ||
MISTRAL_7B("@cf/mistral/mistral-7b-instruct-v0.1"), | ||
CODE_LLAMA_7B("@hf/thebloke/codellama-7b-instruct-awq"), | ||
CODE_LLAMA_2_13B("@hf/thebloke/llama-2-13b-chat-awq"), | ||
ZEPHYR_7B("@hf/thebloke/zephyr-7b-beta-awq"), | ||
MISTRAL_7B_AWQ_V01("@hf/thebloke/mistral-7b-instruct-v0.1-awq"), | ||
MISTRAL_7B_AWQ_V02("@hf/mistral/mistral-7b-instruct-v0.2"), | ||
OPENHERMES_MISTRAL_7B("@hf/thebloke/openhermes-2.5-mistral-7b-awq"), | ||
NEURAL_CHAT_7B("@hf/thebloke/neural-chat-7b-v3-1-awq"), | ||
LLAMA_GUARD_7B("@hf/thebloke/llamaguard-7b-awq"), | ||
DEEPSEEK_CODER_6_7_BASE("@hf/thebloke/deepseek-coder-6.7b-base-awq"), | ||
DEEPSEEK_CODER_6_7_INSTRUCT("@hf/thebloke/deepseek-coder-6.7b-instruct-awq"), | ||
DEEPSEEK_MATH_7B_BASE("@@cf/deepseek-ai/deepseek-math-7b-base"), | ||
DEEPSEEK_MATH_7B_INSTRUCT("@cf/deepseek-ai/deepseek-math-7b-instruct"), | ||
OPENCHAT_3_5("@cf/openchat/openchat-3.5-0106"), | ||
PHI_2("@cf/phi/phi-2"), | ||
TINYLAMA_1_1B("@cf/tinyllama/tinyllama-1.1b-chat-v1.0"), | ||
DISCOLM_GERMAN_7B("@cf/thebloke/discolm-german-7b-v1-awq"), | ||
QWEN_1_5_0_5B_CHAT("@cf/qwen/qwen-1.5.0.5b-chat"), | ||
QWEN1_5_1_8B_CHAT("@cf/qwen/qwen1.5-1.8b-chat"), | ||
QWEN_1_5_7B_CHAT_AWQ("@cf/qwen/qwen1.5-7b-chat-awq"), | ||
QWEN_1_5_14B_CHAT_AWQ("@cf/qwen/qwen1.5-14b-chat-awq"), | ||
FALCON_7B_INSTRUCT("@cf/tiiuae/falcon-7b-instruct"), | ||
GEMMA_2B_IT_LORA("@cf/google/gemma-2b-it-lora"), | ||
GEMMA_7B_IT("@hf/google/gemma-7b-it"), | ||
GEMMA_7B_IT_LORA("@cf/google/gemma-7b-it-lora"), | ||
HERMES_2_PRO_7B("@hf/nousresearch/hermes-2-pro-mistral-7b"), | ||
LLAMA_2_7B_CHAT_HF_LORA("@cf/meta-llama/llama-2-7b-chat-hf-lora"), | ||
LLAMA_3_8B_INSTRUCT("@hf/meta-llama/meta-llama-3-8b-instruct"), | ||
UNA_CYBERTRON_7B_V2_BF16("@cf/fblgit/una-cybertron-7b-v2-bf16"), | ||
STARLING_LM_7B_BETA("@hf/nexusflow/starling-lm-7b-beta"), | ||
SQL_CODER_7B_2("@cf/defog/sqlcoder-7b-2"); | ||
|
||
const TextGenerationModelsEnum(this.value); | ||
final String value; | ||
} |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,34 @@ | ||
import 'package:cloudflare_ai/cloudflare_ai.dart'; | ||
import 'package:cloudflare_ai/src/text_generation/raw_response_model.dart'; | ||
import 'package:cloudflare_ai/src/text_generation/text_generation_models.dart'; | ||
import 'package:dotenv/dotenv.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('A group of tests', () { | ||
|
||
group('Text Generation Models', () { | ||
final env = DotEnv()..load(); | ||
String accountId = env['accountId'] ?? ""; | ||
String apiKey = env['apiKey'] ?? ""; | ||
test("Gemma 7B IT: Generate Content", () async { | ||
TextGenerationModel model = TextGenerationModel( | ||
accountId: accountId, | ||
apiKey: apiKey, | ||
model: TextGenerationModelsEnum.GEMMA_7B_IT, | ||
); | ||
TextGenerationResponseModel res = await model.generateText("Hello!"); | ||
expect(res.result.response, isNotNull); | ||
expect(res.success, true); | ||
}); | ||
|
||
test("Falcon 7B Instruct: Generate Content", () async { | ||
TextGenerationModel model = TextGenerationModel( | ||
accountId: accountId, | ||
apiKey: apiKey, | ||
model: TextGenerationModelsEnum.FALCON_7B_INSTRUCT, | ||
); | ||
TextGenerationResponseModel res = await model.generateText("Hello!"); | ||
expect(res.result.response, isNotNull); | ||
expect(res.success, true); | ||
}); | ||
}); | ||
} |