diff --git a/module/ai/iml.go b/module/ai/iml.go index 427155b1..201ff398 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -7,17 +7,17 @@ import ( "net/http" "sort" "time" - + "github.com/APIParkLab/APIPark/service/service" - + ai_key_dto "github.com/APIParkLab/APIPark/module/ai-key/dto" - + ai_key "github.com/APIParkLab/APIPark/service/ai-key" - + "github.com/eolinker/go-common/auto" - + ai_api "github.com/APIParkLab/APIPark/service/ai-api" - + model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime" "github.com/APIParkLab/APIPark/gateway" ai_dto "github.com/APIParkLab/APIPark/module/ai/dto" @@ -30,7 +30,7 @@ import ( ) func newKey(key *ai_key.Key) *gateway.DynamicRelease { - + return &gateway.DynamicRelease{ BasicItem: &gateway.BasicItem{ ID: fmt.Sprintf("%s-%s", key.Provider, key.ID), @@ -91,7 +91,7 @@ func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error if !has { continue } - + l, has := providerMap[id] if !has { continue @@ -139,7 +139,7 @@ func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error return err } return i.syncGateway(ctx, cluster.DefaultClusterID, offlineReleases, false) - + }) } @@ -176,7 +176,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto. return nil, nil, fmt.Errorf("create default key error:%v", err) } } - + p, has := model_runtime.GetProvider(l.Id) if !has { continue @@ -185,7 +185,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto. if err != nil { return nil, nil, fmt.Errorf("get provider keys error:%v", err) } - + keysStatus := make([]*ai_dto.KeyStatus, 0, len(keys)) for _, k := range keys { status := ai_key_dto.ToKeyStatus(k.Status) @@ -204,7 +204,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto. sort.Slice(keysStatus, func(i, j int) bool { return keysStatus[i].Priority < keysStatus[j].Priority }) - + providers = append(providers, &ai_dto.ConfiguredProviderItem{ Id: l.Id, Name: l.Name, @@ -248,7 +248,7 @@ func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.Simp return nil, err } providers := model_runtime.Providers() - + providerMap := utils.SliceToMap(list, func(e *ai.Provider) string { return e.Id }) @@ -315,7 +315,7 @@ func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context) ([]*a Name: model.ID(), }, } - + items = append(items, item) } sort.Slice(items, func(i, j int) bool { @@ -426,7 +426,7 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr if info.Priority == 0 { info.Priority = maxPriority } - + return &ai_dto.Provider{ Id: info.Id, Name: info.Name, @@ -445,12 +445,12 @@ func (i *imlProviderModule) LLMs(ctx context.Context, driver string) ([]*ai_dto. if !has { return nil, nil, fmt.Errorf("ai provider not found") } - + llms, has := p.ModelsByType(model_runtime.ModelTypeLLM) if !has { return nil, nil, fmt.Errorf("ai provider not found") } - + items := make([]*ai_dto.LLMItem, 0, len(llms)) for _, v := range llms { items = append(items, &ai_dto.LLMItem{ @@ -478,7 +478,7 @@ func (i *imlProviderModule) LLMs(ctx context.Context, driver string) ([]*ai_dto. Logo: p.Logo(), }, nil } - + return items, &ai_dto.ProviderItem{ Id: info.Id, Name: info.Name, @@ -558,19 +558,19 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if err != nil { return err } - - if input.Enable != nil { - status = 0 - if *input.Enable { - status = 1 - } - pInfo.Status = &status - } + + //if input.Enable != nil { + // status = 0 + // if *input.Enable { + // status = 1 + // } + // pInfo.Status = &status + //} err = i.providerService.Save(ctx, id, pInfo) if err != nil { return err } - + if *pInfo.Status == 0 { return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ { @@ -614,11 +614,11 @@ func (i *imlProviderModule) getAiProviders(ctx context.Context) ([]*gateway.Dyna if err != nil { return nil, err } - + providers := make([]*gateway.DynamicRelease, 0, len(list)) for _, l := range list { // 获取当前供应商所有Key信息 - + driver, has := model_runtime.GetProvider(l.Id) if !has { return nil, fmt.Errorf("provider not found: %s", l.Id) @@ -653,7 +653,7 @@ func (i *imlProviderModule) initGateway(ctx context.Context, clusterId string, c if err != nil { return err } - + for _, p := range providers { client, err := clientDriver.Dynamic(p.Resource) if err != nil { @@ -664,7 +664,7 @@ func (i *imlProviderModule) initGateway(ctx context.Context, clusterId string, c return err } } - + return nil } @@ -694,7 +694,7 @@ func (i *imlProviderModule) syncGateway(ctx context.Context, clusterId string, r return err } } - + return nil } @@ -727,9 +727,9 @@ func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId st Name: s.Name, }) serviceTeamMap[s.Id] = s.Team - + } - + modelItems := utils.SliceToSlice(p.Models(), func(e model_runtime.IModel) *ai_dto.BasicInfo { return &ai_dto.BasicInfo{ Id: e.ID(), @@ -752,7 +752,7 @@ func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId st if err != nil { return nil, nil, 0, err } - + if len(apis) <= 0 { return nil, condition, 0, nil } @@ -767,10 +767,10 @@ func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId st if err != nil { return nil, nil, 0, err } - + apiItems := utils.SliceToSlice(results, func(e *ai_api.APIUse) *ai_dto.APIItem { info := apiMap[e.API] - + delete(apiMap, e.API) return &ai_dto.APIItem{ Id: e.API, @@ -814,5 +814,8 @@ func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId st for i := offset; i < offset+size && i < len(sortApis); i++ { apiItems = append(apiItems, sortApis[i]) } - - total := int64(len(apis)) \ No newline at end of file + + total := int64(len(apis)) + return apiItems, condition, total, nil + } +}