Skip to content

Commit

Permalink
upd wandb weave
Browse files Browse the repository at this point in the history
  • Loading branch information
waleko committed Jun 22, 2024
1 parent de21236 commit c55aaff
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 257 deletions.
7 changes: 3 additions & 4 deletions code_editing/agents/agent_codeeditor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Dict

import weave
from hydra.utils import instantiate
from langchain_core.runnables import RunnableConfig, RunnableLambda

Expand All @@ -11,7 +12,6 @@
from code_editing.code_editor import CEInput, CEOutput, CodeEditor
from code_editing.configs.agents.context_providers.context_config import ContextConfig
from code_editing.utils.git_utils import get_head_diff_unsafe
from code_editing.utils.wandb_utils import log_codeeditor_trace


class AgentCodeEditor(CodeEditor):
Expand All @@ -35,8 +35,8 @@ def __init__(
self.context_providers_cfg = context_providers_cfg
self.runnable_config = runnable_config

@log_codeeditor_trace()
def generate_diff(self, req: CEInput, root_span) -> CEOutput:
@weave.op()
def generate_diff(self, req: CEInput) -> CEOutput:
# Get repository full path
repo_path = req["code_base"].get(CheckoutExtractor.REPO_KEY, None)
if repo_path is None:
Expand All @@ -55,7 +55,6 @@ def generate_diff(self, req: CEInput, root_span) -> CEOutput:
# Tools available to the agent
tools = self.tool_factory.build(
run_overview_manager=run_overview_manager,
root_span=root_span, # W&B root span
)

# Build the graph runnable
Expand Down
8 changes: 4 additions & 4 deletions code_editing/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

class ToolInfo(TypedDict):
calls: int
errors: int
success: int
failures: int
errors: int


# enum class: calls, errors, failures
Expand All @@ -35,16 +36,15 @@ def __init__(
self.start_ms = wandb_utils.get_current_ms()

def log_tool_use(self, tool_name, status: ToolUseStatus):
status = status.value
self.tools_info.setdefault(tool_name, {}).setdefault(status, 0)
self.tools_info[tool_name][status] += 1

def get_run_summary(self):
end_ms = wandb_utils.get_current_ms()
return {
"tools": self.tools_info,
"start_ms": self.start_ms,
"end_ms": end_ms,
"duration_ms": end_ms - self.start_ms,
"duration_sec": (end_ms - self.start_ms) / 1000,
}

def get_ctx_provider(self, ctx_provider_name) -> ContextProvider:
Expand Down
8 changes: 2 additions & 6 deletions code_editing/agents/tools/edit_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ class EditToolInput(BaseModel):
The instruction should be a prompt for the editing LLM."""
args_schema = EditToolInput

def __init__(self, backbone: CEBackbone = None, root_span=None, **kwargs):
def __init__(self, backbone: CEBackbone = None, **kwargs):
super().__init__(**kwargs)
self.args_schema = self.EditToolInput

if self.dry_run:
return

self.backbone = backbone
self.root_span = root_span

if self.backbone is None:
raise ValueError("Backbone is required for the edit tool")
Expand All @@ -48,9 +47,7 @@ def _run_tool(self, file_name: str, start_index: int, instruction: str, context:
file = parse_file(file_name, self.repo_path)
contents, lines, start, end = read_file(context, file, start_index)
# Send to the editing LLM
resp = self.backbone.generate_diff(
{"instruction": instruction, "code_base": {file_name: contents}}, parent_span=self.root_span
)
resp = self.backbone.generate_diff({"instruction": instruction, "code_base": {file_name: contents}})
new_contents = resp["prediction"]
# Save
with open(file, "w") as f:
Expand All @@ -67,4 +64,3 @@ def short_name(self) -> str:

backbone: CEBackbone = None
retrieval_helper: RetrievalHelper = None
root_span: Any = None
5 changes: 0 additions & 5 deletions code_editing/metrics/gpt4_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ def _score_single(self, diff_true: str, diff_pred: str, full_row: Dict):
if found:
try:
res = float(found[0])
# Log to W&B
if wandb.run is not None:
wandb_utils.gpt4_eval_trace(
diff_true, patch, start_ms, end_ms, response, res, metadata={"model": self.model_name}
)
return res
except:
pass
Expand Down
29 changes: 1 addition & 28 deletions code_editing/utils/backbones/baseline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import wandb
from wandb.sdk.data_types.trace_tree import StatusCode

from code_editing.code_editor import CEBackbone, CEInput, CEOutput, CodeEditor
from code_editing.utils import wandb_utils
from code_editing.utils.preprocessors.base_preprocessor import CEPreprocessor
Expand All @@ -13,37 +10,13 @@ def __init__(self, backbone: CEBackbone, preprocessor: CEPreprocessor):
self.run_name = backbone.name

def generate_diff(self, req: CEInput) -> CEOutput:
# Initialize the root span for W&B
root_span = None
if wandb.run is not None:
root_span = wandb_utils.build_main_trace(
req,
wandb_utils.get_current_ms(),
"Code Editing",
metadata={
"preprocessor_name": self.preprocessor.name,
"backbone_name": self.backbone.name,
},
)

# Preprocess the input
start_ms = wandb_utils.get_current_ms()
old_req = req
req = self.preprocessor(req)
after_preprocess_ms = wandb_utils.get_current_ms()
# Log the preprocessing trace to W&B
if wandb.run is not None:
wandb_utils.log_preprocessor_trace(old_req, req, start_ms, after_preprocess_ms, root_span)

# Generate the diff using the backbone
try:
resp = self.backbone.generate_diff(req, parent_span=root_span)
if wandb.run is not None:
wandb_utils.log_main_trace(root_span, old_req, resp, StatusCode.SUCCESS)
except Exception as e:
if wandb.run is not None:
wandb_utils.log_main_trace(root_span, old_req, None, StatusCode.ERROR, str(e))
raise e
resp = self.backbone.generate_diff(req)

return resp

Expand Down
24 changes: 4 additions & 20 deletions code_editing/utils/backbones/hf_backbone.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging
from typing import Dict, Optional
from typing import Dict

import torch
import weave
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, set_seed
from wandb.sdk.data_types.trace_tree import Trace

from code_editing.code_editor import CEBackbone, CEInput, CEOutput
from code_editing.configs.backbones_configs import HFGenerationConfig, HFModelConfig
from code_editing.utils import wandb_utils
from code_editing.utils.prompts.base_prompt import CEPrompt


Expand Down Expand Up @@ -83,15 +82,7 @@ def generate_diff(self, req: CEInput, **kwargs) -> CEOutput:
if not self._prompt:
raise ValueError("Prompt is required for HuggingFace models.")

# Initialize the root span for W&B
parent_span: Optional[Trace] = kwargs.get("parent_span", None)

@wandb_utils.log_prompt_trace(
parent_span,
metadata={
"prompt_name": self._prompt.name,
},
)
@weave.op(name="prompt")
def get_inp(r):
return self._prompt.hf(
r,
Expand All @@ -102,14 +93,7 @@ def get_inp(r):
preprocessed_inputs = get_inp(req)
encoding = self._tokenizer(preprocessed_inputs, return_tensors="pt").to(self._device)

@wandb_utils.log_llm_trace(
parent_span=parent_span,
model_name=self._name_or_path,
metadata={
"model_config": self._model.config.to_dict(),
"generation_config": self._generation_config.to_dict(),
},
)
@weave.op(name="generate")
def get_resp(_):
return self._model.generate(
**encoding,
Expand Down
18 changes: 2 additions & 16 deletions code_editing/utils/backbones/openai_backbone.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging
from typing import Optional

from openai import OpenAI
from wandb.sdk.data_types.trace_tree import Trace

from code_editing.code_editor import CEBackbone, CEInput, CEOutput
from code_editing.utils import wandb_utils
from code_editing.utils.prompts.base_prompt import CEPrompt


Expand All @@ -24,19 +21,8 @@ def __init__(self, model_name: str, prompt: CEPrompt, **kwargs):
logging.getLogger("httpx").setLevel(logging.WARNING)

def generate_diff(self, req: CEInput, **kwargs) -> CEOutput:
# Initialize the root span for W&B
parent_span: Optional[Trace] = kwargs.get("parent_span", None)

preprocessed_inputs = wandb_utils.log_prompt_trace(
parent_span,
metadata={
"prompt_name": self._prompt.name,
},
)(
self._prompt.chat
)(req)

@wandb_utils.log_llm_trace(parent_span=parent_span, model_name=self._model_name)
preprocessed_inputs = self._prompt.chat(req)

def openai_request(inp):
resp = self.api.chat.completions.create(
messages=inp,
Expand Down
Loading

0 comments on commit c55aaff

Please sign in to comment.