Skip to content

Commit

Permalink
Merge pull request #77 from kounoike:feature/whisper-trascribe
Browse files Browse the repository at this point in the history
OpenAI whisper による録画の自動文字起こし
  • Loading branch information
kounoike authored Apr 9, 2023
2 parents 24d2d6f + 35ed029 commit 9652673
Show file tree
Hide file tree
Showing 15 changed files with 339 additions and 100 deletions.
4 changes: 2 additions & 2 deletions cmd/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (c *BotCommand) Execute(ctx context.Context, f *flag.FlagSet, args ...inter
var asynqClient *asynq.Client
var asynqInspector *asynq.Inspector

if config.Encoding.Enabled {
if config.Encoding.Enabled || config.Transcription.Enabled {
redisAddr := fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port)
asynqClient = asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
defer asynqClient.Close()
Expand Down Expand Up @@ -217,7 +217,7 @@ func (c *BotCommand) Execute(ctx context.Context, f *flag.FlagSet, args ...inter
logger.Info("CreateChannels OK")

// エンコード結果取得タスク
if config.Encoding.Enabled {
if config.Encoding.Enabled || config.Transcription.Enabled {
scheduler.Every("1m").Do(func() {
err := usecase.CheckCompletedTask(ctx)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/hibiken/asynq"
"github.com/jinzhu/configor"
"github.com/kounoike/dtv-discord-go/config"
"github.com/kounoike/dtv-discord-go/gpt"
"github.com/kounoike/dtv-discord-go/tasks"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
Expand Down Expand Up @@ -76,9 +77,12 @@ func (c *WorkerCommand) Execute(ctx context.Context, f *flag.FlagSet, args ...in
},
)

gpt := gpt.NewGPTClient(config.OpenAI.Enabled, config.OpenAI.Token, logger)

mux := asynq.NewServeMux()
tmpl := template.Must(template.New("output-name-tmpl").Parse(config.Encoding.EncodeCommandTemplate))
tmpl := template.Must(template.New("encode-command-tmpl").Parse(config.Encoding.EncodeCommandTemplate))
mux.Handle(tasks.TypeProgramEncode, tasks.NewProgramEncoder(logger, tmpl, config.Recording.BasePath, config.Encoding.BasePath))
mux.Handle(tasks.TypeProgramTranscription, tasks.NewProgramTranscriber(logger, gpt, tmpl, config.Recording.BasePath, config.Transcription.BasePath))
mux.HandleFunc(tasks.TypeHello, tasks.HelloTask)

logger.Debug("Starting worker server")
Expand Down
5 changes: 5 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ type Config struct {
OutputPathTemplate string `required:"true" env:"ENCODING_OUTPUT_PATH_TEMPLATE"`
EncodeCommandTemplate string `required:"true" env:"ENCODING_COMMAND"`
}
Transcription struct {
Enabled bool `required:"true" env:"TRANSCRIPTION_ENABLED"`
BasePath string `required:"true" env:"TRANSCRIPTION_BASE_PATH"`
OutputPathTemplate string `required:"true" env:"TRANSCRIPTION_OUTPUT_PATH_TEMPLATE"`
}
Match struct {
KanaMatch bool `default:"true" env:"KANA_MATCH"`
FuzzyMatch bool `default:"true" env:"FUZZY_MATCH"`
Expand Down
13 changes: 7 additions & 6 deletions discord/emoji.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package discord

const (
RecordingReactionEmoji = "🔴"
RecordedReactionEmoji = "📼"
EncodedReactionEmoji = "🗜️"
OkReactionEmoji = "🆗"
NotifyReactionEmoji = "👀"
AutoSearchReactionEmoji = "🔍"
RecordingReactionEmoji = "🔴"
RecordedReactionEmoji = "📼"
EncodedReactionEmoji = "🗜️"
OkReactionEmoji = "🆗"
NotifyReactionEmoji = "👀"
AutoSearchReactionEmoji = "🔍"
TranscriptionReactionEmoji = "📝"
)
4 changes: 4 additions & 0 deletions docker-compose/config.yml.example
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ encoding:
basepath: "/encoded"
outputpathtemplate: "{{.Program.Name | fold}}-{{.Program.StartTime.Format \"20060102-1504\"}}-{{.Service.Name | fold}}.mp4"
encodecommandtemplate: "ffmpeg -i {{.InputPath}} {{.OutputPath}} -y"
transcription:
enabled: true
basepath: "/transcribed"
outputpathtemplate: "{{.Title | fold}} #{{.Episode}} [{{.Subtitle}}] {{.Program.StartTime.Format \"20060102-1504\"}} {{.Service.Name | fold}}.txt"
match:
kanamatch: true
fuzzymatch: true
Expand Down
1 change: 1 addition & 0 deletions docker-compose/docker-compose.yml.example
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ services:
- ./config.yml:/config.yml:ro
- ./mirakc/recorded:/recorded:rw
- ./mirakc/encoded:/encoded:rw
- ./mirakc/transcribed:/transcribed:rw
environment:
TZ: Asia/Tokyo
links:
Expand Down
120 changes: 85 additions & 35 deletions dtv/check_completed_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,93 @@ import (
"encoding/json"
"fmt"

"github.com/hibiken/asynq"
"github.com/kounoike/dtv-discord-go/db"
"github.com/kounoike/dtv-discord-go/discord"
"github.com/kounoike/dtv-discord-go/tasks"
"github.com/pkg/errors"
"go.uber.org/zap"
)

func (dtv *DTVUsecase) onProgramEncoded(ctx context.Context, taskInfo *asynq.TaskInfo) error {
_, err := dtv.queries.GetEncodeTaskByTaskID(ctx, taskInfo.ID)
if errors.Cause(err) != sql.ErrNoRows {
return err
}
var payload tasks.ProgramEncodePayload
err = json.Unmarshal(taskInfo.Payload, &payload)
if err != nil {
dtv.logger.Warn("task payload json.Unmarshal error", zap.Error(err))
return err
}
err = dtv.queries.InsertEncodeTask(ctx, db.InsertEncodeTaskParams{TaskID: taskInfo.ID, Status: "success"})
if err != nil {
dtv.logger.Warn("failed to InsertEncodeTask", zap.Error(err))
return err
}
_, err = dtv.discord.SendMessage(discord.InformationCategory, discord.RecordingChannel, fmt.Sprintf("**エンコード完了** `%s`のエンコードが完了しました", payload.OutputPath))
if err != nil {
dtv.logger.Warn("failed to SendMessage", zap.Error(err))
return err
}
programMessage, err := dtv.queries.GetProgramMessageByProgramID(ctx, payload.ProgramId)
if errors.Cause(err) == sql.ErrNoRows {
dtv.logger.Warn("failed to GetProgramMessageByProgramID", zap.Error(err))
return err
}
if err != nil {
dtv.logger.Warn("failed to GetProgramMessageByProgramID", zap.Error(err))
return err
}

err = dtv.discord.MessageReactionAdd(programMessage.ChannelID, programMessage.MessageID, discord.EncodedReactionEmoji)
if err != nil {
dtv.logger.Warn("failed to MessageReactionAdd", zap.Error(err))
return err
}
return nil
}

func (dtv *DTVUsecase) onProgramTranscribed(ctx context.Context, taskInfo *asynq.TaskInfo) error {
// _, err := dtv.queries.GetEncodeTaskByTaskID(ctx, taskInfo.ID)
// if errors.Cause(err) != sql.ErrNoRows {
// return err
// }
var payload tasks.ProgramTranscriptionPayload
err := json.Unmarshal(taskInfo.Payload, &payload)
if err != nil {
dtv.logger.Warn("task payload json.Unmarshal error", zap.Error(err))
return err
}
// err = dtv.queries.InsertEncodeTask(ctx, db.InsertEncodeTaskParams{TaskID: taskInfo.ID, Status: "success"})
// if err != nil {
// dtv.logger.Warn("failed to InsertEncodeTask", zap.Error(err))
// return err
// }
_, err = dtv.discord.SendMessage(discord.InformationCategory, discord.RecordingChannel, fmt.Sprintf("**文字起こし完了** `%s`の文字起こしが完了しました", payload.OutputPath))
if err != nil {
dtv.logger.Warn("failed to SendMessage", zap.Error(err))
return err
}
programMessage, err := dtv.queries.GetProgramMessageByProgramID(ctx, payload.ProgramId)
if errors.Cause(err) == sql.ErrNoRows {
dtv.logger.Warn("failed to GetProgramMessageByProgramID", zap.Error(err))
return err
}
if err != nil {
dtv.logger.Warn("failed to GetProgramMessageByProgramID", zap.Error(err))
return err
}

err = dtv.discord.MessageReactionAdd(programMessage.ChannelID, programMessage.MessageID, discord.TranscriptionReactionEmoji)
if err != nil {
dtv.logger.Warn("failed to MessageReactionAdd", zap.Error(err))
return err
}
return nil
}

func (dtv *DTVUsecase) CheckCompletedTask(ctx context.Context) error {
dtv.logger.Debug("Start CheckCompletedTask")
if dtv.inspector == nil {
return nil
}
Expand All @@ -23,41 +101,13 @@ func (dtv *DTVUsecase) CheckCompletedTask(ctx context.Context) error {
return err
}
for _, taskInfo := range taskInfoList {
if taskInfo.Type != tasks.TypeProgramEncode {
switch taskInfo.Type {
case tasks.TypeHello:
continue
}
_, err := dtv.queries.GetEncodeTaskByTaskID(ctx, taskInfo.ID)
if errors.Cause(err) != sql.ErrNoRows {
continue
}
var payload tasks.ProgramEncodePayload
err = json.Unmarshal(taskInfo.Payload, &payload)
if err != nil {
dtv.logger.Warn("task payload json.Unmarshal error", zap.Error(err))
continue
}
err = dtv.queries.InsertEncodeTask(ctx, db.InsertEncodeTaskParams{TaskID: taskInfo.ID, Status: "success"})
if err != nil {
dtv.logger.Warn("failed to InsertEncodeTask", zap.Error(err))
continue
}
_, err = dtv.discord.SendMessage(discord.InformationCategory, discord.RecordingChannel, fmt.Sprintf("**エンコード完了** `%s`のエンコードが完了しました", payload.OutputPath))
if err != nil {
dtv.logger.Warn("failed to SendMessage", zap.Error(err))
continue
}
programMessage, err := dtv.queries.GetProgramMessageByProgramID(ctx, payload.ProgramId)
if errors.Cause(err) == sql.ErrNoRows {
dtv.logger.Warn("failed to GetProgramMessageByProgramID", zap.Error(err))
continue
}
if err != nil {
dtv.logger.Warn("failed to GetProgramMessageByProgramID", zap.Error(err))
}

err = dtv.discord.MessageReactionAdd(programMessage.ChannelID, programMessage.MessageID, discord.EncodedReactionEmoji)
if err != nil {
dtv.logger.Warn("failed to MessageReactionAdd", zap.Error(err))
case tasks.TypeProgramEncode:
_ = dtv.onProgramEncoded(ctx, taskInfo)
case tasks.TypeProgramTranscription:
_ = dtv.onProgramTranscribed(ctx, taskInfo)
}
}
return nil
Expand Down
1 change: 0 additions & 1 deletion dtv/check_failed_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
)

func (dtv *DTVUsecase) CheckFailedTask(ctx context.Context) error {
dtv.logger.Debug("Start CheckFailedTask")
if dtv.inspector == nil {
return nil
}
Expand Down
62 changes: 36 additions & 26 deletions dtv/dtv_usecase.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@ import (
)

type DTVUsecase struct {
asynq *asynq.Client
inspector *asynq.Inspector
discord *discord_client.DiscordClient
mirakc *mirakc_client.MirakcClient
scheduler *gocron.Scheduler
queries *db.Queries
logger *zap.Logger
contentPathTmpl *template.Template
outputPathTmpl *template.Template
autoSearchChannel *discordgo.Channel
kanaMatch bool
fuzzyMatch bool
gpt *gpt.GPTClient
asynq *asynq.Client
inspector *asynq.Inspector
discord *discord_client.DiscordClient
mirakc *mirakc_client.MirakcClient
scheduler *gocron.Scheduler
queries *db.Queries
logger *zap.Logger
contentPathTmpl *template.Template
encodingOutputPathTmpl *template.Template
transcriptionOutputPathTmpl *template.Template
autoSearchChannel *discordgo.Channel
gpt *gpt.GPTClient
kanaMatch bool
fuzzyMatch bool
encodingEnabled bool
transcriptionEnabled bool
}

func fold(str string) string {
Expand All @@ -55,22 +58,29 @@ func NewDTVUsecase(
if err != nil {
return nil, err
}
outputTmpl, err := template.New("output-path").Funcs(funcMap).Parse(cfg.Encoding.OutputPathTemplate)
encodingOutputTmpl, err := template.New("encoding-output-path").Funcs(funcMap).Parse(cfg.Encoding.OutputPathTemplate)
if err != nil {
return nil, err
}
transcriptionOutputTmpl, err := template.New("transcription-output-path").Funcs(funcMap).Parse(cfg.Transcription.OutputPathTemplate)
if err != nil {
return nil, err
}
return &DTVUsecase{
asynq: asynqClient,
inspector: inspector,
discord: discordClient,
mirakc: mirakcClient,
scheduler: scheduler,
queries: queries,
logger: logger,
contentPathTmpl: contentTmpl,
outputPathTmpl: outputTmpl,
kanaMatch: kanaMatch,
fuzzyMatch: fuzzyMatch,
gpt: gpt,
asynq: asynqClient,
inspector: inspector,
discord: discordClient,
mirakc: mirakcClient,
scheduler: scheduler,
queries: queries,
logger: logger,
gpt: gpt,
contentPathTmpl: contentTmpl,
encodingOutputPathTmpl: encodingOutputTmpl,
transcriptionOutputPathTmpl: transcriptionOutputTmpl,
kanaMatch: kanaMatch,
fuzzyMatch: fuzzyMatch,
encodingEnabled: cfg.Encoding.Enabled,
transcriptionEnabled: cfg.Transcription.Enabled,
}, nil
}
22 changes: 11 additions & 11 deletions dtv/get_content_path.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ func (dtv *DTVUsecase) getContentPath(ctx context.Context, program db.Program, s
return contentPath, nil
}

func (dtv *DTVUsecase) getOutputPath(ctx context.Context, program db.Program, service db.Service) (string, error) {
func (dtv *DTVUsecase) getEncodingOutputPath(ctx context.Context, program db.Program, service db.Service, pathData *template.PathTemplateData) (string, error) {
var b bytes.Buffer
data := template.PathTemplateData{}

_ = dtv.gpt.ParseTitle(ctx, program.Name, &data)

data.Program = template.PathProgram{
Name: program.Name,
StartTime: program.StartTime(),
}
data.Service = template.PathService{
Name: service.Name,
err := dtv.encodingOutputPathTmpl.Execute(&b, pathData)
if err != nil {
return "", err
}
outputPath := b.String()

err := dtv.outputPathTmpl.Execute(&b, data)
return outputPath, nil
}

func (dtv *DTVUsecase) getTranscriptionOutputPath(ctx context.Context, program db.Program, service db.Service, pathData *template.PathTemplateData) (string, error) {
var b bytes.Buffer
err := dtv.transcriptionOutputPathTmpl.Execute(&b, pathData)
if err != nil {
return "", err
}
Expand Down
Loading

0 comments on commit 9652673

Please sign in to comment.