Skip to content

Commit

Permalink
feat: ability to mimic /v1/models/ api (#68)
Browse files Browse the repository at this point in the history
* add: ability to mimic `/v1/models/` api

* fix: missing json dependency

* fix: accidentally remove dependency

* fix: linter error caught by gh workflow
  • Loading branch information
john-theo authored Sep 27, 2023
1 parent 80330ec commit c40e27c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
80 changes: 80 additions & 0 deletions azure/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package azure

import (
"bytes"
"encoding/json"
"fmt"
"github.com/stulzq/azure-openai-proxy/util"
"io"
Expand All @@ -21,6 +22,85 @@ func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc {
}
}

type DeploymentInfo struct {
Data []map[string]interface{} `json:"data"`
Object string `json:"object"`
}

func ModelProxy(c *gin.Context) {
// Create a channel to receive the results of each request
results := make(chan []map[string]interface{}, len(ModelDeploymentConfig))

// Send a request for each deployment in the map
for _, deployment := range ModelDeploymentConfig {
go func(deployment DeploymentConfig) {
// Create the request
req, err := http.NewRequest(http.MethodGet, deployment.Endpoint+"/openai/deployments?api-version=2022-12-01", nil)
if err != nil {
log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}

// Set the auth header
req.Header.Set(AuthHeaderKey, deployment.ApiKey)

// Send the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Printf("error sending request for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("unexpected status code %d for deployment %s", resp.StatusCode, deployment.DeploymentName)
results <- nil
return
}

// Read the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("error reading response body for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}

// Parse the response body as JSON
var deplotmentInfo DeploymentInfo
err = json.Unmarshal(body, &deplotmentInfo)
if err != nil {
log.Printf("error parsing response body for deployment %s: %v", deployment.DeploymentName, err)
results <- nil
return
}
results <- deplotmentInfo.Data
}(deployment)
}

// Wait for all requests to finish and collect the results
var allResults []map[string]interface{}
for i := 0; i < len(ModelDeploymentConfig); i++ {
result := <-results
if result != nil {
allResults = append(allResults, result...)
}
}
var info = DeploymentInfo{Data: allResults, Object: "list"}
combinedResults, err := json.Marshal(info)
if err != nil {
log.Printf("error marshalling results: %v", err)
util.SendError(c, err)
return
}

// Set the response headers and body
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(combinedResults))
}

// Proxy Azure OpenAI
func Proxy(c *gin.Context, requestConverter RequestConverter) {
if c.Request.Method == http.MethodOptions {
Expand Down
1 change: 1 addition & 0 deletions cmd/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func registerRoute(r *gin.Engine) {
})
apiBase := viper.GetString("api_base")
stripPrefixConverter := azure.NewStripPrefixConverter(apiBase)
r.GET(stripPrefixConverter.Prefix+"/models", azure.ModelProxy)
templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings")
apiBasedRouter := r.Group(apiBase)
{
Expand Down

0 comments on commit c40e27c

Please sign in to comment.