Skip to content

Commit

Permalink
chore: session read args/bufs no more new func
Browse files Browse the repository at this point in the history
  • Loading branch information
carlmontanari committed Mar 9, 2025
1 parent 16c2811 commit 83d882e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 56 deletions.
4 changes: 3 additions & 1 deletion src/ffi-driver.zig
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ pub const SendInputOperation = struct {
pub const SendPromptedInputOperation = struct {
id: u32,
input: []const u8,
prompt: []const u8,
prompt: ?[]const u8,
prompt_pattern: ?[]const u8,
response: []const u8,
options: operation.SendPromptedInputOptions,
};
Expand Down Expand Up @@ -246,6 +247,7 @@ pub const FfiDriver = struct {
self.allocator,
o.input,
o.prompt,
o.prompt_pattern,
o.response,
o.options,
) catch |err| blk: {
Expand Down
2 changes: 2 additions & 0 deletions src/ffi.zig
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ export fn sendPromptedInput(
cancel: *bool,
input: [*c]const u8,
prompt: [*c]const u8,
prompt_pattern: [*c]const u8,
response: [*c]const u8,
hidden_response: bool,
abort_input: [*c]const u8,
Expand All @@ -479,6 +480,7 @@ export fn sendPromptedInput(
.id = 0,
.input = std.mem.span(input),
.prompt = std.mem.span(prompt),
.prompt_pattern = std.mem.span(prompt_pattern),
.response = std.mem.span(response),
.options = options,
},
Expand Down
101 changes: 46 additions & 55 deletions src/session.zig
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,28 @@ const default_operation_timeout_ns: u64 = 10_000_000_000;
const default_operation_max_search_depth: u64 = 512;

const ReadThreadState = enum(u8) {
Uninitialized,
Run,
Stop,
uninitialized,
run,
stop,
};

// TODO get rid of this
fn NewReadArgs() ReadArgs {
return ReadArgs{
.pattern = null,
.patterns = null,
.actual = null,
};
}

const ReadArgs = struct {
pattern: ?*pcre2.pcre2_code_8,
patterns: ?[]const ?*pcre2.pcre2_code_8,
actual: ?[]const u8,
pattern: ?*pcre2.pcre2_code_8 = null,
patterns: ?[]const ?*pcre2.pcre2_code_8 = null,
actual: ?[]const u8 = null,
};

// TODO get rid of this (use init)
fn NewReadBufs(allocator: std.mem.Allocator) ReadBufs {
return ReadBufs{
.raw = std.ArrayList(u8).init(allocator),
.processed = std.ArrayList(u8).init(allocator),
};
}

const ReadBufs = struct {
raw: std.ArrayList(u8),
processed: std.ArrayList(u8),

fn init(allocator: std.mem.Allocator) ReadBufs {
return ReadBufs{
.raw = std.ArrayList(u8).init(allocator),
.processed = std.ArrayList(u8).init(allocator),
};
}

fn deinit(self: *ReadBufs) void {
self.raw.deinit();
self.processed.deinit();
Expand Down Expand Up @@ -192,7 +182,7 @@ pub const Session = struct {
.auth_options = auth_options,
.transport = t,
.read_thread = null,
.read_stop = std.atomic.Value(ReadThreadState).init(ReadThreadState.Uninitialized),
.read_stop = std.atomic.Value(ReadThreadState).init(ReadThreadState.uninitialized),
.read_lock = std.Thread.Mutex{},
.read_queue = std.fifo.LinearFifo(
u8,
Expand Down Expand Up @@ -250,7 +240,7 @@ pub const Session = struct {
pub fn deinit(self: *Session) void {
self.last_consumed_prompt.deinit();

if (self.read_stop.load(std.builtin.AtomicOrder.acquire) == ReadThreadState.Run) {
if (self.read_stop.load(std.builtin.AtomicOrder.acquire) == ReadThreadState.run) {
// if for whatever reason (likely because a call to driver.open failed causing a defer
// close to *not* trigger) the session didnt get "closed", ensure we do that...
self.close();
Expand Down Expand Up @@ -300,7 +290,7 @@ pub const Session = struct {
self.auth_options,
);

self.read_stop.store(ReadThreadState.Run, std.builtin.AtomicOrder.unordered);
self.read_stop.store(ReadThreadState.run, std.builtin.AtomicOrder.unordered);

// start read thread
self.read_thread = std.Thread.spawn(
Expand Down Expand Up @@ -329,7 +319,7 @@ pub const Session = struct {
}

pub fn close(self: *Session) void {
self.read_stop.store(ReadThreadState.Stop, std.builtin.AtomicOrder.unordered);
self.read_stop.store(ReadThreadState.stop, std.builtin.AtomicOrder.unordered);

if (self.read_thread != null) {
self.read_thread.?.join();
Expand All @@ -346,7 +336,7 @@ pub const Session = struct {

var cur_read_delay_ns: u64 = self.options.read_delay_min_ns;

while (self.read_stop.load(std.builtin.AtomicOrder.acquire) != ReadThreadState.Stop) {
while (self.read_stop.load(std.builtin.AtomicOrder.acquire) != ReadThreadState.stop) {
defer std.time.sleep(cur_read_delay_ns);

const n = try self.transport.read(buf);
Expand Down Expand Up @@ -413,7 +403,7 @@ pub const Session = struct {

var cur_read_delay_ns: u64 = self.options.read_delay_min_ns;

var bufs = NewReadBufs(allocator);
var bufs = ReadBufs.init(allocator);
defer bufs.deinit();

var cur_check_start_idx: usize = 0;
Expand Down Expand Up @@ -683,13 +673,9 @@ pub const Session = struct {
) ![2][]const u8 {
self.log.info("get prompt requested", .{});

var args = NewReadArgs();

args.pattern = self.compiled_prompt_pattern;

try self.writeReturn();

var bufs = NewReadBufs(allocator);
var bufs = ReadBufs.init(allocator);
defer bufs.deinit();

var timer = try std.time.Timer.start();
Expand All @@ -698,7 +684,9 @@ pub const Session = struct {
&timer,
options.cancel,
readUntilPatternCheckDone,
args,
.{
.pattern = self.compiled_prompt_pattern,
},
&bufs,
);

Expand Down Expand Up @@ -730,10 +718,10 @@ pub const Session = struct {
input_handling: operation.InputHandling,
bufs: *ReadBufs,
) !MatchPositions {
var args = NewReadArgs();

args.pattern = self.compiled_prompt_pattern;
args.actual = input;
const args = ReadArgs{
.pattern = self.compiled_prompt_pattern,
.actual = input,
};

try self.write(input, false);

Expand Down Expand Up @@ -780,7 +768,7 @@ pub const Session = struct {

var timer = try std.time.Timer.start();

var bufs = NewReadBufs(allocator);
var bufs = ReadBufs.init(allocator);
defer bufs.deinit();

if (self.last_consumed_prompt.items.len != 0) {
Expand All @@ -803,16 +791,14 @@ pub const Session = struct {
try bufs.processed.resize(0);
}

var args = NewReadArgs();

args.pattern = self.compiled_prompt_pattern;
args.actual = input;

var prompt_indexes = try self.readTimeout(
&timer,
options.cancel,
readUntilPatternCheckDone,
args,
.{
.pattern = self.compiled_prompt_pattern,
.actual = input,
},
&bufs,
);

Expand All @@ -823,7 +809,11 @@ pub const Session = struct {
if (!options.retain_trailing_prompt) {
// using the prompt indexes, replace that range holding the trailing prompt out
// of the processed buf
try bufs.processed.replaceRange(prompt_indexes.start, prompt_indexes.len(), "",);
try bufs.processed.replaceRange(
prompt_indexes.start,
prompt_indexes.len(),
"",
);
}

return bufs.toOwnedSlices(allocator);
Expand All @@ -845,9 +835,11 @@ pub const Session = struct {
var compiled_pattern: ?*pcre2.pcre2_code_8 = null;

if (prompt_pattern) |pattern| {
compiled_pattern = re.pcre2Compile(pattern);
if (compiled_pattern == null) {
return error.CompilePromptPatternFailed;
if (pattern.len > 0) {
compiled_pattern = re.pcre2Compile(pattern);
if (compiled_pattern == null) {
return error.CompilePromptPatternFailed;
}
}
}

Expand All @@ -868,7 +860,7 @@ pub const Session = struct {
}
}

var bufs = NewReadBufs(allocator);
var bufs = ReadBufs.init(allocator);
defer bufs.deinit();

if (self.last_consumed_prompt.items.len != 0) {
Expand All @@ -886,7 +878,9 @@ pub const Session = struct {
&bufs,
);

var args = NewReadArgs();
var args = ReadArgs{
.actual = prompt,
};

if (compiled_pattern) |cp| {
args.patterns = &[_]?*pcre2.pcre2_code_8{
Expand All @@ -897,9 +891,6 @@ pub const Session = struct {
args.pattern = self.compiled_prompt_pattern;
}

args.actual = prompt;
args.patterns = &[_]?*pcre2.pcre2_code_8{self.compiled_prompt_pattern};

_ = try self.readTimeout(
&timer,
options.cancel,
Expand Down

0 comments on commit 83d882e

Please sign in to comment.