-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #756 from b4s36t4/feat/mistral-guardrail
feat: add mistral as a new guardrail provider
- Loading branch information
Showing
5 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import { | ||
HookEventType, | ||
PluginContext, | ||
PluginHandler, | ||
PluginParameters, | ||
} from '../types'; | ||
import { getText, post } from '../utils'; | ||
|
||
interface MistralResponse { | ||
id: string; | ||
model: string; | ||
results: [ | ||
{ | ||
categories: { | ||
sexual: boolean; | ||
hate_and_discrimination: boolean; | ||
violence_and_threats: boolean; | ||
dangerous_and_criminal_content: boolean; | ||
selfharm: boolean; | ||
health: boolean; | ||
financial: boolean; | ||
law: boolean; | ||
pii: boolean; | ||
}; | ||
category_score: { | ||
sexual: number; | ||
hate_and_discrimination: number; | ||
violence_and_threats: number; | ||
dangerous_and_criminal_content: number; | ||
selfharm: number; | ||
health: number; | ||
financial: number; | ||
law: number; | ||
pii: number; | ||
}; | ||
}, | ||
]; | ||
} | ||
|
||
type GuardrailFunction = keyof MistralResponse['results'][0]['categories']; | ||
|
||
export const mistralGuardrailHandler: PluginHandler = async ( | ||
context: PluginContext, | ||
parameters: PluginParameters, | ||
eventType: HookEventType, | ||
_options | ||
) => { | ||
let error = null; | ||
let verdict = true; | ||
let data = null; | ||
|
||
const creds = parameters.credentials as Record<string, string>; | ||
if (!creds.apiKey) { | ||
return { | ||
error: 'Mistral API key not provided.', | ||
verdict: false, | ||
data: null, | ||
}; | ||
} | ||
|
||
let model = 'mistral-moderation-latest'; | ||
|
||
if (parameters.model) { | ||
// Model can be passed dynamically | ||
model = parameters.model; | ||
} | ||
|
||
const checks = parameters.categories as GuardrailFunction[]; | ||
|
||
const text = getText(context, eventType); | ||
const messages = | ||
eventType === 'beforeRequestHook' | ||
? context.request?.json?.messages | ||
: context.response?.json?.messages; | ||
|
||
// should contain text or should contain messages array | ||
if ( | ||
(!text && !Array.isArray(messages)) || | ||
(Array.isArray(messages) && messages.length === 0) | ||
) { | ||
return { | ||
error: 'Mistral: Invalid Request body', | ||
verdict: false, | ||
data: null, | ||
}; | ||
} | ||
|
||
// Use conversation guardrail if it's a chatcomplete and before hook | ||
const shouldUseConversation = | ||
eventType === 'beforeRequestHook' && context.requestType === 'chatComplete'; | ||
const url = shouldUseConversation | ||
? 'https://api.mistral.ai/v1/chat/moderations' | ||
: 'https://api.mistral.ai/v1/moderations'; | ||
|
||
try { | ||
const request = await post<MistralResponse>( | ||
url, | ||
{ | ||
model: model, | ||
...(!shouldUseConversation && { input: [text] }), | ||
...(shouldUseConversation && { input: [messages] }), | ||
}, | ||
{ | ||
headers: { | ||
Authorization: `Bearer ${creds.apiKey}`, | ||
'Content-Type': 'application/json', | ||
}, | ||
} | ||
); | ||
|
||
const categories: Record<GuardrailFunction, boolean> = | ||
request.results[0]?.categories ?? {}; | ||
const categoriesFlagged = Object.keys(categories).filter((category) => { | ||
if ( | ||
checks.includes(category as GuardrailFunction) && | ||
!!categories[category as GuardrailFunction] | ||
) { | ||
return true; | ||
} | ||
return false; | ||
}); | ||
|
||
if (categoriesFlagged.length > 0) { | ||
verdict = false; | ||
data = { flagged_categories: categoriesFlagged }; | ||
} | ||
} catch (err) { | ||
error = err; | ||
verdict = true; | ||
} | ||
|
||
return { error, verdict, data }; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
{ | ||
"id": "mistral", | ||
"description": "Mistral Content Moderation classifier leverages the most relevant policy categories for effective guardrails and introduces a pragmatic approach to LLM safety by addressing model-generated harms such as unqualified advice and PII", | ||
"credentials": { | ||
"type": "object", | ||
"properties": { | ||
"apiKey": { | ||
"type": "string", | ||
"label": "API Key", | ||
"description": "Find your API key in the Mistral la-plateforme", | ||
"encrypted": true | ||
} | ||
}, | ||
"required": ["apiKey"] | ||
}, | ||
"functions": [ | ||
{ | ||
"name": "Moderate Content", | ||
"id": "moderateContent", | ||
"type": "guardrail", | ||
"supportedHooks": ["beforeRequestHook", "afterRequestHook"], | ||
"description": [ | ||
{ | ||
"type": "subHeading", | ||
"text": "Checks if the content passes the mentioned content moderation checks." | ||
} | ||
], | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"categories": { | ||
"type": "array", | ||
"label": "Moderation Checks", | ||
"description": [ | ||
{ | ||
"type": "subHeading", | ||
"text": "Select the categories that should NOT be allowed in the content. (Checked via OpenAI moderation API)" | ||
} | ||
], | ||
"items": { | ||
"type": "string", | ||
"enum": [ | ||
"sexual", | ||
"hate_and_discrimination", | ||
"violence_and_threats", | ||
"dangerous_and_criminal_content", | ||
"selfharm", | ||
"health", | ||
"financial", | ||
"law", | ||
"pii" | ||
], | ||
"default": [ | ||
"selfharm", | ||
"pii", | ||
"sexual", | ||
"hate_and_discrimination" | ||
] | ||
} | ||
} | ||
}, | ||
"required": ["categories"] | ||
} | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import { PluginContext } from '../types'; | ||
import testCreds from './.creds.json'; | ||
import { mistralGuardrailHandler } from './index'; | ||
|
||
function getParameters() { | ||
return { | ||
credentials: testCreds, | ||
}; | ||
} | ||
|
||
describe('mistral guardrail handler', () => { | ||
it('should fail if the apiKey is invalid', async () => { | ||
const eventType = 'beforeRequestHook'; | ||
const context = { | ||
request: { text: 'this is a test string for moderations' }, | ||
}; | ||
const parameters = JSON.parse(JSON.stringify(getParameters())); | ||
parameters.credentials.apiKey = 'invalid-api-key'; | ||
|
||
const result = await mistralGuardrailHandler( | ||
context as unknown as PluginContext, | ||
parameters, | ||
eventType, | ||
{ env: {} } | ||
); | ||
|
||
expect(result).toBeDefined(); | ||
expect(result.verdict).toBe(true); | ||
expect(result.error).toBeDefined(); | ||
expect(result.data).toBeNull(); | ||
}); | ||
|
||
it('should success and return the flagged categories', async () => { | ||
const eventType = 'beforeRequestHook'; | ||
const context = { | ||
request: { | ||
text: 'my name is John Doe and my email is john.doe@example.com', | ||
}, | ||
}; | ||
const parameters = JSON.parse(JSON.stringify(getParameters())); | ||
parameters.categories = ['pii']; | ||
|
||
const result = await mistralGuardrailHandler( | ||
context as unknown as PluginContext, | ||
parameters, | ||
eventType, | ||
{ env: {} } | ||
); | ||
|
||
expect(result).toBeDefined(); | ||
expect(result.verdict).toBe(false); | ||
expect(result.error).toBeDefined(); | ||
expect(result.data).toMatchObject({ flagged_categories: ['pii'] }); | ||
}); | ||
|
||
it('should include the multiple flagged categories in the response', async () => { | ||
const eventType = 'beforeRequestHook'; | ||
const context = { | ||
request: { | ||
text: 'my name is John Doe and my email is john.doe@example.com. I am a financial advisor and I suggest you to invest in the stock market in company A.', | ||
}, | ||
}; | ||
const parameters = JSON.parse(JSON.stringify(getParameters())); | ||
parameters.categories = ['pii', 'financial']; | ||
|
||
const result = await mistralGuardrailHandler( | ||
context as unknown as PluginContext, | ||
parameters, | ||
eventType, | ||
{ env: {} } | ||
); | ||
|
||
expect(result).toBeDefined(); | ||
expect(result.verdict).toBe(false); | ||
expect(result.error).toBeDefined(); | ||
expect(result.data).toMatchObject({ | ||
flagged_categories: ['financial', 'pii'], | ||
}); | ||
}); | ||
|
||
it('should fail if the request body is invalid', async () => { | ||
const eventType = 'beforeRequestHook'; | ||
const context = { | ||
request: { text: 'this is safe string without any flagged categories' }, | ||
}; | ||
|
||
const parameters = JSON.parse(JSON.stringify(getParameters())); | ||
parameters.categories = ['pii', 'financial']; | ||
|
||
const result = await mistralGuardrailHandler( | ||
context as unknown as PluginContext, | ||
parameters, | ||
eventType, | ||
{ env: {} } | ||
); | ||
|
||
expect(result).toBeDefined(); | ||
expect(result.verdict).toBe(true); | ||
expect(result.error).toBeDefined(); | ||
expect(result.data).toBeNull(); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters