Skip to content

Commit

Permalink
Cleanups on conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Nov 11, 2024
1 parent 606d5f5 commit f99ed4b
Show file tree
Hide file tree
Showing 17 changed files with 531 additions and 166 deletions.
63 changes: 63 additions & 0 deletions js/sdk/__tests__/ConversationsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { r2rClient } from "../src/index";
const fs = require("fs");
import { describe, test, beforeAll, expect } from "@jest/globals";

const baseUrl = "http://localhost:7272";

describe("r2rClient V3 Collections Integration Tests", () => {
let client: r2rClient;
let conversationId: string;
let messageId: string;

beforeAll(async () => {
client = new r2rClient(baseUrl);
await client.users.login({
email: "admin@example.com",
password: "change_me_immediately",
});
});

test("List all conversations", async () => {
const response = await client.conversations.list();
expect(response.results).toBeDefined();
});

test("Create a conversation", async () => {
const response = await client.conversations.create();
conversationId = response.results.id;
expect(response.results).toBeDefined();
});

test("Add a message to a conversation", async () => {
const response = await client.conversations.addMessage({
id: conversationId,
content: "Hello, world!",
role: "user",
});
messageId = response.results.id;
expect(response.results).toBeDefined();
});

// TODO: This is throwing a 405? Why?
// test("Update a message in a conversation", async () => {
// const response = await client.conversations.updateMessage({
// id: conversationId,
// message_id: messageId,
// content: "Hello, world! How are you?",
// });
// expect(response.results).toBeDefined();
// });

test("List branches in a conversation", async () => {
const response = await client.conversations.listBranches({
id: conversationId,
});
console.log("List branches response: ", response);
expect(response.results).toBeDefined();
});

test("Delete a conversation", async () => {
const response = await client.conversations.delete({ id: conversationId });
expect(response.results).toBeDefined();
});
});
2 changes: 0 additions & 2 deletions js/sdk/__tests__/PromptsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ const baseUrl = "http://localhost:7272";

describe("r2rClient V3 Collections Integration Tests", () => {
let client: r2rClient;
let collectionId: string;
let documentId: string;

beforeAll(async () => {
client = new r2rClient(baseUrl);
Expand Down
2 changes: 1 addition & 1 deletion js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ describe("r2rClient Integration Tests", () => {

test("Create conversation", async () => {
const createConversationResponse = await client.createConversation();
createdConversationId = createConversationResponse.results;
createdConversationId = createConversationResponse.results.id;
expect(createdConversationId).toBeDefined();
});

Expand Down
2 changes: 1 addition & 1 deletion js/sdk/__tests__/r2rV2ClientIntegrationUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ describe("r2rClient Integration Tests", () => {

test("Create conversation", async () => {
const createConversationResponse = await client.createConversation();
createdConversationId = createConversationResponse.results;
createdConversationId = createConversationResponse.results.id;
expect(createdConversationId).toBeDefined();
});

Expand Down
17 changes: 16 additions & 1 deletion js/sdk/src/r2rClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import FormData from "form-data";
import { BaseClient } from "./baseClient";

import { CollectionsClient } from "./v3/clients/collections";
import { ConversationsClient } from "./v3/clients/conversations";
import { DocumentsClient } from "./v3/clients/documents";
import { UsersClient } from "./v3/clients/users";
import { PromptsClient } from "./v3/clients/prompts";
Expand Down Expand Up @@ -33,15 +34,17 @@ import {

export class r2rClient extends BaseClient {
public readonly collections: CollectionsClient;
public readonly conversations: ConversationsClient;
public readonly documents: DocumentsClient;
public readonly users: UsersClient;
public readonly prompts: PromptsClient;

constructor(baseURL: string, anonymousTelemetry = true) {
super(baseURL, "", anonymousTelemetry);

this.documents = new DocumentsClient(this);
this.collections = new CollectionsClient(this);
this.conversations = new ConversationsClient(this);
this.documents = new DocumentsClient(this);
this.users = new UsersClient(this);
this.prompts = new PromptsClient(this);

Expand Down Expand Up @@ -726,6 +729,7 @@ export class r2rClient extends BaseClient {
* @param template The new template for the prompt.
* @param input_types The new input types for the prompt.
* @returns A promise that resolves to the response from the server.
* @deprecated Use `client.prompts.update` instead.
*/
@feature("updatePrompt")
async updatePrompt(
Expand Down Expand Up @@ -757,6 +761,7 @@ export class r2rClient extends BaseClient {
* @param name The name of the prompt.
* @param template The template for the prompt.
* @param input_types The input types for the prompt.
* @deprecated Use `client.prompts.create` instead.
*/
@feature("addPrompt")
async addPrompt(
Expand All @@ -781,6 +786,7 @@ export class r2rClient extends BaseClient {
* @param name The name of the prompt to retrieve.
* @param inputs Inputs for the prompt.
* @param prompt_override Override for the prompt template.
* @deprecated Use `client.prompts.retrieve` instead.
* @returns
*/
@feature("getPrompt")
Expand All @@ -805,6 +811,7 @@ export class r2rClient extends BaseClient {
/**
* Get all prompts from the system.
* @returns A promise that resolves to the response from the server.
* @deprecated Use `client.prompts.list` instead.
*/
@feature("getAllPrompts")
async getAllPrompts(): Promise<Record<string, any>> {
Expand All @@ -816,6 +823,7 @@ export class r2rClient extends BaseClient {
* Delete a prompt from the system.
* @param prompt_name The name of the prompt to delete.
* @returns A promise that resolves to the response from the server.
* @deprecated Use `client.prompts.delete` instead.
*/
@feature("deletePrompt")
async deletePrompt(prompt_name: string): Promise<Record<string, any>> {
Expand Down Expand Up @@ -1345,6 +1353,7 @@ export class r2rClient extends BaseClient {
* Get an overview of existing conversations.
* @param limit The maximum number of conversations to return.
* @param offset The offset to start listing conversations from.
* @deprecated use `client.conversations.list` instead
* @returns
*/
@feature("conversationsOverview")
Expand Down Expand Up @@ -1373,6 +1382,7 @@ export class r2rClient extends BaseClient {
* Get a conversation by its ID.
* @param conversationId The ID of the conversation to get.
* @param branchId The ID of the branch (optional).
* @deprecated use `client.conversations.retrieve` instead
* @returns A promise that resolves to the response from the server.
*/
@feature("getConversation")
Expand All @@ -1390,6 +1400,7 @@ export class r2rClient extends BaseClient {

/**
* Create a new conversation.
* @deprecated use `client.conversations.create` instead
* @returns A promise that resolves to the response from the server.
*/
@feature("createConversation")
Expand All @@ -1402,6 +1413,7 @@ export class r2rClient extends BaseClient {
* Add a message to an existing conversation.
* @param conversationId
* @param message
* @deprecated use `client.conversations.addMessage` instead
* @returns
*/
@feature("addMessage")
Expand All @@ -1426,6 +1438,7 @@ export class r2rClient extends BaseClient {
* Update a message in an existing conversation.
* @param message_id The ID of the message to update.
* @param message The updated message.
* @deprecated use `client.conversations.updateMessage` instead
* @returns A promise that resolves to the response from the server.
*/
@feature("updateMessage")
Expand All @@ -1442,6 +1455,7 @@ export class r2rClient extends BaseClient {
/**
* Get an overview of branches in a conversation.
* @param conversationId The ID of the conversation to get branches for.
* @deprecated use `client.conversations.listBranches` instead
* @returns A promise that resolves to the response from the server.
*/
@feature("branchesOverview")
Expand Down Expand Up @@ -1494,6 +1508,7 @@ export class r2rClient extends BaseClient {
/**
* Delete a conversation by its ID.
* @param conversationId The ID of the conversation to delete.
* @deprecated use `client.conversations.delete` instead
* @returns A promise that resolves to the response from the server.
*/
@feature("deleteConversation")
Expand Down
133 changes: 133 additions & 0 deletions js/sdk/src/v3/clients/conversations.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import { r2rClient } from "../../r2rClient";

export class ConversationsClient {
constructor(private client: r2rClient) {}

/**
* Create a new conversation.
* @returns
*/
async create(): Promise<any> {
return this.client.makeRequest("POST", "conversations");
}

/**
* List conversations with pagination and sorting options.
* @param ids List of conversation IDs to retrieve
* @param offset Specifies the number of objects to skip. Defaults to 0.
* @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
* @returns
*/
async list(options?: {
ids?: string[];
offset?: number;
limit?: number;
}): Promise<any> {
const params: Record<string, any> = {
offset: options?.offset ?? 0,
limit: options?.limit ?? 100,
};

if (options?.ids && options.ids.length > 0) {
params.ids = options.ids;
}

return this.client.makeRequest("GET", "conversations", {
params,
});
}

/**
* Get detailed information about a specific conversation.
* @param id The ID of the conversation to retrieve
* @param branch_id The ID of the branch to retrieve
* @returns
*/
async retrieve(options: { id: string; branch_id?: string }): Promise<any> {
const params: Record<string, any> = {
branch_id: options.branch_id,
};

return this.client.makeRequest("GET", `conversations/${options.id}`, {
params,
});
}

/**
* Delete a conversation.
* @param id The ID of the conversation to delete
* @returns
*/
async delete(options: { id: string }): Promise<any> {
return this.client.makeRequest("DELETE", `conversations/${options.id}`);
}

/**
* Add a new message to a conversation.
* @param id The ID of the conversation to add the message to
* @param content The content of the message
* @param role The role of the message (e.g., "user" or "assistant")
* @param parent_id The ID of the parent message
* @param metadata Additional metadata to attach to the message
* @returns
*/
async addMessage(options: {
id: string;
content: string;
role: string;
parent_id?: string;
metadata?: Record<string, any>;
}): Promise<any> {
const data: Record<string, any> = {
content: options.content,
role: options.role,
...(options.parent_id && { parent_id: options.parent_id }),
...(options.metadata && { metadata: options.metadata }),
};

return this.client.makeRequest(
"POST",
`conversations/${options.id}/messages`,
{
data,
},
);
}

/**
* Update an existing message in a conversation.
* @param id The ID of the conversation containing the message
* @param message_id The ID of the message to update
* @param content The new content of the message
* @returns
*/
async updateMessage(options: {
id: string;
message_id: string;
content: string;
}): Promise<any> {
const data: Record<string, any> = {
content: options.content,
};

return this.client.makeRequest(
"POST",
`conversations/${options.id}/messages/${options.message_id}`,
{
data,
},
);
}

/**
* List all branches in a conversation.
* @param id The ID of the conversation to list branches for
* @returns
*/
async listBranches(options: { id: string }): Promise<any> {
return this.client.makeRequest(
"GET",
`conversations/${options.id}/branches`,
);
}
}
12 changes: 7 additions & 5 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
AnalyticsResponse,
AppSettingsResponse,
CollectionResponse,
ConversationOverviewResponse,
ConversationResponse,
DocumentChunkResponse,
DocumentOverviewResponse,
LogResponse,
Expand All @@ -51,7 +51,7 @@
WrappedCollectionResponse,
WrappedCollectionsResponse,
WrappedConversationResponse,
WrappedConversationsOverviewResponse,
WrappedConversationsResponse,
WrappedDeleteResponse,
WrappedDocumentChunkResponse,
WrappedDocumentResponse,
Expand Down Expand Up @@ -118,14 +118,12 @@
"DocumentOverviewResponse",
"DocumentChunkResponse",
"CollectionResponse",
"ConversationOverviewResponse",
"WrappedPromptMessageResponse",
"WrappedServerStatsResponse",
"WrappedLogResponse",
"WrappedAnalyticsResponse",
"WrappedAppSettingsResponse",
"WrappedUserOverviewResponse",
"WrappedConversationResponse",
"WrappedDocumentChunkResponse",
"WrappedDocumentResponse",
"WrappedDocumentOverviewResponse",
Expand All @@ -135,13 +133,17 @@
"WrappedDocumentChunkResponse",
"WrappedAddUserResponse",
"WrappedUsersInCollectionResponse",
# Conversation Responses
"ConversationResponse",
"WrappedConversationResponse",
"WrappedConversationsResponse",
# Prompt Responses
"WrappedPromptResponse",
"WrappedPromptsResponse",
# TODO: This needs to be cleaned up
"WrappedUserCollectionResponse",
"WrappedDocumentChunkResponse",
"WrappedDeleteResponse",
"WrappedConversationsOverviewResponse",
"WrappedMessageResponse",
# Retrieval Responses
"CombinedSearchResponse",
Expand Down
Loading

0 comments on commit f99ed4b

Please sign in to comment.