Skip to content

Commit

Permalink
Completed Text Generation Models with raw responses
Browse files Browse the repository at this point in the history
  • Loading branch information
MananGandhi1810 committed May 16, 2024
1 parent a18a332 commit ff267d0
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 45 deletions.
Empty file added .github/workflows/main.yml
Empty file.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# Avoid committing pubspec.lock for library packages; see
# https://dart.dev/guides/libraries/private-files#pubspeclock.
pubspec.lock
.env*
7 changes: 1 addition & 6 deletions lib/src/cloudflare_ai_base.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
class CloudflareAI {
late String apiToken;
late String accountId;

CloudflareAI({required apiToken, required accountId});
}
export 'text_generation/text_generation.dart';
15 changes: 14 additions & 1 deletion lib/src/services/network_service.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
import 'package:dio/dio.dart';

class NetworkService {
Dio dio = Dio();
final Dio _dio = Dio();

Future<Map<String, dynamic>> post(String url, String apiKey, Map data) async {
Response res = await _dio.post(
url,
options: Options(
headers: {
"Authorization": "Bearer $apiKey",
},
),
data: data,
);
return res.data;
}
}
23 changes: 23 additions & 0 deletions lib/src/text_generation/chat_model.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
enum Role { system, user, assistant }

class ChatModel {
late Role role;
late String message;

ChatModel({
required Role role,
required String messsage,
});

ChatModel.fromJson(data) {
role = data['role'];
message = data['message'];
}

Map toJson() {
return {
"role": role.name,
"message": message,
};
}
}
34 changes: 34 additions & 0 deletions lib/src/text_generation/raw_response_model.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
class TextGenerationResponseModel {
late ResultModel result;
late bool success;

TextGenerationResponseModel({required this.result, required this.success});

TextGenerationResponseModel.fromJson(Map<String, dynamic> json) {
result = ResultModel.fromJson(json['result']);
success = json['success'];
}

Map<String, dynamic> toJson() {
final Map<String, dynamic> data = <String, dynamic>{};
data['result'] = result.toJson();
data['success'] = success;
return data;
}
}

class ResultModel {
String? response;

ResultModel({this.response});

ResultModel.fromJson(Map<String, dynamic> json) {
response = json['response'];
}

Map<String, dynamic> toJson() {
final Map<String, dynamic> data = <String, dynamic>{};
data['response'] = response;
return data;
}
}
33 changes: 33 additions & 0 deletions lib/src/text_generation/text_generation.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import 'package:cloudflare_ai/src/text_generation/raw_response_model.dart';

import '../services/network_service.dart';
import 'text_generation_models.dart';

class TextGenerationModel {
late String accountId;
late String apiKey;
late TextGenerationModelsEnum model;
late bool raw;
NetworkService networkService = NetworkService();
late String baseUrl;

TextGenerationModel({
required this.accountId,
required this.apiKey,
required this.model,
this.raw = true,
}) {
baseUrl = "https://api.cloudflare.com/client/v4/accounts/$accountId/ai/run";
}

Future<TextGenerationResponseModel> generateText(String prompt) async {
Map<String, dynamic> res =
await networkService.post("$baseUrl/${model.value}", apiKey, {
"prompt": prompt,
"raw": raw,
});
TextGenerationResponseModel response =
TextGenerationResponseModel.fromJson(res);
return response;
}
}
38 changes: 38 additions & 0 deletions lib/src/text_generation/text_generation_models.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
enum TextGenerationModelsEnum {
LLAMA_2_7B("@cf/meta/llama-2-7b-chat-fp16"),
LLAMA_2_7B_INT8("@cf/meta/llama-2-7b-chat-int8"),
MISTRAL_7B("@cf/mistral/mistral-7b-instruct-v0.1"),
CODE_LLAMA_7B("@hf/thebloke/codellama-7b-instruct-awq"),
CODE_LLAMA_2_13B("@hf/thebloke/llama-2-13b-chat-awq"),
ZEPHYR_7B("@hf/thebloke/zephyr-7b-beta-awq"),
MISTRAL_7B_AWQ_V01("@hf/thebloke/mistral-7b-instruct-v0.1-awq"),
MISTRAL_7B_AWQ_V02("@hf/mistral/mistral-7b-instruct-v0.2"),
OPENHERMES_MISTRAL_7B("@hf/thebloke/openhermes-2.5-mistral-7b-awq"),
NEURAL_CHAT_7B("@hf/thebloke/neural-chat-7b-v3-1-awq"),
LLAMA_GUARD_7B("@hf/thebloke/llamaguard-7b-awq"),
DEEPSEEK_CODER_6_7_BASE("@hf/thebloke/deepseek-coder-6.7b-base-awq"),
DEEPSEEK_CODER_6_7_INSTRUCT("@hf/thebloke/deepseek-coder-6.7b-instruct-awq"),
DEEPSEEK_MATH_7B_BASE("@@cf/deepseek-ai/deepseek-math-7b-base"),
DEEPSEEK_MATH_7B_INSTRUCT("@cf/deepseek-ai/deepseek-math-7b-instruct"),
OPENCHAT_3_5("@cf/openchat/openchat-3.5-0106"),
PHI_2("@cf/phi/phi-2"),
TINYLAMA_1_1B("@cf/tinyllama/tinyllama-1.1b-chat-v1.0"),
DISCOLM_GERMAN_7B("@cf/thebloke/discolm-german-7b-v1-awq"),
QWEN_1_5_0_5B_CHAT("@cf/qwen/qwen-1.5.0.5b-chat"),
QWEN1_5_1_8B_CHAT("@cf/qwen/qwen1.5-1.8b-chat"),
QWEN_1_5_7B_CHAT_AWQ("@cf/qwen/qwen1.5-7b-chat-awq"),
QWEN_1_5_14B_CHAT_AWQ("@cf/qwen/qwen1.5-14b-chat-awq"),
FALCON_7B_INSTRUCT("@cf/tiiuae/falcon-7b-instruct"),
GEMMA_2B_IT_LORA("@cf/google/gemma-2b-it-lora"),
GEMMA_7B_IT("@hf/google/gemma-7b-it"),
GEMMA_7B_IT_LORA("@cf/google/gemma-7b-it-lora"),
HERMES_2_PRO_7B("@hf/nousresearch/hermes-2-pro-mistral-7b"),
LLAMA_2_7B_CHAT_HF_LORA("@cf/meta-llama/llama-2-7b-chat-hf-lora"),
LLAMA_3_8B_INSTRUCT("@hf/meta-llama/meta-llama-3-8b-instruct"),
UNA_CYBERTRON_7B_V2_BF16("@cf/fblgit/una-cybertron-7b-v2-bf16"),
STARLING_LM_7B_BETA("@hf/nexusflow/starling-lm-7b-beta"),
SQL_CODER_7B_2("@cf/defog/sqlcoder-7b-2");

const TextGenerationModelsEnum(this.value);
final String value;
}
34 changes: 0 additions & 34 deletions lib/src/text_to_text/response_model.dart

This file was deleted.

2 changes: 0 additions & 2 deletions lib/src/text_to_text/text_to_text.dart

This file was deleted.

1 change: 1 addition & 0 deletions pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ environment:
# Add regular dependencies here.
dependencies:
dio: ^5.4.3+1
dotenv: ^4.2.0
# path: ^1.8.0

dev_dependencies:
Expand Down
30 changes: 28 additions & 2 deletions test/cloudflare_ai_test.dart
Original file line number Diff line number Diff line change
@@ -1,8 +1,34 @@
import 'package:cloudflare_ai/cloudflare_ai.dart';
import 'package:cloudflare_ai/src/text_generation/raw_response_model.dart';
import 'package:cloudflare_ai/src/text_generation/text_generation_models.dart';
import 'package:dotenv/dotenv.dart';
import 'package:test/test.dart';

void main() {
group('A group of tests', () {

group('Text Generation Models', () {
final env = DotEnv()..load();
String accountId = env['accountId'] ?? "";
String apiKey = env['apiKey'] ?? "";
test("Gemma 7B IT: Generate Content", () async {
TextGenerationModel model = TextGenerationModel(
accountId: accountId,
apiKey: apiKey,
model: TextGenerationModelsEnum.GEMMA_7B_IT,
);
TextGenerationResponseModel res = await model.generateText("Hello!");
expect(res.result.response, isNotNull);
expect(res.success, true);
});

test("Falcon 7B Instruct: Generate Content", () async {
TextGenerationModel model = TextGenerationModel(
accountId: accountId,
apiKey: apiKey,
model: TextGenerationModelsEnum.FALCON_7B_INSTRUCT,
);
TextGenerationResponseModel res = await model.generateText("Hello!");
expect(res.result.response, isNotNull);
expect(res.success, true);
});
});
}

0 comments on commit ff267d0

Please sign in to comment.