From 586afa0cd99baba0e6e44251776955367aa33435 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 24 Jan 2026 04:52:06 +0000 Subject: [PATCH] fix: improve voice filtering (#8) (thanks @joelbdavies) --- CHANGELOG.md | 1 + cmd/speak.go | 22 +++++++++++------- cmd/speak_test.go | 36 +++++++++++++++++++++--------- cmd/voices.go | 19 ++++++++++++++-- cmd/voices_test.go | 18 +++++++++++++++ internal/elevenlabs/client.go | 9 ++------ internal/elevenlabs/client_test.go | 11 ++++----- 7 files changed, 81 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 439feca..67f3bb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## 0.2.2 - Unreleased ### Fixed - Voice ID resolution respects `--voice-id` and avoids misclassifying long names; `--rate` now overrides `--speed` validation. (#7, thanks @joelbdavies) +- Voice name matching now uses exact/substring checks without falling back to unrelated voices; voice search is handled client-side. (#8, thanks @joelbdavies) ## 0.2.1 - 2026-01-01 ### Fixed diff --git a/cmd/speak.go b/cmd/speak.go index ebf3ca8..e043407 100644 --- a/cmd/speak.go +++ b/cmd/speak.go @@ -432,7 +432,7 @@ func resolveVoice(ctx context.Context, client *elevenlabs.Client, voiceInput str if voiceInput == "" { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - voices, err := client.ListVoices(ctx, "") + voices, err := client.ListVoices(ctx) if err != nil { return "", fmt.Errorf("voice not specified and failed to fetch voices: %w", err) } @@ -445,7 +445,7 @@ func resolveVoice(ctx context.Context, client *elevenlabs.Client, voiceInput str if voiceInput == "?" { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - voices, err := client.ListVoices(ctx, "") + voices, err := client.ListVoices(ctx) if err != nil { return "", err } @@ -474,7 +474,7 @@ func resolveVoice(ctx context.Context, client *elevenlabs.Client, voiceInput str } ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - voices, err := client.ListVoices(ctx, voiceInput) + voices, err := client.ListVoices(ctx) if err != nil { return "", err } @@ -490,22 +490,28 @@ func resolveVoice(ctx context.Context, client *elevenlabs.Client, voiceInput str ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - voices, err := client.ListVoices(ctx, voiceInput) + voices, err := client.ListVoices(ctx) if err != nil { return "", err } voiceInputLower := strings.ToLower(voiceInput) + + // First, check for exact match (case-insensitive) for _, v := range voices { if strings.ToLower(v.Name) == voiceInputLower { fmt.Fprintf(os.Stderr, "using voice %s (%s)\n", v.Name, v.VoiceID) return v.VoiceID, nil } } - if len(voices) > 0 { - v := voices[0] - fmt.Fprintf(os.Stderr, "using closest voice match %s (%s)\n", v.Name, v.VoiceID) - return v.VoiceID, nil + + // Then, check for substring match (case-insensitive) + for _, v := range voices { + if strings.Contains(strings.ToLower(v.Name), voiceInputLower) { + fmt.Fprintf(os.Stderr, "using voice %s (%s)\n", v.Name, v.VoiceID) + return v.VoiceID, nil + } } + return "", fmt.Errorf("voice %q not found; try 'sag voices' or -v '?'", voiceInput) } diff --git a/cmd/speak_test.go b/cmd/speak_test.go index 443a3f9..c6bff15 100644 --- a/cmd/speak_test.go +++ b/cmd/speak_test.go @@ -235,7 +235,7 @@ func TestResolveVoiceLooksLikeIDNoMatchPassesThrough(t *testing.T) { } } -func TestResolveVoiceClosestMatch(t *testing.T) { +func TestResolveVoiceNoMatch(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if _, err := w.Write([]byte(`{"voices":[{"voice_id":"id1","name":"Near","category":"premade"}]}`)); err != nil { t.Fatalf("write response: %v", err) @@ -243,19 +243,37 @@ func TestResolveVoiceClosestMatch(t *testing.T) { })) defer srv.Close() + client := elevenlabs.NewClient("key", srv.URL) + _, err := resolveVoice(context.Background(), client, "nothing-match", false) + if err == nil { + t.Fatalf("expected error for non-matching voice") + } + if !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected 'not found' error, got %q", err.Error()) + } +} + +func TestResolveVoicePartialMatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if _, err := w.Write([]byte(`{"voices":[{"voice_id":"id1","name":"Sarah","category":"premade"},{"voice_id":"id2","name":"Roger - Casual","category":"premade"}]}`)); err != nil { + t.Fatalf("write response: %v", err) + } + })) + defer srv.Close() + restore, read := captureStderr(t) defer restore() client := elevenlabs.NewClient("key", srv.URL) - id, err := resolveVoice(context.Background(), client, "nothing-match", false) + id, err := resolveVoice(context.Background(), client, "roger", false) if err != nil { t.Fatalf("resolveVoice error: %v", err) } - if id != "id1" { - t.Fatalf("expected closest id1, got %q", id) + if id != "id2" { + t.Fatalf("expected id2 for partial match 'roger', got %q", id) } - if out := read(); !strings.Contains(out, "using closest voice match") { - t.Fatalf("expected closest match notice, got %q", out) + if out := read(); !strings.Contains(out, "using voice") { + t.Fatalf("expected 'using voice' notice, got %q", out) } } @@ -440,11 +458,7 @@ func captureStderr(t *testing.T) (restore func(), read func() string) { func TestResolveVoiceByName(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // ensure search param contains name - if !strings.Contains(r.URL.RawQuery, "search=roger") { - t.Fatalf("expected search param to contain 'roger', got %s", r.URL.RawQuery) - } - if _, err := w.Write([]byte(`{"voices":[{"voice_id":"id-roger","name":"Roger","category":"premade"}]}`)); err != nil { + if _, err := w.Write([]byte(`{"voices":[{"voice_id":"id-sarah","name":"Sarah","category":"premade"},{"voice_id":"id-roger","name":"Roger","category":"premade"}]}`)); err != nil { t.Fatalf("write response: %v", err) } })) diff --git a/cmd/voices.go b/cmd/voices.go index 87d7dfe..984292e 100644 --- a/cmd/voices.go +++ b/cmd/voices.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "strings" "text/tabwriter" "time" @@ -33,10 +34,13 @@ func init() { ctx, cancel := context.WithTimeout(cmd.Context(), 30*time.Second) defer cancel() - voices, err := client.ListVoices(ctx, opts.search) + voices, err := client.ListVoices(ctx) if err != nil { return err } + if opts.search != "" { + voices = filterVoicesByName(voices, opts.search) + } if opts.limit > 0 && len(voices) > opts.limit { voices = voices[:opts.limit] @@ -55,7 +59,18 @@ func init() { }, } - cmd.Flags().StringVar(&opts.search, "search", "", "Filter voices by name (server-side when supported)") + cmd.Flags().StringVar(&opts.search, "search", "", "Filter voices by name (client-side)") cmd.Flags().IntVar(&opts.limit, "limit", opts.limit, "Maximum rows to display (0 = all)") rootCmd.AddCommand(cmd) } + +func filterVoicesByName(voices []elevenlabs.Voice, search string) []elevenlabs.Voice { + searchLower := strings.ToLower(search) + filtered := make([]elevenlabs.Voice, 0, len(voices)) + for _, v := range voices { + if strings.Contains(strings.ToLower(v.Name), searchLower) { + filtered = append(filtered, v) + } + } + return filtered +} diff --git a/cmd/voices_test.go b/cmd/voices_test.go index d157b23..ab45d75 100644 --- a/cmd/voices_test.go +++ b/cmd/voices_test.go @@ -7,6 +7,8 @@ import ( "net/http/httptest" "os" "testing" + + "github.com/steipete/sag/internal/elevenlabs" ) func TestVoicesCommand(t *testing.T) { @@ -43,6 +45,22 @@ func TestVoicesCommand(t *testing.T) { _ = os.Unsetenv("ELEVENLABS_API_KEY") } +func TestFilterVoicesByName(t *testing.T) { + voices := []elevenlabs.Voice{ + {VoiceID: "id1", Name: "Sarah"}, + {VoiceID: "id2", Name: "Roger - Casual"}, + {VoiceID: "id3", Name: "ROGUE"}, + } + + filtered := filterVoicesByName(voices, "rog") + if len(filtered) != 2 { + t.Fatalf("expected 2 voices, got %d", len(filtered)) + } + if filtered[0].VoiceID != "id2" || filtered[1].VoiceID != "id3" { + t.Fatalf("unexpected filter order: %+v", filtered) + } +} + func captureStdoutVoices(t *testing.T) (restore func(), read func() string) { t.Helper() orig := os.Stdout diff --git a/internal/elevenlabs/client.go b/internal/elevenlabs/client.go index b7ec2c1..85d9cdd 100644 --- a/internal/elevenlabs/client.go +++ b/internal/elevenlabs/client.go @@ -47,18 +47,13 @@ type listVoicesResponse struct { Next *string `json:"next_page_token,omitempty"` } -// ListVoices fetches voices; search filters by name substring when provided. -func (c *Client) ListVoices(ctx context.Context, search string) ([]Voice, error) { +// ListVoices fetches available voices. +func (c *Client) ListVoices(ctx context.Context) ([]Voice, error) { u, err := url.Parse(c.baseURL) if err != nil { return nil, err } u.Path = path.Join(u.Path, "/v1/voices") - if search != "" { - q := u.Query() - q.Set("search", search) - u.RawQuery = q.Encode() - } req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { diff --git a/internal/elevenlabs/client_test.go b/internal/elevenlabs/client_test.go index f169b8e..7bd310e 100644 --- a/internal/elevenlabs/client_test.go +++ b/internal/elevenlabs/client_test.go @@ -23,21 +23,18 @@ func TestListVoices(t *testing.T) { if r.URL.Path != "/v1/voices" { t.Fatalf("unexpected path: %s", r.URL.Path) } - if search := r.URL.Query().Get("search"); search != "roger" { - t.Fatalf("expected search query 'roger', got %q", search) - } w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"voices":[{"voice_id":"id1","name":"Roger","category":"premade"}]}`)) + _, _ = w.Write([]byte(`{"voices":[{"voice_id":"id1","name":"Sarah","category":"premade"},{"voice_id":"id2","name":"Roger","category":"premade"}]}`)) })) defer srv.Close() c := NewClient("key", srv.URL) - voices, err := c.ListVoices(context.Background(), "roger") + voices, err := c.ListVoices(context.Background()) if err != nil { t.Fatalf("ListVoices error: %v", err) } - if len(voices) != 1 || voices[0].VoiceID != "id1" { - t.Fatalf("unexpected voices: %+v", voices) + if len(voices) != 2 { + t.Fatalf("expected 2 voices, got: %+v", voices) } }