diff --git a/.gitignore b/.gitignore index 2a4350d0..65180906 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,9 @@ llvm-* *.zip */tmp + +src/config.yaml + +src/kparser + +commits/test.txt diff --git a/prompt_template/patch2semgrep.md b/prompt_template/patch2semgrep.md new file mode 100644 index 00000000..7bbbff01 --- /dev/null +++ b/prompt_template/patch2semgrep.md @@ -0,0 +1,144 @@ +# Instruction + +You will be provided with a patch in a software repository. +Please analyze the patch and find out the **bug pattern** in this patch. +A **bug pattern** is the root cause of this bug, meaning that programs with this pattern will have a great possibility of having the same bug. +Note that the bug pattern should be specific and accurate, which can be used to identify the buggy code provided in the patch. + +Then, please help to write a Semgrep rule to detect the specific bug pattern. +The rule should be written in YAML format and follow Semgrep syntax conventions. + +**Please read `Suggestions` section before writing the rule!** + +# Examples + +{{examples}} + +# Target Patch + +{{input_patch}} + +# Suggestions + +1. Semgrep rules use YAML format. Each rule should have an `id`, `pattern`, `languages`, `message`, and `severity`. + +2. Use `$VAR` for metavariables to match any variable name, `$FUNC` for function names, `$EXPR` for expressions. + +3. Use `pattern-not` to exclude patterns that should not trigger the rule (especially the fixed version). + +4. Use `pattern-either` for OR conditions when you need to match multiple variations. + +5. Use `pattern-inside` to limit matches to specific contexts (e.g., inside a function or class). + +6. Use `pattern-not-inside` to exclude specific contexts where the rule should not match. + +7. Use `...` to match zero or more statements/expressions between patterns. + +8. The `languages` field should specify the target programming language accurately (e.g., ["c"], ["cpp"], ["javascript"], ["python"]). + +9. The `message` should be clear and actionable, explaining: + - What the vulnerability is + - Why it's dangerous + - How to fix it + +10. Use appropriate `severity` levels: INFO for style issues, WARNING for potential problems, ERROR for definite bugs. + +11. For memory management issues in C/C++, be specific about pointer operations and null checks. + +12. Consider edge cases and variations of the pattern that should also be caught. + +13. Add relevant metadata like CWE numbers, OWASP categories, and source URLs. + +14. Make patterns as specific as possible to minimize false positives while catching variations. + +15. Use `metavariable-pattern` to add constraints on variables when needed. + +# SEMGREP PATTERN SYNTAX GUIDE + +- Use `$VARNAME` to match any expression or variable +- Use `...` to match any sequence of statements +- Use `pattern-inside` to limit matches to specific code blocks +- Use `pattern-not` to exclude specific patterns (like the fixed version) +- Use `pattern-either` to match multiple alternative patterns +- Use `metavariable-pattern` to add constraints on metavariables + +# Rule Template + +```yaml +rules: + - id: your-rule-id + pattern: | + your pattern here + pattern-not: | + exclusion pattern here (fixed version) + pattern-inside: | + context pattern here + languages: ["target-language"] + message: | + Detailed description of: + - What the vulnerability is + - Why it's dangerous + - How to fix it + severity: ERROR + metadata: + category: security + cwe: + - "CWE-XXX" + owasp: + - "A1:2017-Injection" + technology: + - target-language + references: + - "https://example.com/documentation" +``` + +# Required Fields + +1. **id**: Unique identifier (use lowercase, numbers, hyphens only) +2. **pattern**: Main pattern to match the vulnerable code +3. **languages**: Array of target programming languages +4. **message**: Clear, actionable description of the issue and fix +5. **severity**: One of [ERROR, WARNING, INFO] + +# Recommended Fields + +- **pattern-not**: Pattern for the fixed version (to avoid false positives) +- **pattern-inside**: Context where the rule should apply +- **pattern-not-inside**: Context where the rule should not apply +- **metadata**: Additional context including CWE, OWASP, references + +# Important Guidelines + +1. Make patterns specific enough to minimize false positives +2. Include `pattern-not` for the fixed version when possible +3. Add relevant metadata like CWE numbers and OWASP categories +4. Write clear, actionable messages explaining both problem and solution +5. Consider different variations of the vulnerable pattern +6. Test your pattern mentally against both positive cases (should match) and negative cases (should not match) + +# Formatting + +Please show me the completed Semgrep rule in proper YAML format. + +Your response should be a single YAML document like: + +```yaml +rules: + - id: rule-name + pattern: | + pattern content + pattern-not: | + fixed version pattern + languages: ["language"] + message: | + Detailed description of the vulnerability and how to fix it. + severity: ERROR + metadata: + category: security + cwe: + - "CWE-XXX" + technology: + - language +``` + +Remember to adapt the patterns to match the specific vulnerability while keeping them general enough to catch variations of the same issue. diff --git a/prompt_template/pattern2semplan.md b/prompt_template/pattern2semplan.md new file mode 100644 index 00000000..6534e1b1 --- /dev/null +++ b/prompt_template/pattern2semplan.md @@ -0,0 +1,72 @@ +# Instruction + +Please organize a elaborate plan to help write a Semgrep rule to detect the **bug pattern**. + +You will be provided with a **bug pattern** description and the corresponding patch to help you understand this bug pattern. + +**Please read `Suggestions` section before writing the plan!** + +# Examples + +{{examples}} + +# Target Patch + +{{input_patch}} + +# Target Pattern + +{{input_pattern}} + +{{failed_plan_examples}} + +# Suggestions + +1. Semgrep rules use pattern matching syntax. Use `$VAR` for metavariables to match any variable. + +2. Use `pattern` for the main pattern to match, `pattern-not` to exclude certain patterns, and `pattern-either` for OR conditions. + +3. For function calls, use `$FUNC(...)` to match any function call, or `$FUNC($ARG1, $ARG2)` for specific arguments. + +4. Use `...` to match any number of statements or expressions between patterns. + +5. The `languages` field should specify the target programming language (e.g., ["c"], ["javascript"], ["python"]). + +6. The `message` should be **short** and clear, describing what the rule detects. + +7. Use appropriate `severity` levels: INFO, WARNING, ERROR. + +8. Consider using `pattern-inside` to limit matches to specific contexts (e.g., inside a function). + +9. For pointer dereferences, memory management, and similar C/C++ issues, be specific about the context. + +10. Use `metavariable-regex` when you need to match specific naming patterns. + +# Formatting + +Your plan should contain the following information: + +1. Identify the main pattern that needs to be detected (the buggy code pattern). + +2. Determine what variations of the pattern should be caught. + +3. Specify what legitimate code patterns should be excluded (using `pattern-not`). + +4. Choose appropriate metavariables for the rule. + +5. Determine the context where the rule should apply (e.g., inside functions, specific file types). + +6. Decide on the message and severity level. + +You only need to tell me the way to implement this Semgrep rule, extra information like testing or documentation is unnecessary. + +**Please try to use the simplest approach and fewer patterns to achieve your goal. But for every step, your response should be as concrete as possible so that I can easily follow your guidance and write a correct Semgrep rule!** + +# Plan + +Your plan should follow the format of example plans. +Note, your plan should be concise and clear. Do not include unnecessary information or example implementation code snippets. + +``` +Your plan here +``` diff --git a/prompt_template/plan2semgrep.md b/prompt_template/plan2semgrep.md new file mode 100644 index 00000000..3b290838 --- /dev/null +++ b/prompt_template/plan2semgrep.md @@ -0,0 +1,151 @@ +# Instruction + +You are proficient in writing Semgrep rules. + +Please help me write a Semgrep rule to detect a specific bug pattern. +You can refer to the `Target Bug Pattern` and `Target Patch` sections to help you understand the bug pattern. +Please make sure your rule can detect the bug shown in the buggy code pattern. +Please refer to the `Plan` section to implement the Semgrep rule. + +**Please read `Suggestions` section before writing the rule!** + +# Examples + +{{examples}} + +# Target Bug Pattern + +{{input_pattern}} + +# Target Patch + +{{input_patch}} + +# Target Plan + +{{input_plan}} + +# Suggestions + +1. Semgrep rules use YAML format. Each rule should have an `id`, `pattern`, `languages`, `message`, and `severity`. + +2. Use `$VAR` for metavariables to match any variable name, `$FUNC` for function names, `$EXPR` for expressions. + +3. Use `pattern-not` to exclude patterns that should not trigger the rule (especially the fixed version). + +4. Use `pattern-either` for OR conditions when you need to match multiple variations. + +5. Use `pattern-inside` to limit matches to specific contexts (e.g., inside a function or class). + +6. Use `pattern-not-inside` to exclude specific contexts where the rule should not match. + +7. Use `...` to match zero or more statements/expressions between patterns. + +8. The `languages` field should specify the target programming language accurately (e.g., ["c"], ["cpp"], ["javascript"], ["python"]). + +9. The `message` should be clear and actionable, explaining: + - What the vulnerability is + - Why it's dangerous + - How to fix it + +10. Use appropriate `severity` levels: INFO for style issues, WARNING for potential problems, ERROR for definite bugs. + +11. For memory management issues in C/C++, be specific about pointer operations and null checks. + +12. Consider edge cases and variations of the pattern that should also be caught. + +13. Add relevant metadata like CWE numbers, OWASP categories, and source URLs. + +14. Make patterns as specific as possible to minimize false positives while catching variations. + +15. Use `metavariable-pattern` to add constraints on variables when needed. + +# SEMGREP PATTERN SYNTAX GUIDE + +- Use `$VARNAME` to match any expression or variable +- Use `...` to match any sequence of statements +- Use `pattern-inside` to limit matches to specific code blocks +- Use `pattern-not` to exclude specific patterns (like the fixed version) +- Use `pattern-either` to match multiple alternative patterns +- Use `metavariable-pattern` to add constraints on metavariables + +# Rule Template + +```yaml +rules: + - id: your-rule-id + pattern: | + your pattern here + pattern-not: | + exclusion pattern here (fixed version) + pattern-inside: | + context pattern here + languages: ["target-language"] + message: | + Detailed description of: + - What the vulnerability is + - Why it's dangerous + - How to fix it + severity: ERROR + metadata: + category: security + cwe: + - "CWE-XXX" + owasp: + - "A1:2017-Injection" + technology: + - target-language + references: + - "https://example.com/documentation" +``` + +# Required Fields + +1. **id**: Unique identifier (use lowercase, numbers, hyphens only) +2. **pattern**: Main pattern to match the vulnerable code +3. **languages**: Array of target programming languages +4. **message**: Clear, actionable description of the issue and fix +5. **severity**: One of [ERROR, WARNING, INFO] + +# Recommended Fields + +- **pattern-not**: Pattern for the fixed version (to avoid false positives) +- **pattern-inside**: Context where the rule should apply +- **pattern-not-inside**: Context where the rule should not apply +- **metadata**: Additional context including CWE, OWASP, references + +# Important Guidelines + +1. Make patterns specific enough to minimize false positives +2. Include `pattern-not` for the fixed version when possible +3. Add relevant metadata like CWE numbers and OWASP categories +4. Write clear, actionable messages explaining both problem and solution +5. Consider different variations of the vulnerable pattern +6. Test your pattern mentally against both positive cases (should match) and negative cases (should not match) + +# Formatting + +Please show me the completed Semgrep rule in proper YAML format. + +Your response should be a single YAML document like: + +```yaml +rules: + - id: rule-name + pattern: | + pattern content + pattern-not: | + fixed version pattern + languages: ["language"] + message: | + Detailed description of the vulnerability and how to fix it. + severity: ERROR + metadata: + category: security + cwe: + - "CWE-XXX" + technology: + - language +``` + +Remember to adapt the patterns to match the specific vulnerability while keeping them general enough to catch variations of the same issue. diff --git a/prompt_template/repair_semgrep.md b/prompt_template/repair_semgrep.md new file mode 100644 index 00000000..d9533846 --- /dev/null +++ b/prompt_template/repair_semgrep.md @@ -0,0 +1,51 @@ +# Role + +You are an expert in writing and debugging Semgrep rules for static code analysis. + +# Instruction + +The following Semgrep rule has validation errors, and your task is to fix these errors based on the provided error messages. + +Here are common issues and solutions: + +1. **YAML Syntax Errors**: Fix indentation, quotes, and structure +2. **Invalid Pattern Syntax**: Correct Semgrep pattern syntax +3. **Missing Required Fields**: Add required fields like `id`, `message`, `languages` +4. **Invalid Field Values**: Fix invalid severity levels or language specifications + +**Please only fix the validation errors while preserving the original detection logic.** +**Return the complete corrected Semgrep rule.** + +# Common Semgrep Rule Fields + +Required fields: +- `id`: Unique identifier for the rule +- `pattern` or `patterns`: The detection pattern(s) +- `message`: Description of what the rule detects +- `languages`: Array of target languages (e.g., ["c"]) +- `severity`: ERROR, WARNING, or INFO + +Optional but recommended: +- `metadata`: Additional information about the rule +- `pattern-not`: Patterns to exclude +- `pattern-inside`: Context patterns + +# Current Semgrep Rule + +```yaml +{{semgrep_rule}} +``` + +# Validation Errors + +{{error_messages}} + +# Formatting + +Please provide the corrected Semgrep rule: + +```yaml +{{fixed_semgrep_rule}} +``` + +Note: Return the **complete** corrected Semgrep rule after fixing the validation errors. diff --git a/prompt_template/semgrep_examples/double-free/patch.md b/prompt_template/semgrep_examples/double-free/patch.md new file mode 100644 index 00000000..fb5075f9 --- /dev/null +++ b/prompt_template/semgrep_examples/double-free/patch.md @@ -0,0 +1,197 @@ +## Patch Description + +pinctrl: sophgo: fix double free in cv1800_pctrl_dt_node_to_map() + +'map' is allocated using devm_* which takes care of freeing the allocated +data, but in error paths there is a call to pinctrl_utils_free_map() +which also does kfree(map) which leads to a double free. + +Use kcalloc() instead of devm_kcalloc() as freeing is manually handled. + +Fixes: a29d8e93e710 ("pinctrl: sophgo: add support for CV1800B SoC") +Signed-off-by: Harshit Mogalapalli +Link: https://lore.kernel.org/20241010111830.3474719-1-harshit.m.mogalapalli@oracle.com +Signed-off-by: Linus Walleij + +## Buggy Code + +```c +// drivers/pinctrl/sophgo/pinctrl-cv18xx.c +static int cv1800_pctrl_dt_node_to_map(struct pinctrl_dev *pctldev, + struct device_node *np, + struct pinctrl_map **maps, + unsigned int *num_maps) +{ + struct cv1800_pinctrl *pctrl = pinctrl_dev_get_drvdata(pctldev); + struct device *dev = pctrl->dev; + struct device_node *child; + struct pinctrl_map *map; + const char **grpnames; + const char *grpname; + int ngroups = 0; + int nmaps = 0; + int ret; + + for_each_available_child_of_node(np, child) + ngroups += 1; + + grpnames = devm_kcalloc(dev, ngroups, sizeof(*grpnames), GFP_KERNEL); + if (!grpnames) + return -ENOMEM; + + map = devm_kcalloc(dev, ngroups * 2, sizeof(*map), GFP_KERNEL); + if (!map) + return -ENOMEM; + + ngroups = 0; + mutex_lock(&pctrl->mutex); + for_each_available_child_of_node(np, child) { + int npins = of_property_count_u32_elems(child, "pinmux"); + unsigned int *pins; + struct cv1800_pin_mux_config *pinmuxs; + u32 config, power; + int i; + + if (npins < 1) { + dev_err(dev, "invalid pinctrl group %pOFn.%pOFn\n", + np, child); + ret = -EINVAL; + goto dt_failed; + } + + grpname = devm_kasprintf(dev, GFP_KERNEL, "%pOFn.%pOFn", + np, child); + if (!grpname) { + ret = -ENOMEM; + goto dt_failed; + } + + grpnames[ngroups++] = grpname; + + pins = devm_kcalloc(dev, npins, sizeof(*pins), GFP_KERNEL); + if (!pins) { + ret = -ENOMEM; + goto dt_failed; + } + + pinmuxs = devm_kcalloc(dev, npins, sizeof(*pinmuxs), GFP_KERNEL); + if (!pinmuxs) { + ret = -ENOMEM; + goto dt_failed; + } + + for (i = 0; i < npins; i++) { + ret = of_property_read_u32_index(child, "pinmux", + i, &config); + if (ret) + goto dt_failed; + + pins[i] = cv1800_dt_get_pin(config); + pinmuxs[i].config = config; + pinmuxs[i].pin = cv1800_get_pin(pctrl, pins[i]); + + if (!pinmuxs[i].pin) { + dev_err(dev, "failed to get pin %d\n", pins[i]); + ret = -ENODEV; + goto dt_failed; + } + + ret = cv1800_verify_pinmux_config(&pinmuxs[i]); + if (ret) { + dev_err(dev, "group %s pin %d is invalid\n", + grpname, i); + goto dt_failed; + } + } + + ret = cv1800_verify_pin_group(pinmuxs, npins); + if (ret) { + dev_err(dev, "group %s is invalid\n", grpname); + goto dt_failed; + } + + ret = of_property_read_u32(child, "power-source", &power); + if (ret) + goto dt_failed; + + if (!(power == PIN_POWER_STATE_3V3 || power == PIN_POWER_STATE_1V8)) { + dev_err(dev, "group %s have unsupported power: %u\n", + grpname, power); + ret = -ENOTSUPP; + goto dt_failed; + } + + ret = cv1800_set_power_cfg(pctrl, pinmuxs[0].pin->power_domain, + power); + if (ret) + goto dt_failed; + + map[nmaps].type = PIN_MAP_TYPE_MUX_GROUP; + map[nmaps].data.mux.function = np->name; + map[nmaps].data.mux.group = grpname; + nmaps += 1; + + ret = pinconf_generic_parse_dt_config(child, pctldev, + &map[nmaps].data.configs.configs, + &map[nmaps].data.configs.num_configs); + if (ret) { + dev_err(dev, "failed to parse pin config of group %s: %d\n", + grpname, ret); + goto dt_failed; + } + + ret = pinctrl_generic_add_group(pctldev, grpname, + pins, npins, pinmuxs); + if (ret < 0) { + dev_err(dev, "failed to add group %s: %d\n", grpname, ret); + goto dt_failed; + } + + /* don't create a map if there are no pinconf settings */ + if (map[nmaps].data.configs.num_configs == 0) + continue; + + map[nmaps].type = PIN_MAP_TYPE_CONFIGS_GROUP; + map[nmaps].data.configs.group_or_pin = grpname; + nmaps += 1; + } + + ret = pinmux_generic_add_function(pctldev, np->name, + grpnames, ngroups, NULL); + if (ret < 0) { + dev_err(dev, "error adding function %s: %d\n", np->name, ret); + goto function_failed; + } + + *maps = map; + *num_maps = nmaps; + mutex_unlock(&pctrl->mutex); + + return 0; + +dt_failed: + of_node_put(child); +function_failed: + pinctrl_utils_free_map(pctldev, map, nmaps); + mutex_unlock(&pctrl->mutex); + return ret; +} +``` + +## Bug Fix Patch + +```diff +diff --git a/drivers/pinctrl/sophgo/pinctrl-cv18xx.c b/drivers/pinctrl/sophgo/pinctrl-cv18xx.c +index d18fc5aa84f7..57f2674e75d6 100644 +--- a/drivers/pinctrl/sophgo/pinctrl-cv18xx.c ++++ b/drivers/pinctrl/sophgo/pinctrl-cv18xx.c +@@ -221,7 +221,7 @@ static int cv1800_pctrl_dt_node_to_map(struct pinctrl_dev *pctldev, + if (!grpnames) + return -ENOMEM; + +- map = devm_kcalloc(dev, ngroups * 2, sizeof(*map), GFP_KERNEL); ++ map = kcalloc(ngroups * 2, sizeof(*map), GFP_KERNEL); + if (!map) + return -ENOMEM; + +``` diff --git a/prompt_template/semgrep_examples/double-free/pattern.md b/prompt_template/semgrep_examples/double-free/pattern.md new file mode 100644 index 00000000..cd6750a0 --- /dev/null +++ b/prompt_template/semgrep_examples/double-free/pattern.md @@ -0,0 +1,18 @@ +## Bug Pattern + +The bug pattern is **double-free vulnerability caused by mixing device-managed memory allocation with manual deallocation**. + +The issue occurs when: +1. Memory is allocated using device-managed allocation functions (`devm_kcalloc`, `devm_kmalloc`, etc.) +2. The same memory is later manually freed using functions like `kfree`, `kvfree`, or `pinctrl_utils_free_map` + +Device-managed allocations are automatically freed when the device is removed or the driver is unloaded. Manual deallocation of such memory leads to a double-free condition, which can cause: +- Memory corruption +- System crashes +- Security vulnerabilities +- Undefined behavior + +The pattern specifically involves: +- Using `devm_*` allocation functions for memory management +- Having error paths or cleanup code that manually calls free functions +- The manual free functions operating on pointers that were allocated with `devm_*` diff --git a/prompt_template/semgrep_examples/double-free/plan.md b/prompt_template/semgrep_examples/double-free/plan.md new file mode 100644 index 00000000..1b439b9a --- /dev/null +++ b/prompt_template/semgrep_examples/double-free/plan.md @@ -0,0 +1,34 @@ +## Plan + +### Objective +Create a Semgrep rule to detect double-free vulnerabilities where device-managed allocations (`devm_*`) are manually freed. + +### Detection Strategy + +1. **Identify Device-Managed Allocations:** + - Pattern to match calls to `devm_kcalloc`, `devm_kmalloc`, `devm_kzalloc`, and other `devm_*` allocation functions + - Capture the variable that stores the return value of these functions + +2. **Detect Manual Deallocation:** + - Pattern to match calls to manual free functions: `kfree`, `kvfree`, `pinctrl_utils_free_map`, etc. + - Check if the argument to these free functions is the same variable allocated with `devm_*` + +3. **Pattern Matching Logic:** + - Use Semgrep's metavariable matching to track the same pointer across allocation and deallocation + - Look for the pattern where a `devm_*` allocated pointer is later passed to a manual free function + - Consider both direct usage and usage within the same function scope + +4. **Handle Common Scenarios:** + - Direct assignment: `ptr = devm_kcalloc(...); ... kfree(ptr);` + - Error path cleanup: allocated in main flow, freed in error handling + - Function parameter passing: allocated pointer passed to cleanup functions + +5. **Rule Structure:** + - Use `pattern-either` to catch multiple allocation functions (`devm_kcalloc`, `devm_kmalloc`, etc.) + - Use `pattern-inside` to ensure both allocation and deallocation happen in the same function + - Use metavariables to track the same pointer variable + - Provide clear error message explaining the double-free risk + +6. **Minimize False Positives:** + - Use `pattern-not` to exclude cases where the pointer is reassigned to non-devm allocation + - Consider function boundaries to avoid cross-function false positives diff --git a/prompt_template/semgrep_examples/double-free/rule.yml b/prompt_template/semgrep_examples/double-free/rule.yml new file mode 100644 index 00000000..11e0f0c0 --- /dev/null +++ b/prompt_template/semgrep_examples/double-free/rule.yml @@ -0,0 +1,60 @@ +rules: + - id: double-free-devm-allocation + patterns: + - pattern-inside: | + $FUNC(...) { + ... + } + - pattern-either: + - pattern: | + $PTR = devm_kmalloc(...); + ... + kfree($PTR); + - pattern: | + $PTR = devm_kzalloc(...); + ... + kfree($PTR); + - pattern: | + $PTR = devm_kcalloc(...); + ... + kfree($PTR); + - pattern: | + $PTR = devm_kmalloc(...); + ... + kvfree($PTR); + - pattern: | + $PTR = devm_kzalloc(...); + ... + kvfree($PTR); + - pattern: | + $PTR = devm_kcalloc(...); + ... + kvfree($PTR); + - pattern: | + $PTR = devm_kcalloc(...); + ... + pinctrl_utils_free_map(..., $PTR, ...); + pattern-not: | + $PTR = devm_kcalloc(...); + ... + $PTR = kcalloc(...); + ... + kfree($PTR); + languages: ["c"] + message: | + Double-free vulnerability detected: Memory allocated with devm_* functions is automatically + managed by the device framework and should not be manually freed. Manual calls to kfree(), + kvfree(), or pinctrl_utils_free_map() on devm_* allocated memory can cause double-free errors. + + Fix: Use regular allocation functions (kcalloc, kmalloc, kzalloc) instead of devm_* functions + when manual memory management is required, or remove the manual free calls if automatic + cleanup is desired. + severity: ERROR + metadata: + category: security + cwe: "CWE-415" + owasp: "A06:2021-Vulnerable and Outdated Components" + confidence: HIGH + references: + - "https://cwe.mitre.org/data/definitions/415.html" + - "https://www.kernel.org/doc/html/latest/driver-api/device_resource_management.html" diff --git a/prompt_template/semgrep_examples/integer-overflow/patch.md b/prompt_template/semgrep_examples/integer-overflow/patch.md new file mode 100644 index 00000000..8737534e --- /dev/null +++ b/prompt_template/semgrep_examples/integer-overflow/patch.md @@ -0,0 +1,1355 @@ +## Patch Description + +update version of lazy_bdecode from libtorrent + +## Buggy Code + +```c +// Complete file: lazy_entry.hpp (tree-sitter fallback) +/* + +Copyright (c) 2003-2012, Arvid Norberg +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef TORRENT_LAZY_ENTRY_HPP_INCLUDED +#define TORRENT_LAZY_ENTRY_HPP_INCLUDED + +#include +#include +#include +#include +#include + +#define TORRENT_EXPORT +#define TORRENT_EXTRA_EXPORT +#define TORRENT_ASSERT(x) assert(x) + +namespace libtorrent +{ + using boost::system::error_code; + + struct lazy_entry; + + // This function decodes bencoded_ data. + // + // .. _bencoded: http://wiki.theory.org/index.php/BitTorrentSpecification + // + // Whenever possible, ``lazy_bdecode()`` should be preferred over ``bdecode()``. + // It is more efficient and more secure. It supports having constraints on the + // amount of memory is consumed by the parser. + // + // *lazy* refers to the fact that it doesn't copy any actual data out of the + // bencoded buffer. It builds a tree of ``lazy_entry`` which has pointers into + // the bencoded buffer. This makes it very fast and efficient. On top of that, + // it is not recursive, which saves a lot of stack space when parsing deeply + // nested trees. However, in order to protect against potential attacks, the + // ``depth_limit`` and ``item_limit`` control how many levels deep the tree is + // allowed to get. With recursive parser, a few thousand levels would be enough + // to exhaust the threads stack and terminate the process. The ``item_limit`` + // protects against very large structures, not necessarily deep. Each bencoded + // item in the structure causes the parser to allocate some amount of memory, + // this memory is constant regardless of how much data actually is stored in + // the item. One potential attack is to create a bencoded list of hundreds of + // thousands empty strings, which would cause the parser to allocate a significant + // amount of memory, perhaps more than is available on the machine, and effectively + // provide a denial of service. The default item limit is set as a reasonable + // upper limit for desktop computers. Very few torrents have more items in them. + // The limit corresponds to about 25 MB, which might be a bit much for embedded + // systems. + // + // ``start`` and ``end`` defines the bencoded buffer to be decoded. ``ret`` is + // the ``lazy_entry`` which is filled in with the whole decoded tree. ``ec`` + // is a reference to an ``error_code`` which is set to describe the error encountered + // in case the function fails. ``error_pos`` is an optional pointer to an int, + // which will be set to the byte offset into the buffer where an error occurred, + // in case the function fails. + TORRENT_EXPORT int lazy_bdecode(char const* start, char const* end + , lazy_entry& ret, error_code& ec, int* error_pos = 0 + , int depth_limit = 1000, int item_limit = 1000000); + + // this is a string that is not NULL-terminated. Instead it + // comes with a length, specified in bytes. This is particularly + // useful when parsing bencoded structures, because strings are + // not NULL-terminated internally, and requiring NULL termination + // would require copying the string. + // + // see lazy_entry::string_pstr(). + struct TORRENT_EXPORT pascal_string + { + // construct a string pointing to the characters at ``p`` + // of length ``l`` characters. No NULL termination is required. + pascal_string(char const* p, int l): len(l), ptr(p) {} + + // the number of characters in the string. + int len; + + // the pointer to the first character in the string. This is + // not NULL terminated, but instead consult the ``len`` field + // to know how many characters follow. + char const* ptr; + + // lexicographical comparison of strings. Order is consisten + // with memcmp. + bool operator<(pascal_string const& rhs) const + { + return std::memcmp(ptr, rhs.ptr, (std::min)(len, rhs.len)) < 0 + || len < rhs.len; + } + }; + + struct lazy_dict_entry; + + // this object represent a node in a bencoded structure. It is a variant + // type whose concrete type is one of: + // + // 1. dictionary (maps strings -> lazy_entry) + // 2. list (sequence of lazy_entry, i.e. heterogenous) + // 3. integer + // 4. string + // + // There is also a ``none`` type, which is used for uninitialized + // lazy_entries. + struct TORRENT_EXPORT lazy_entry + { + // The different types a lazy_entry can have + enum entry_type_t + { + none_t, dict_t, list_t, string_t, int_t + }; + + lazy_entry() : m_begin(0), m_len(0), m_size(0), m_capacity(0), m_type(none_t) + { m_data.start = 0; } + + // tells you which specific type this lazy entry has. + // See entry_type_t. The type determines which subset of + // member functions are valid to use. + entry_type_t type() const { return (entry_type_t)m_type; } + + // start points to the first decimal digit + // length is the number of digits + void construct_int(char const* start, int length) + { + TORRENT_ASSERT(m_type == none_t); + m_type = int_t; + m_data.start = start; + m_size = length; + m_begin = start - 1; // include 'i' + m_len = length + 2; // include 'e' + } + + // if this is an integer, return the integer value + boost::int64_t int_value() const; + + // internal + void construct_string(char const* start, int length); + + // the string is not null-terminated! + // use string_length() to determine how many bytes + // are part of the string. + char const* string_ptr() const + { + TORRENT_ASSERT(m_type == string_t); + return m_data.start; + } + + // this will return a null terminated string + // it will write to the source buffer! + char const* string_cstr() const + { + TORRENT_ASSERT(m_type == string_t); + const_cast(m_data.start)[m_size] = 0; + return m_data.start; + } + + // if this is a string, returns a pascal_string + // representing the string value. + pascal_string string_pstr() const + { + TORRENT_ASSERT(m_type == string_t); + return pascal_string(m_data.start, m_size); + } + + // if this is a string, returns the string as a std::string. + // (which requires a copy) + std::string string_value() const + { + TORRENT_ASSERT(m_type == string_t); + return std::string(m_data.start, m_size); + } + + // if the lazy_entry is a string, returns the + // length of the string, in bytes. + int string_length() const + { return m_size; } + + // internal + void construct_dict(char const* begin) + { + TORRENT_ASSERT(m_type == none_t); + m_type = dict_t; + m_size = 0; + m_capacity = 0; + m_begin = begin; + } + + // internal + lazy_entry* dict_append(char const* name); + // internal + void pop(); + + // if this is a dictionary, look for a key ``name``, and return + // a pointer to its value, or NULL if there is none. + lazy_entry* dict_find(char const* name); + lazy_entry const* dict_find(char const* name) const + { return const_cast(this)->dict_find(name); } + lazy_entry const* dict_find_string(char const* name) const; + + // if this is a dictionary, look for a key ``name`` whose value + // is a string. If such key exist, return a pointer to + // its value, otherwise NULL. + std::string dict_find_string_value(char const* name) const; + pascal_string dict_find_pstr(char const* name) const; + + // if this is a dictionary, look for a key ``name`` whose value + // is an int. If such key exist, return a pointer to its value, + // otherwise NULL. + boost::int64_t dict_find_int_value(char const* name, boost::int64_t default_val = 0) const; + lazy_entry const* dict_find_int(char const* name) const; + + lazy_entry const* dict_find_dict(char const* name) const; + lazy_entry const* dict_find_list(char const* name) const; + + // if this is a dictionary, return the key value pair at + // position ``i`` from the dictionary. + std::pair dict_at(int i) const; + + // if this is a dictionary, return the number of items in it + int dict_size() const + { + TORRENT_ASSERT(m_type == dict_t); + return m_size; + } + + // internal + void construct_list(char const* begin) + { + TORRENT_ASSERT(m_type == none_t); + m_type = list_t; + m_size = 0; + m_capacity = 0; + m_begin = begin; + } + + // internal + lazy_entry* list_append(); + + // if this is a list, return the item at index ``i``. + lazy_entry* list_at(int i) + { + TORRENT_ASSERT(m_type == list_t); + TORRENT_ASSERT(i < int(m_size)); + return &m_data.list[i]; + } + lazy_entry const* list_at(int i) const + { return const_cast(this)->list_at(i); } + + std::string list_string_value_at(int i) const; + pascal_string list_pstr_at(int i) const; + boost::int64_t list_int_value_at(int i, boost::int64_t default_val = 0) const; + + // if this is a list, return the number of items in it. + int list_size() const + { + TORRENT_ASSERT(m_type == list_t); + return int(m_size); + } + + // end points one byte passed last byte in the source + // buffer backing the bencoded structure. + void set_end(char const* end) + { + TORRENT_ASSERT(end > m_begin); + m_len = end - m_begin; + } + + // internal + void clear(); + + // releases ownership of any memory allocated + void release() + { + m_data.start = 0; + m_size = 0; + m_capacity = 0; + m_type = none_t; + } + + // internal + ~lazy_entry() + { clear(); } + + // returns pointers into the source buffer where + // this entry has its bencoded data + std::pair data_section() const; + + // swap values of ``this`` and ``e``. + void swap(lazy_entry& e) + { + using std::swap; + boost::uint32_t tmp = e.m_type; + e.m_type = m_type; + m_type = tmp; + tmp = e.m_capacity; + e.m_capacity = m_capacity; + m_capacity = tmp; + swap(m_data.start, e.m_data.start); + swap(m_size, e.m_size); + swap(m_begin, e.m_begin); + swap(m_len, e.m_len); + } + + private: + + union data_t + { + lazy_dict_entry* dict; + lazy_entry* list; + char const* start; + } m_data; + + // used for dictionaries and lists to record the range + // in the original buffer they are based on + char const* m_begin; + // the number of bytes this entry extends in the + // bencoded byffer + boost::uint32_t m_len; + + // if list or dictionary, the number of items + boost::uint32_t m_size; + // if list or dictionary, allocated number of items + boost::uint32_t m_capacity:29; + // element type (dict, list, int, string) + boost::uint32_t m_type:3; + + // non-copyable + lazy_entry(lazy_entry const&); + lazy_entry const& operator=(lazy_entry const&); + }; + + struct lazy_dict_entry + { + char const* name; + lazy_entry val; + }; + + TORRENT_EXTRA_EXPORT std::string print_entry(lazy_entry const& e + , bool single_line = false, int indent = 0); + + TORRENT_EXPORT boost::system::error_category& get_bdecode_category(); + + namespace bdecode_errors + { + // libtorrent uses boost.system's ``error_code`` class to represent errors. libtorrent has + // its own error category get_bdecode_category() whith the error codes defined by error_code_enum. + enum error_code_enum + { + // Not an error + no_error = 0, + // expected string in bencoded string + expected_string, + // expected colon in bencoded string + expected_colon, + // unexpected end of file in bencoded string + unexpected_eof, + // expected value (list, dict, int or string) in bencoded string + expected_value, + // bencoded recursion depth limit exceeded + depth_exceeded, + // bencoded item count limit exceeded + limit_exceeded, + + // the number of error codes + error_code_max + }; + + // hidden + inline boost::system::error_code make_error_code(error_code_enum e) + { + return boost::system::error_code(e, get_bdecode_category()); + } + } +} + +#endif + +``` + +```c +// Complete file: lazy_bdecode.cpp (tree-sitter fallback) +/* + +Copyright (c) 2008-2012, Arvid Norberg +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "lazy_entry.hpp" +#include +#define __STDC_FORMAT_MACROS +#include + +namespace +{ + const int lazy_entry_grow_factor = 150; // percent + const int lazy_entry_dict_init = 5; + const int lazy_entry_list_init = 5; +} + +namespace libtorrent +{ + +#define TORRENT_FAIL_BDECODE(code) \ + { \ + ec = make_error_code(code); \ + while (!stack.empty()) { \ + top = stack.back(); \ + if (top->type() == lazy_entry::dict_t || top->type() == lazy_entry::list_t) top->pop(); \ + stack.pop_back(); \ + } \ + if (error_pos) *error_pos = start - orig_start; \ + return -1; \ + } + + bool is_digit(char c) { return c >= '0' && c <= '9'; } + + bool is_print(char c) { return c >= 32 && c < 127; } + + // fills in 'val' with what the string between start and the + // first occurance of the delimiter is interpreted as an int. + // return the pointer to the delimiter, or 0 if there is a + // parse error. val should be initialized to zero + char const* parse_int(char const* start, char const* end, char delimiter, boost::int64_t& val) + { + while (start < end && *start != delimiter) + { + if (!is_digit(*start)) { return 0; } + val *= 10; + val += *start - '0'; + ++start; + } + return start; + } + + char const* find_char(char const* start, char const* end, char delimiter) + { + while (start < end && *start != delimiter) ++start; + return start; + } + + // return 0 = success + int lazy_bdecode(char const* start, char const* end, lazy_entry& ret + , error_code& ec, int* error_pos, int depth_limit, int item_limit) + { + char const* const orig_start = start; + ret.clear(); + if (start == end) return 0; + + std::vector stack; + + stack.push_back(&ret); + while (start < end) + { + if (stack.empty()) break; // done! + + lazy_entry* top = stack.back(); + + if (int(stack.size()) > depth_limit) TORRENT_FAIL_BDECODE(bdecode_errors::depth_exceeded); + if (start >= end) TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); + char t = *start; + ++start; + if (start >= end && t != 'e') TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); + + switch (top->type()) + { + case lazy_entry::dict_t: + { + if (t == 'e') + { + top->set_end(start); + stack.pop_back(); + continue; + } + if (!is_digit(t)) TORRENT_FAIL_BDECODE(bdecode_errors::expected_string); + boost::int64_t len = t - '0'; + start = parse_int(start, end, ':', len); + if (start == 0 || start + len + 3 > end || *start != ':') + TORRENT_FAIL_BDECODE(bdecode_errors::expected_colon); + ++start; + if (start == end) TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); + lazy_entry* ent = top->dict_append(start); + if (ent == 0) TORRENT_FAIL_BDECODE(boost::system::errc::not_enough_memory); + start += len; + if (start >= end) TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); + stack.push_back(ent); + t = *start; + ++start; + break; + } + case lazy_entry::list_t: + { + if (t == 'e') + { + top->set_end(start); + stack.pop_back(); + continue; + } + lazy_entry* ent = top->list_append(); + if (ent == 0) TORRENT_FAIL_BDECODE(boost::system::errc::not_enough_memory); + stack.push_back(ent); + break; + } + default: break; + } + + --item_limit; + if (item_limit <= 0) TORRENT_FAIL_BDECODE(bdecode_errors::limit_exceeded); + + top = stack.back(); + switch (t) + { + case 'd': + top->construct_dict(start - 1); + continue; + case 'l': + top->construct_list(start - 1); + continue; + case 'i': + { + char const* int_start = start; + start = find_char(start, end, 'e'); + top->construct_int(int_start, start - int_start); + if (start == end) TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); + TORRENT_ASSERT(*start == 'e'); + ++start; + stack.pop_back(); + continue; + } + default: + { + if (!is_digit(t)) + TORRENT_FAIL_BDECODE(bdecode_errors::expected_value); + + boost::int64_t len = t - '0'; + start = parse_int(start, end, ':', len); + if (start == 0 || start + len + 1 > end || *start != ':') + TORRENT_FAIL_BDECODE(bdecode_errors::expected_colon); + ++start; + top->construct_string(start, int(len)); + stack.pop_back(); + start += len; + continue; + } + } + return 0; + } + return 0; + } + + boost::int64_t lazy_entry::int_value() const + { + TORRENT_ASSERT(m_type == int_t); + boost::int64_t val = 0; + bool negative = false; + if (*m_data.start == '-') negative = true; + parse_int(negative?m_data.start+1:m_data.start, m_data.start + m_size, 'e', val); + if (negative) val = -val; + return val; + } + + lazy_entry* lazy_entry::dict_append(char const* name) + { + TORRENT_ASSERT(m_type == dict_t); + TORRENT_ASSERT(m_size <= m_capacity); + if (m_capacity == 0) + { + int capacity = lazy_entry_dict_init; + m_data.dict = new (std::nothrow) lazy_dict_entry[capacity]; + if (m_data.dict == 0) return 0; + m_capacity = capacity; + } + else if (m_size == m_capacity) + { + int capacity = m_capacity * lazy_entry_grow_factor / 100; + lazy_dict_entry* tmp = new (std::nothrow) lazy_dict_entry[capacity]; + if (tmp == 0) return 0; + std::memcpy(tmp, m_data.dict, sizeof(lazy_dict_entry) * m_size); + for (int i = 0; i < int(m_size); ++i) m_data.dict[i].val.release(); + delete[] m_data.dict; + m_data.dict = tmp; + m_capacity = capacity; + } + + TORRENT_ASSERT(m_size < m_capacity); + lazy_dict_entry& ret = m_data.dict[m_size++]; + ret.name = name; + return &ret.val; + } + + void lazy_entry::pop() + { + if (m_size > 0) --m_size; + } + + namespace + { + // the number of decimal digits needed + // to represent the given value + int num_digits(int val) + { + int ret = 1; + while (val >= 10) + { + ++ret; + val /= 10; + +// ... [TRUNCATED: 93 lines omitted] ... + + for (int i = 0; i < int(m_size); ++i) + { + lazy_dict_entry& e = m_data.dict[i]; + if (string_equal(name, e.name, e.val.m_begin - e.name)) + return &e.val; + } + return 0; + } + + lazy_entry* lazy_entry::list_append() + { + TORRENT_ASSERT(m_type == list_t); + TORRENT_ASSERT(m_size <= m_capacity); + if (m_capacity == 0) + { + int capacity = lazy_entry_list_init; + m_data.list = new (std::nothrow) lazy_entry[capacity]; + if (m_data.list == 0) return 0; + m_capacity = capacity; + } + else if (m_size == m_capacity) + { + int capacity = m_capacity * lazy_entry_grow_factor / 100; + lazy_entry* tmp = new (std::nothrow) lazy_entry[capacity]; + if (tmp == 0) return 0; + std::memcpy(tmp, m_data.list, sizeof(lazy_entry) * m_size); + for (int i = 0; i < int(m_size); ++i) m_data.list[i].release(); + delete[] m_data.list; + m_data.list = tmp; + m_capacity = capacity; + } + + TORRENT_ASSERT(m_size < m_capacity); + return m_data.list + (m_size++); + } + + std::string lazy_entry::list_string_value_at(int i) const + { + lazy_entry const* e = list_at(i); + if (e == 0 || e->type() != lazy_entry::string_t) return std::string(); + return e->string_value(); + } + + pascal_string lazy_entry::list_pstr_at(int i) const + { + lazy_entry const* e = list_at(i); + if (e == 0 || e->type() != lazy_entry::string_t) return pascal_string(0, 0); + return e->string_pstr(); + } + + boost::int64_t lazy_entry::list_int_value_at(int i, boost::int64_t default_val) const + { + lazy_entry const* e = list_at(i); + if (e == 0 || e->type() != lazy_entry::int_t) return default_val; + return e->int_value(); + } + + void lazy_entry::clear() + { + switch (m_type) + { + case list_t: delete[] m_data.list; break; + case dict_t: delete[] m_data.dict; break; + default: break; + } + m_data.start = 0; + m_size = 0; + m_capacity = 0; + m_type = none_t; + } + + std::pair lazy_entry::data_section() const + { + typedef std::pair return_t; + return return_t(m_begin, m_len); + } + + int line_longer_than(lazy_entry const& e, int limit) + { + int line_len = 0; + switch (e.type()) + { + case lazy_entry::list_t: + line_len += 4; + if (line_len > limit) return -1; + for (int i = 0; i < e.list_size(); ++i) + { + int ret = line_longer_than(*e.list_at(i), limit - line_len); + if (ret == -1) return -1; + line_len += ret + 2; + } + break; + case lazy_entry::dict_t: + line_len += 4; + if (line_len > limit) return -1; + for (int i = 0; i < e.dict_size(); ++i) + { + line_len += 4 + e.dict_at(i).first.size(); + if (line_len > limit) return -1; + int ret = line_longer_than(*e.dict_at(i).second, limit - line_len); + if (ret == -1) return -1; + line_len += ret + 1; + } + break; + case lazy_entry::string_t: + line_len += 3 + e.string_length(); + break; + case lazy_entry::int_t: + { + boost::int64_t val = e.int_value(); + while (val > 0) + { + ++line_len; + val /= 10; + } + line_len += 2; + } + break; + case lazy_entry::none_t: + line_len += 4; + break; + } + + if (line_len > limit) return -1; + return line_len; + } + + std::string print_entry(lazy_entry const& e, bool single_line, int indent) + { + char indent_str[200]; + memset(indent_str, ' ', 200); + indent_str[0] = ','; + indent_str[1] = '\n'; + indent_str[199] = 0; + if (indent < 197 && indent >= 0) indent_str[indent+2] = 0; + std::string ret; + switch (e.type()) + { + case lazy_entry::none_t: return "none"; + case lazy_entry::int_t: + { + char str[100]; + snprintf(str, sizeof(str), "%" PRId64, e.int_value()); + return str; + } + case lazy_entry::string_t: + { + bool printable = true; + char const* str = e.string_ptr(); + for (int i = 0; i < e.string_length(); ++i) + { + using namespace std; + if (is_print((unsigned char)str[i])) continue; + printable = false; + break; + } + ret += "'"; + if (printable) + { + ret += e.string_value(); + ret += "'"; + return ret; + } + for (int i = 0; i < e.string_length(); ++i) + { + char tmp[5]; + snprintf(tmp, sizeof(tmp), "%02x", (unsigned char)str[i]); + ret += tmp; + } + ret += "'"; + return ret; + } + case lazy_entry::list_t: + { + ret += '['; + bool one_liner = line_longer_than(e, 200) != -1 || single_line; + + if (!one_liner) ret += indent_str + 1; + for (int i = 0; i < e.list_size(); ++i) + { + if (i == 0 && one_liner) ret += " "; + ret += print_entry(*e.list_at(i), single_line, indent + 2); + if (i < e.list_size() - 1) ret += (one_liner?", ":indent_str); + else ret += (one_liner?" ":indent_str+1); + } + ret += "]"; + return ret; + } + case lazy_entry::dict_t: + { + ret += "{"; + bool one_liner = line_longer_than(e, 200) != -1 || single_line; + + if (!one_liner) ret += indent_str+1; + for (int i = 0; i < e.dict_size(); ++i) + { + if (i == 0 && one_liner) ret += " "; + std::pair ent = e.dict_at(i); + ret += "'"; + ret += ent.first; + ret += "': "; + ret += print_entry(*ent.second, single_line, indent + 2); + if (i < e.dict_size() - 1) ret += (one_liner?", ":indent_str); + else ret += (one_liner?" ":indent_str+1); + } + ret += "}"; + return ret; + } + } + return ret; + } + + struct bdecode_error_category : boost::system::error_category + { + virtual const char* name() const BOOST_SYSTEM_NOEXCEPT; + virtual std::string message(int ev) const BOOST_SYSTEM_NOEXCEPT; + virtual boost::system::error_condition default_error_condition(int ev) const BOOST_SYSTEM_NOEXCEPT + { return boost::system::error_condition(ev, *this); } + }; + + const char* bdecode_error_category::name() const BOOST_SYSTEM_NOEXCEPT + { + return "bdecode error"; + } + + std::string bdecode_error_category::message(int ev) const BOOST_SYSTEM_NOEXCEPT + { + static char const* msgs[] = + { + "no error", + "expected string in bencoded string", + "expected colon in bencoded string", + "unexpected end of file in bencoded string", + "expected value (list, dict, int or string) in bencoded string", + "bencoded nesting depth exceeded", + "bencoded item count limit exceeded", + }; + if (ev < 0 || ev >= int(sizeof(msgs)/sizeof(msgs[0]))) + return "Unknown error"; + return msgs[ev]; + } + + boost::system::error_category& get_bdecode_category() + { + static bdecode_error_category bdecode_category; + return bdecode_category; + } + +}; + +``` + +## Bug Fix Patch + +```diff +diff --git a/lazy_bdecode.cpp b/lazy_bdecode.cpp +index 3bd4080..0f7b292 100644 +--- a/lazy_bdecode.cpp ++++ b/lazy_bdecode.cpp +@@ -1,6 +1,6 @@ + /* + +-Copyright (c) 2008-2012, Arvid Norberg ++Copyright (c) 2008-2014, Arvid Norberg + All rights reserved. + + Redistribution and use in source and binary forms, with or without +@@ -45,35 +45,62 @@ namespace + namespace libtorrent + { + +-#define TORRENT_FAIL_BDECODE(code) \ +- { \ +- ec = make_error_code(code); \ +- while (!stack.empty()) { \ +- top = stack.back(); \ +- if (top->type() == lazy_entry::dict_t || top->type() == lazy_entry::list_t) top->pop(); \ +- stack.pop_back(); \ +- } \ +- if (error_pos) *error_pos = start - orig_start; \ +- return -1; \ ++ namespace ++ { ++ int fail(int* error_pos ++ , std::vector& stack ++ , char const* start ++ , char const* orig_start) ++ { ++ while (!stack.empty()) { ++ lazy_entry* top = stack.back(); ++ if (top->type() == lazy_entry::dict_t || top->type() == lazy_entry::list_t) ++ { ++ top->pop(); ++ break; ++ } ++ stack.pop_back(); ++ } ++ if (error_pos) *error_pos = start - orig_start; ++ return -1; ++ } + } + +- bool is_digit(char c) { return c >= '0' && c <= '9'; } ++#define TORRENT_FAIL_BDECODE(code) do { ec = make_error_code(code); return fail(error_pos, stack, start, orig_start); } while (false) + +- bool is_print(char c) { return c >= 32 && c < 127; } ++ namespace { bool numeric(char c) { return c >= '0' && c <= '9'; } } + + // fills in 'val' with what the string between start and the + // first occurance of the delimiter is interpreted as an int. + // return the pointer to the delimiter, or 0 if there is a + // parse error. val should be initialized to zero +- char const* parse_int(char const* start, char const* end, char delimiter, boost::int64_t& val) ++ char const* parse_int(char const* start, char const* end, char delimiter ++ , boost::int64_t& val, bdecode_errors::error_code_enum& ec) + { + while (start < end && *start != delimiter) + { +- if (!is_digit(*start)) { return 0; } ++ if (!numeric(*start)) ++ { ++ ec = bdecode_errors::expected_string; ++ return start; ++ } ++ if (val > INT64_MAX / 10) ++ { ++ ec = bdecode_errors::overflow; ++ return start; ++ } + val *= 10; +- val += *start - '0'; ++ int digit = *start - '0'; ++ if (val > INT64_MAX - digit) ++ { ++ ec = bdecode_errors::overflow; ++ return start; ++ } ++ val += digit; + ++start; + } ++ if (*start != delimiter) ++ ec = bdecode_errors::expected_colon; + return start; + } + +@@ -94,7 +121,7 @@ namespace libtorrent + std::vector stack; + + stack.push_back(&ret); +- while (start < end) ++ while (start <= end) + { + if (stack.empty()) break; // done! + +@@ -116,11 +143,19 @@ namespace libtorrent + stack.pop_back(); + continue; + } +- if (!is_digit(t)) TORRENT_FAIL_BDECODE(bdecode_errors::expected_string); ++ if (!numeric(t)) TORRENT_FAIL_BDECODE(bdecode_errors::expected_string); + boost::int64_t len = t - '0'; +- start = parse_int(start, end, ':', len); +- if (start == 0 || start + len + 3 > end || *start != ':') +- TORRENT_FAIL_BDECODE(bdecode_errors::expected_colon); ++ bdecode_errors::error_code_enum e = bdecode_errors::no_error; ++ start = parse_int(start, end, ':', len, e); ++ if (e) ++ TORRENT_FAIL_BDECODE(e); ++ ++ if (start + len + 1 > end) ++ TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); ++ ++ if (len < 0) ++ TORRENT_FAIL_BDECODE(bdecode_errors::overflow); ++ + ++start; + if (start == end) TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); + lazy_entry* ent = top->dict_append(start); +@@ -173,13 +208,19 @@ namespace libtorrent + } + default: + { +- if (!is_digit(t)) ++ if (!numeric(t)) + TORRENT_FAIL_BDECODE(bdecode_errors::expected_value); + + boost::int64_t len = t - '0'; +- start = parse_int(start, end, ':', len); +- if (start == 0 || start + len + 1 > end || *start != ':') +- TORRENT_FAIL_BDECODE(bdecode_errors::expected_colon); ++ bdecode_errors::error_code_enum e = bdecode_errors::no_error; ++ start = parse_int(start, end, ':', len, e); ++ if (e) ++ TORRENT_FAIL_BDECODE(e); ++ if (start + len + 1 > end) ++ TORRENT_FAIL_BDECODE(bdecode_errors::unexpected_eof); ++ if (len < 0) ++ TORRENT_FAIL_BDECODE(bdecode_errors::overflow); ++ + ++start; + top->construct_string(start, int(len)); + stack.pop_back(); +@@ -198,7 +239,10 @@ namespace libtorrent + boost::int64_t val = 0; + bool negative = false; + if (*m_data.start == '-') negative = true; +- parse_int(negative?m_data.start+1:m_data.start, m_data.start + m_size, 'e', val); ++ bdecode_errors::error_code_enum ec = bdecode_errors::no_error; ++ parse_int(m_data.start + negative ++ , m_data.start + m_size, 'e', val, ec); ++ if (ec) return 0; + if (negative) val = -val; + return val; + } +@@ -331,6 +375,13 @@ namespace libtorrent + return e; + } + ++ lazy_entry const* lazy_entry::dict_find_dict(std::string const& name) const ++ { ++ lazy_entry const* e = dict_find(name); ++ if (e == 0 || e->type() != lazy_entry::dict_t) return 0; ++ return e; ++ } ++ + lazy_entry const* lazy_entry::dict_find_list(char const* name) const + { + lazy_entry const* e = dict_find(name); +@@ -350,6 +401,19 @@ namespace libtorrent + return 0; + } + ++ lazy_entry* lazy_entry::dict_find(std::string const& name) ++ { ++ TORRENT_ASSERT(m_type == dict_t); ++ for (int i = 0; i < int(m_size); ++i) ++ { ++ lazy_dict_entry& e = m_data.dict[i]; ++ if (name.size() != e.val.m_begin - e.name) continue; ++ if (std::equal(name.begin(), name.end(), e.name)) ++ return &e.val; ++ } ++ return 0; ++ } ++ + lazy_entry* lazy_entry::list_append() + { + TORRENT_ASSERT(m_type == list_t); +@@ -492,23 +556,50 @@ namespace libtorrent + char const* str = e.string_ptr(); + for (int i = 0; i < e.string_length(); ++i) + { +- using namespace std; +- if (is_print((unsigned char)str[i])) continue; ++ char c = str[i]; ++ if (c >= 32 && c < 127) continue; + printable = false; + break; + } + ret += "'"; + if (printable) + { +- ret += e.string_value(); ++ if (single_line && e.string_length() > 30) ++ { ++ ret.append(e.string_ptr(), 14); ++ ret += "..."; ++ ret.append(e.string_ptr() + e.string_length()-14, 14); ++ } ++ else ++ ret.append(e.string_ptr(), e.string_length()); + ret += "'"; + return ret; + } +- for (int i = 0; i < e.string_length(); ++i) ++ if (single_line && e.string_length() > 20) + { +- char tmp[5]; +- snprintf(tmp, sizeof(tmp), "%02x", (unsigned char)str[i]); +- ret += tmp; ++ for (int i = 0; i < 9; ++i) ++ { ++ char tmp[5]; ++ snprintf(tmp, sizeof(tmp), "%02x", (unsigned char)str[i]); ++ ret += tmp; ++ } ++ ret += "..."; ++ for (int i = e.string_length() - 9 ++ , len(e.string_length()); i < len; ++i) ++ { ++ char tmp[5]; ++ snprintf(tmp, sizeof(tmp), "%02x", (unsigned char)str[i]); ++ ret += tmp; ++ } ++ } ++ else ++ { ++ for (int i = 0; i < e.string_length(); ++i) ++ { ++ char tmp[5]; ++ snprintf(tmp, sizeof(tmp), "%02x", (unsigned char)str[i]); ++ ret += tmp; ++ } + } + ret += "'"; + return ret; +@@ -577,6 +668,7 @@ namespace libtorrent + "expected value (list, dict, int or string) in bencoded string", + "bencoded nesting depth exceeded", + "bencoded item count limit exceeded", ++ "integer overflow", + }; + if (ev < 0 || ev >= int(sizeof(msgs)/sizeof(msgs[0]))) + return "Unknown error"; +@@ -589,5 +681,12 @@ namespace libtorrent + return bdecode_category; + } + ++ namespace bdecode_errors ++ { ++ boost::system::error_code make_error_code(error_code_enum e) ++ { ++ return boost::system::error_code(e, get_bdecode_category()); ++ } ++ } + }; + +diff --git a/lazy_entry.hpp b/lazy_entry.hpp +index 70cec90..0e1bfb6 100644 +--- a/lazy_entry.hpp ++++ b/lazy_entry.hpp +@@ -1,6 +1,6 @@ + /* + +-Copyright (c) 2003-2012, Arvid Norberg ++Copyright (c) 2003-2014, Arvid Norberg + All rights reserved. + + Redistribution and use in source and binary forms, with or without +@@ -37,6 +37,7 @@ POSSIBILITY OF SUCH DAMAGE. + #include + #include + #include ++#include + #include + + #define TORRENT_EXPORT +@@ -136,6 +137,7 @@ namespace libtorrent + none_t, dict_t, list_t, string_t, int_t + }; + ++ // internal + lazy_entry() : m_begin(0), m_len(0), m_size(0), m_capacity(0), m_type(none_t) + { m_data.start = 0; } + +@@ -156,7 +158,7 @@ namespace libtorrent + m_len = length + 2; // include 'e' + } + +- // if this is an integer, return the integer value ++ // requires the type to be an integer. return the integer value + boost::int64_t int_value() const; + + // internal +@@ -221,6 +223,9 @@ namespace libtorrent + lazy_entry* dict_find(char const* name); + lazy_entry const* dict_find(char const* name) const + { return const_cast(this)->dict_find(name); } ++ lazy_entry* dict_find(std::string const& name); ++ lazy_entry const* dict_find(std::string const& name) const ++ { return const_cast(this)->dict_find(name); } + lazy_entry const* dict_find_string(char const* name) const; + + // if this is a dictionary, look for a key ``name`` whose value +@@ -235,14 +240,22 @@ namespace libtorrent + boost::int64_t dict_find_int_value(char const* name, boost::int64_t default_val = 0) const; + lazy_entry const* dict_find_int(char const* name) const; + ++ // these functions require that ``this`` is a dictionary. ++ // (this->type() == dict_t). They look for an element with the ++ // specified name in the dictionary. ``dict_find_dict`` only ++ // finds dictionaries and ``dict_find_list`` only finds lists. ++ // if no key with the corresponding value of the right type is ++ // found, NULL is returned. + lazy_entry const* dict_find_dict(char const* name) const; ++ lazy_entry const* dict_find_dict(std::string const& name) const; + lazy_entry const* dict_find_list(char const* name) const; + + // if this is a dictionary, return the key value pair at + // position ``i`` from the dictionary. + std::pair dict_at(int i) const; + +- // if this is a dictionary, return the number of items in it ++ // requires that ``this`` is a dictionary. return the ++ // number of items in it + int dict_size() const + { + TORRENT_ASSERT(m_type == dict_t); +@@ -262,7 +275,8 @@ namespace libtorrent + // internal + lazy_entry* list_append(); + +- // if this is a list, return the item at index ``i``. ++ // requires that ``this`` is a list. return ++ // the item at index ``i``. + lazy_entry* list_at(int i) + { + TORRENT_ASSERT(m_type == list_t); +@@ -272,8 +286,19 @@ namespace libtorrent + lazy_entry const* list_at(int i) const + { return const_cast(this)->list_at(i); } + ++ // these functions require ``this`` to have the type list. ++ // (this->type() == list_t). ``list_string_value_at`` returns ++ // the string at index ``i``. ``list_pstr_at`` ++ // returns a pascal_string of the string value at index ``i``. ++ // if the element at ``i`` is not a string, an empty string ++ // is returned. + std::string list_string_value_at(int i) const; + pascal_string list_pstr_at(int i) const; ++ ++ // this function require ``this`` to have the type list. ++ // (this->type() == list_t). returns the integer value at ++ // index ``i``. If the element at ``i`` is not an integer ++ // ``default_val`` is returned, which defaults to 0. + boost::int64_t list_int_value_at(int i, boost::int64_t default_val = 0) const; + + // if this is a list, return the number of items in it. +@@ -283,7 +308,7 @@ namespace libtorrent + return int(m_size); + } + +- // end points one byte passed last byte in the source ++ // internal: end points one byte passed last byte in the source + // buffer backing the bencoded structure. + void set_end(char const* end) + { +@@ -294,7 +319,7 @@ namespace libtorrent + // internal + void clear(); + +- // releases ownership of any memory allocated ++ // internal: releases ownership of any memory allocated + void release() + { + m_data.start = 0; +@@ -361,9 +386,12 @@ namespace libtorrent + lazy_entry val; + }; + +- TORRENT_EXTRA_EXPORT std::string print_entry(lazy_entry const& e ++ // print the bencoded structure in a human-readable format to a stting ++ // that's returned. ++ TORRENT_EXPORT std::string print_entry(lazy_entry const& e + , bool single_line = false, int indent = 0); + ++ // get the ``error_category`` for bdecode errors + TORRENT_EXPORT boost::system::error_category& get_bdecode_category(); + + namespace bdecode_errors +@@ -386,17 +414,21 @@ namespace libtorrent + depth_exceeded, + // bencoded item count limit exceeded + limit_exceeded, ++ // integer overflow ++ overflow, + + // the number of error codes + error_code_max + }; + + // hidden +- inline boost::system::error_code make_error_code(error_code_enum e) +- { +- return boost::system::error_code(e, get_bdecode_category()); +- } ++ TORRENT_EXPORT boost::system::error_code make_error_code(error_code_enum e); + } ++ ++ TORRENT_EXTRA_EXPORT char const* parse_int(char const* start ++ , char const* end, char delimiter, boost::int64_t& val ++ , bdecode_errors::error_code_enum& ec); ++ + } + + #endif +``` diff --git a/prompt_template/semgrep_examples/integer-overflow/pattern.md b/prompt_template/semgrep_examples/integer-overflow/pattern.md new file mode 100644 index 00000000..76e150df --- /dev/null +++ b/prompt_template/semgrep_examples/integer-overflow/pattern.md @@ -0,0 +1,26 @@ +## Bug Pattern + +The bug pattern is **integer overflow vulnerability caused by insufficient bounds checking in arithmetic operations**. + +The issue occurs when: +1. Integer arithmetic operations are performed without proper overflow checking +2. The result of arithmetic operations can exceed the maximum value for the integer type +3. No validation is performed before the arithmetic operation to ensure the result stays within valid bounds + +Integer overflow vulnerabilities can lead to: +- Buffer overflows when used for memory allocation sizes +- Security bypasses when used in bounds checking +- Unexpected program behavior due to wraparound +- Denial of service attacks +- Remote code execution in severe cases + +The pattern specifically involves: +- Performing arithmetic operations (addition, multiplication, etc.) on user-controlled or external input +- Missing overflow checks before arithmetic operations +- Using the result of potentially overflowing operations for critical decisions like memory allocation or array indexing +- Particularly dangerous when `int64_t` or similar large integer types are involved, as overflow can be subtle + +Common scenarios include: +- String parsing functions that accumulate digit values without checking for overflow +- Memory allocation calculations that multiply size by count +- Array indexing calculations that add offsets to base addresses diff --git a/prompt_template/semgrep_examples/integer-overflow/plan.md b/prompt_template/semgrep_examples/integer-overflow/plan.md new file mode 100644 index 00000000..04030f8f --- /dev/null +++ b/prompt_template/semgrep_examples/integer-overflow/plan.md @@ -0,0 +1,47 @@ +## Plan + +### Objective +Create a Semgrep rule to detect integer overflow vulnerabilities where arithmetic operations lack proper bounds checking. + +### Detection Strategy + +1. **Identify Vulnerable Arithmetic Operations:** + - Pattern to match arithmetic operations on integer variables, especially `+=`, `*=`, direct assignment with arithmetic + - Focus on operations involving user input or external data + - Pay special attention to loops where values are accumulated + +2. **Detect Missing Overflow Checks:** + - Look for arithmetic operations without preceding overflow validation + - Check for patterns where values are used directly in arithmetic without bounds checking + - Identify cases where maximum value constants (like `INT64_MAX`) are not referenced + +3. **Pattern Matching Logic:** + - Use Semgrep's metavariable matching to track variables across operations + - Look for patterns like `$VAR += $EXPR` without prior `$VAR > MAX_VAL - $EXPR` checks + - Focus on parsing functions, especially those processing numeric strings + - Capture arithmetic operations on function parameters or loop variables + +4. **Handle Common Scenarios:** + - String-to-integer parsing functions that accumulate digit values + - Memory allocation size calculations + - Array index calculations with arithmetic + - Loop counters that can overflow + +5. **Rule Structure:** + - Use `pattern-either` to catch multiple types of arithmetic operations (`+=`, `*=`, `= $VAR + $EXPR`) + - Use `pattern-inside` to focus on function contexts, especially parsing functions + - Use metavariables to track the same variable across operations + - Use `pattern-not` to exclude cases where overflow checks are present + - Provide clear error message explaining the overflow risk and mitigation + +6. **Minimize False Positives:** + - Use `pattern-not` to exclude cases where overflow checking is already implemented + - Consider excluding operations on small constants that cannot cause overflow + - Focus on operations involving external input or variables that could be large + - Exclude cases where the variable type is small enough that overflow is unlikely + +7. **Target High-Risk Functions:** + - Focus on parsing functions (like `parse_int`, `str_to_num`, etc.) + - Memory allocation wrapper functions + - Functions that process user input or network data + - Mathematical utility functions that combine multiple values diff --git a/prompt_template/semgrep_examples/integer-overflow/rule.yml b/prompt_template/semgrep_examples/integer-overflow/rule.yml new file mode 100644 index 00000000..afe56f81 --- /dev/null +++ b/prompt_template/semgrep_examples/integer-overflow/rule.yml @@ -0,0 +1,32 @@ +rules: +- id: vuln-bootstrap-dht-bbc0b719 + pattern: "char const* parse_int(char const* $START, char const* $END, char $DELIMITER,\ + \ boost::int64_t& $VAL)\n{\n ...\n if (!is_digit(*$START)) { return 0; }\n\ + \ ...\n $VAL += *$START - '0';\n ...\n}\n" + pattern-not: "char const* parse_int(char const* $START, char const* $END, char $DELIMITER,\ + \ boost::int64_t& $VAL, bdecode_errors::error_code_enum& $EC)\n{\n ...\n \ + \ if (!numeric(*$START))\n {\n $EC = bdecode_errors::expected_string;\n\ + \ return $START;\n }\n ...\n int $DIGIT = *$START - '0';\n \ + \ if ($VAL > INT64_MAX - $DIGIT)\n {\n $EC = bdecode_errors::overflow;\n\ + \ return $START;\n }\n $VAL += $DIGIT;\n ...\n}\n" + languages: + - cpp + message: 'The function `parse_int` is vulnerable to integer overflow due to insufficient + bounds checking when adding digits to the integer value. This can lead to undefined + behavior or security vulnerabilities. The fix adds proper overflow checks before + performing arithmetic operations. To fix this, ensure that all integer arithmetic + operations are checked for potential overflow. + + ' + severity: ERROR + metadata: + source-url: github.com/bittorrent/bootstrap-dht/commit/bbc0b7191e3f48461ca6e5b1b34bdf4b3f1e79a9 + category: security + cwe: + - CWE-190 + owasp: + - A9:2021-Security Logging and Monitoring Failures + references: + - https://cwe.mitre.org/data/definitions/190.html + technology: + - cpp diff --git a/prompt_template/semgrep_examples/null-ptr-dereference/patch.md b/prompt_template/semgrep_examples/null-ptr-dereference/patch.md new file mode 100644 index 00000000..f393d8e2 --- /dev/null +++ b/prompt_template/semgrep_examples/null-ptr-dereference/patch.md @@ -0,0 +1,192 @@ +## Patch Description + +check whether referenced PPS exists (fixes #393) + +The code was accessing a PPS object field (pps_read) without first checking +if the PPS object itself is null. This can lead to a null pointer dereference +vulnerability that causes crashes or undefined behavior. + +The fix adds a null check before accessing the object's fields and properly +handles the error case with appropriate warning and return false. + +Fixes: CVE-XXXX (potential null pointer dereference) +Source: https://github.com/strukturag/libde265/commit/0b1752abff97cb542941d317a0d18aa50cb199b1 + +## Buggy Code + +```cpp +// libde265/decctx.cc +// returns whether we can continue decoding the stream or whether we should give up +bool decoder_context::process_slice_segment_header(slice_segment_header* hdr, + de265_error* err, de265_PTS pts, + nal_header* nal_hdr, + void* user_data) +{ + *err = DE265_OK; + + flush_reorder_buffer_at_this_frame = false; + + + // get PPS and SPS for this slice + + int pps_id = hdr->slice_pic_parameter_set_id; + if (pps[pps_id]->pps_read==false) { + logerror(LogHeaders, "PPS %d has not been read\n", pps_id); + assert(false); // TODO + + } + + current_pps = pps[pps_id]; + current_sps = sps[ (int)current_pps->seq_parameter_set_id ]; + current_vps = vps[ (int)current_sps->video_parameter_set_id ]; + + calc_tid_and_framerate_ratio(); + + + // --- prepare decoding of new picture --- + + if (hdr->first_slice_segment_in_pic_flag) { + + // previous picture has been completely decoded + + //ctx->push_current_picture_to_output_queue(); + + current_image_poc_lsb = hdr->slice_pic_order_cnt_lsb; + + + seq_parameter_set* sps = current_sps.get(); + + + // --- find and allocate image buffer for decoding --- + + int image_buffer_idx; + bool isOutputImage = (!sps->sample_adaptive_offset_enabled_flag || param_disable_sao); + image_buffer_idx = dpb.new_image(current_sps, this, pts, user_data, isOutputImage); + if (image_buffer_idx < 0) { + *err = (de265_error)(-image_buffer_idx); + return false; + } + + /*de265_image* */ img = dpb.get_image(image_buffer_idx); + img->nal_hdr = *nal_hdr; + + // Note: sps is already set in new_image() -> ??? still the case with shared_ptr ? + + img->set_headers(current_vps, current_sps, current_pps); + + img->decctx = this; + + img->clear_metadata(); + + + if (isIRAP(nal_unit_type)) { + if (isIDR(nal_unit_type) || + isBLA(nal_unit_type) || + first_decoded_picture || + FirstAfterEndOfSequenceNAL) + { + NoRaslOutputFlag = true; + FirstAfterEndOfSequenceNAL = false; + } + else if (0) // TODO: set HandleCraAsBlaFlag by external means + { + } + else + { + NoRaslOutputFlag = false; + HandleCraAsBlaFlag = false; + } + } + + + if (isRASL(nal_unit_type) && + NoRaslOutputFlag) + { + img->PicOutputFlag = false; + } + else + { + img->PicOutputFlag = !!hdr->pic_output_flag; + } + + process_picture_order_count(hdr); + + if (hdr->first_slice_segment_in_pic_flag) { + // mark picture so that it is not overwritten by unavailable reference frames + img->PicState = UsedForShortTermReference; + + *err = process_reference_picture_set(hdr); + if (*err != DE265_OK) { + return false; + } + } + + img->PicState = UsedForShortTermReference; + + log_set_current_POC(img->PicOrderCntVal); + + + // next image is not the first anymore + + first_decoded_picture = false; + } + else { + // claims to be not the first slice, but there is no active image available + + if (img == NULL) { + return false; + } + } + + if (hdr->slice_type == SLICE_TYPE_B || + hdr->slice_type == SLICE_TYPE_P) + { + bool success = construct_reference_picture_lists(hdr); + if (!success) { + return false; + } + } + + //printf("process slice segment header\n"); + + loginfo(LogHeaders,"end of process-slice-header\n"); + dpb.log_dpb_content(); + + + if (hdr->dependent_slice_segment_flag==0) { + hdr->SliceAddrRS = hdr->slice_segment_address; + } else { + hdr->SliceAddrRS = previous_slice_header->SliceAddrRS; + } + + previous_slice_header = hdr; + + + loginfo(LogHeaders,"SliceAddrRS = %d\n",hdr->SliceAddrRS); + + return true; +} +``` + + +## Bug Fix Patch + +```diff +diff --git a/libde265/decctx.cc b/libde265/decctx.cc +index abc123..def456 100644 +--- a/libde265/decctx.cc ++++ b/libde265/decctx.cc +@@ -2004,9 +2004,10 @@ bool decoder_context::process_slice_segment_header(slice_segment_header* hdr, + // get PPS and SPS for this slice + + int pps_id = hdr->slice_pic_parameter_set_id; +- if (pps[pps_id]->pps_read==false) { ++ if (pps[pps_id] == nullptr || pps[pps_id]->pps_read == false) { + logerror(LogHeaders, "PPS %d has not been read\n", pps_id); +- assert(false); // TODO ++ img->decctx->add_warning(DE265_WARNING_NONEXISTING_PPS_REFERENCED, false); ++ return false; + } + + current_pps = pps[pps_id]; +``` \ No newline at end of file diff --git a/prompt_template/semgrep_examples/null-ptr-dereference/pattern.md b/prompt_template/semgrep_examples/null-ptr-dereference/pattern.md new file mode 100644 index 00000000..e7fa6560 --- /dev/null +++ b/prompt_template/semgrep_examples/null-ptr-dereference/pattern.md @@ -0,0 +1,23 @@ +### Bug Pattern + +The bug pattern identified in this semgrep rule is a **null pointer dereference vulnerability**. The code directly accesses a field of a PPS (Picture Parameter Set) object without first verifying that the object pointer is not null. + +**Problematic Pattern:** +```cpp +if (pps[id]->field == false) { + // ... code that processes the condition +} +``` + +**Root Cause:** +- The code assumes that `pps[id]` is a valid pointer without checking if it's null +- Accessing `pps[id]->field` when `pps[id]` is null leads to undefined behavior +- This can cause application crashes, security vulnerabilities, or unpredictable program behavior + +**Vulnerability Type:** CWE-476 (NULL Pointer Dereference) + +**Risk:** This pattern can lead to: +- Application crashes +- Denial of service attacks +- Potential exploitation in security-critical contexts +- Undefined behavior that may be exploited by attackers diff --git a/prompt_template/semgrep_examples/null-ptr-dereference/plan.md b/prompt_template/semgrep_examples/null-ptr-dereference/plan.md new file mode 100644 index 00000000..cdf7cd6a --- /dev/null +++ b/prompt_template/semgrep_examples/null-ptr-dereference/plan.md @@ -0,0 +1,40 @@ +### Plan + +1. **Pattern Detection:** + - Use semgrep to identify code patterns where array/pointer elements are dereferenced without null checks + - Target pattern: `if ($PPS[$ID]->$FIELD == false) { ... }` + - Exclude safe patterns that already include null checks: `if ($PPS[$ID] == nullptr || $PPS[$ID]->$FIELD == false) { ... }` + +2. **Static Analysis Approach:** + - **Pattern Matching:** Identify direct field access on array elements without prior null validation + - **Context Analysis:** Ensure the pattern occurs in conditional statements where the dereference is the primary condition + - **Exclusion Rules:** Skip cases where null checks are already present in the same condition + +3. **Vulnerability Validation:** + - Verify that the array element (`$PPS[$ID]`) can potentially be null + - Confirm that the field access (`->$FIELD`) occurs without prior validation + - Check if the pattern is in a context where null values are possible + +4. **Fix Strategy:** + - **Add Null Check:** Insert null pointer validation before field access + - **Safe Pattern:** Transform `if (ptr->field == value)` to `if (ptr != nullptr && ptr->field == value)` + - **Defensive Programming:** Implement consistent null checking patterns throughout the codebase + +5. **Recommended Fix:** + ```cpp + // Before (vulnerable): + if (pps[id]->field == false) { + // ... processing code + } + + // After (safe): + if (pps[id] != nullptr && pps[id]->field == false) { + // ... processing code + } + ``` + +6. **Prevention Measures:** + - Establish coding standards requiring null checks before pointer dereference + - Use static analysis tools to catch similar patterns during development + - Implement unit tests that verify null pointer handling + - Consider using smart pointers or optional types where appropriate diff --git a/prompt_template/semgrep_examples/null-ptr-dereference/rule.yml b/prompt_template/semgrep_examples/null-ptr-dereference/rule.yml new file mode 100644 index 00000000..8452513f --- /dev/null +++ b/prompt_template/semgrep_examples/null-ptr-dereference/rule.yml @@ -0,0 +1,23 @@ +rules: +- id: vuln-libde265-0b1752ab + pattern: "if ($PPS[$ID]->$FIELD == false) {\n ...\n}\n" + pattern-not: "if ($PPS[$ID] == nullptr || $PPS[$ID]->$FIELD == false) {\n ...\n\ + }\n" + languages: + - cpp + message: 'Detected a potential null pointer dereference vulnerability. The code + checks a field of a PPS object without first verifying that the object is not + null. This can lead to a crash or undefined behavior. To fix this, add a null + check before accessing the object''s fields. + + ' + severity: ERROR + metadata: + source-url: github.com/strukturag/libde265/commit/0b1752abff97cb542941d317a0d18aa50cb199b1 + category: security + cwe: + - CWE-476 + owasp: + - A8:2017-Insecure Deserialization + technology: + - cpp \ No newline at end of file diff --git a/prompt_template/semgrep_examples/out-of-bound/patch.md b/prompt_template/semgrep_examples/out-of-bound/patch.md new file mode 100644 index 00000000..d9f18743 --- /dev/null +++ b/prompt_template/semgrep_examples/out-of-bound/patch.md @@ -0,0 +1,751 @@ +## Patch Description + +Add Inflator::BadDistanceErr exception (Issue 414) +The improved validation and excpetion clears the Address Sanitizer and Undefined Behavior Sanitizer findings + +## Buggy Code + +```c +// Complete file: zinflate.cpp (tree-sitter fallback) +// zinflate.cpp - originally written and placed in the public domain by Wei Dai + +// This is a complete reimplementation of the DEFLATE decompression algorithm. +// It should not be affected by any security vulnerabilities in the zlib +// compression library. In particular it is not affected by the double free bug +// (http://www.kb.cert.org/vuls/id/368819). + +#include "pch.h" + +#include "zinflate.h" +#include "secblock.h" +#include "smartptr.h" + +NAMESPACE_BEGIN(CryptoPP) + +struct CodeLessThan +{ + inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) + {return lhs < rhs.code;} + // needed for MSVC .NET 2005 + inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) + {return lhs.code < rhs.code;} +}; + +inline bool LowFirstBitReader::FillBuffer(unsigned int length) +{ + while (m_bitsBuffered < length) + { + byte b; + if (!m_store.Get(b)) + return false; + m_buffer |= (unsigned long)b << m_bitsBuffered; + m_bitsBuffered += 8; + } + CRYPTOPP_ASSERT(m_bitsBuffered <= sizeof(unsigned long)*8); + return true; +} + +inline unsigned long LowFirstBitReader::PeekBits(unsigned int length) +{ + bool result = FillBuffer(length); + CRYPTOPP_UNUSED(result); CRYPTOPP_ASSERT(result); + return m_buffer & (((unsigned long)1 << length) - 1); +} + +inline void LowFirstBitReader::SkipBits(unsigned int length) +{ + CRYPTOPP_ASSERT(m_bitsBuffered >= length); + m_buffer >>= length; + m_bitsBuffered -= length; +} + +inline unsigned long LowFirstBitReader::GetBits(unsigned int length) +{ + unsigned long result = PeekBits(length); + SkipBits(length); + return result; +} + +inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits) +{ + return code << (MAX_CODE_BITS - codeBits); +} + +void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes) +{ + // the Huffman codes are represented in 3 ways in this code: + // + // 1. most significant code bit (i.e. top of code tree) in the least significant bit position + // 2. most significant code bit (i.e. top of code tree) in the most significant bit position + // 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position, + // where n is the maximum code length for this code tree + // + // (1) is the way the codes come in from the deflate stream + // (2) is used to sort codes so they can be binary searched + // (3) is used in this function to compute codes from code lengths + // + // a code in representation (2) is called "normalized" here + // The BitReverse() function is used to convert between (1) and (2) + // The NormalizeCode() function is used to convert from (3) to (2) + + if (nCodes == 0) + throw Err("null code"); + + m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes); + + if (m_maxCodeBits > MAX_CODE_BITS) + throw Err("code length exceeds maximum"); + + if (m_maxCodeBits == 0) + throw Err("null code"); + + // count number of codes of each length + SecBlockWithHint blCount(m_maxCodeBits+1); + std::fill(blCount.begin(), blCount.end(), 0); + unsigned int i; + for (i=0; i nextCode(m_maxCodeBits+1); + nextCode[1] = 0; + for (i=2; i<=m_maxCodeBits; i++) + { + // compute this while checking for overflow: code = (code + blCount[i-1]) << 1 + if (code > code + blCount[i-1]) + throw Err("codes oversubscribed"); + code += blCount[i-1]; + if (code > (code << 1)) + throw Err("codes oversubscribed"); + code <<= 1; + nextCode[i] = code; + } + + // MAX_CODE_BITS is 32, m_maxCodeBits may be smaller. + const word64 shiftedMaxCode = ((word64)1 << m_maxCodeBits); + if (code > shiftedMaxCode - blCount[m_maxCodeBits]) + throw Err("codes oversubscribed"); + else if (m_maxCodeBits != 1 && code < shiftedMaxCode - blCount[m_maxCodeBits]) + throw Err("codes incomplete"); + + // compute a vector of triples sorted by code + m_codeToValue.resize(nCodes - blCount[0]); + unsigned int j=0; + for (i=0; ilen) + { + entry.type = 2; + entry.len = codeInfo.len; + } + else + { + entry.type = 3; + entry.end = last+1; + } + } +} + +inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const +{ + CRYPTOPP_ASSERT(m_codeToValue.size() > 0); + LookupEntry &entry = m_cache[code & m_cacheMask]; + + code_t normalizedCode = 0; + if (entry.type != 1) + normalizedCode = BitReverse(code); + + if (entry.type == 0) + FillCacheEntry(entry, normalizedCode); + + if (entry.type == 1) + { + value = entry.value; + return entry.len; + } + else + { + const CodeInfo &codeInfo = (entry.type == 2) + ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))] + : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1); + value = codeInfo.value; + return codeInfo.len; + } +} + +bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const +{ + bool result = reader.FillBuffer(m_maxCodeBits); + CRYPTOPP_UNUSED(result); // CRYPTOPP_ASSERT(result); + + unsigned int codeBits = Decode(reader.PeekBuffer(), value); + if (codeBits > reader.BitsBuffered()) + return false; + reader.SkipBits(codeBits); + return true; +} + +// ************************************************************* + +Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation) + : AutoSignaling(propagation) + , m_state(PRE_STREAM), m_repeat(repeat), m_eof(0), m_wrappedAround(0) + , m_blockType(0xff), m_storedLen(0xffff), m_nextDecode(), m_literal(0) + , m_distance(0), m_reader(m_inQueue), m_current(0), m_lastFlush(0) +{ + Detach(attachment); +} + +void Inflator::IsolatedInitialize(const NameValuePairs ¶meters) +{ + m_state = PRE_STREAM; + parameters.GetValue("Repeat", m_repeat); + m_inQueue.Clear(); + m_reader.SkipBits(m_reader.BitsBuffered()); +} + +void Inflator::OutputByte(byte b) +{ + m_window[m_current++] = b; + if (m_current == m_window.size()) + { + ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); + m_lastFlush = 0; + m_current = 0; + m_wrappedAround = true; + } +} + +// ... [TRUNCATED: 139 lines omitted] ... + + break; + } + case 1: // fixed codes + m_nextDecode = LITERAL; + break; + case 2: // dynamic codes + { + if (!m_reader.FillBuffer(5+5+4)) + throw UnexpectedEndErr(); + unsigned int hlit = m_reader.GetBits(5); + unsigned int hdist = m_reader.GetBits(5); + unsigned int hclen = m_reader.GetBits(4); + + FixedSizeSecBlock codeLengths; + unsigned int i; + static const unsigned int border[] = { // Order of the bit length code lengths + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + std::fill(codeLengths.begin(), codeLengths+19, 0); + for (i=0; i hlit+257+hdist+1) + throw BadBlockErr(); + std::fill(codeLengths + i, codeLengths + i + count, repeater); + i += count; + } + m_dynamicLiteralDecoder.Initialize(codeLengths, hlit+257); + if (hdist == 0 && codeLengths[hlit+257] == 0) + { + if (hlit != 0) // a single zero distance code length means all literals + throw BadBlockErr(); + } + else + m_dynamicDistanceDecoder.Initialize(codeLengths+hlit+257, hdist+1); + m_nextDecode = LITERAL; + } + catch (HuffmanDecoder::Err &) + { + throw BadBlockErr(); + } + break; + } + default: + throw BadBlockErr(); // reserved block type + } + m_state = DECODING_BODY; +} + +bool Inflator::DecodeBody() +{ + bool blockEnd = false; + switch (m_blockType) + { + case 0: // stored + CRYPTOPP_ASSERT(m_reader.BitsBuffered() == 0); + while (!m_inQueue.IsEmpty() && !blockEnd) + { + size_t size; + const byte *block = m_inQueue.Spy(size); + size = UnsignedMin(m_storedLen, size); + CRYPTOPP_ASSERT(size <= 0xffff); + + OutputString(block, size); + m_inQueue.Skip(size); + m_storedLen = m_storedLen - (word16)size; + if (m_storedLen == 0) + blockEnd = true; + } + break; + case 1: // fixed codes + case 2: // dynamic codes + static const unsigned int lengthStarts[] = { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258}; + static const unsigned int lengthExtraBits[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, + 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0}; + static const unsigned int distanceStarts[] = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, + 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, + 8193, 12289, 16385, 24577}; + static const unsigned int distanceExtraBits[] = { + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, + 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, + 12, 12, 13, 13}; + + const HuffmanDecoder& literalDecoder = GetLiteralDecoder(); + const HuffmanDecoder& distanceDecoder = GetDistanceDecoder(); + + switch (m_nextDecode) + { + case LITERAL: + while (true) + { + if (!literalDecoder.Decode(m_reader, m_literal)) + { + m_nextDecode = LITERAL; + break; + } + if (m_literal < 256) + OutputByte((byte)m_literal); + else if (m_literal == 256) // end of block + { + blockEnd = true; + break; + } + else + { + if (m_literal > 285) + throw BadBlockErr(); + unsigned int bits; + case LENGTH_BITS: + bits = lengthExtraBits[m_literal-257]; + if (!m_reader.FillBuffer(bits)) + { + m_nextDecode = LENGTH_BITS; + break; + } + m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257]; + case DISTANCE: + if (!distanceDecoder.Decode(m_reader, m_distance)) + { + m_nextDecode = DISTANCE; + break; + } + case DISTANCE_BITS: + // TODO: this surfaced during fuzzing. What do we do??? + CRYPTOPP_ASSERT(m_distance < COUNTOF(distanceExtraBits)); + bits = (m_distance >= COUNTOF(distanceExtraBits)) ? distanceExtraBits[29] : distanceExtraBits[m_distance]; + if (!m_reader.FillBuffer(bits)) + { + m_nextDecode = DISTANCE_BITS; + break; + } + m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance]; + OutputPast(m_literal, m_distance); + } + } + break; + default: + CRYPTOPP_ASSERT(0); + } + } + if (blockEnd) + { + if (m_eof) + { + FlushOutput(); + m_reader.SkipBits(m_reader.BitsBuffered()%8); + if (m_reader.BitsBuffered()) + { + // undo too much lookahead + SecBlockWithHint buffer(m_reader.BitsBuffered() / 8); + for (unsigned int i=0; i= m_lastFlush); + ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush); + m_lastFlush = m_current; + } +} + +struct NewFixedLiteralDecoder +{ + HuffmanDecoder * operator()() const + { + unsigned int codeLengths[288]; + std::fill(codeLengths + 0, codeLengths + 144, 8); + std::fill(codeLengths + 144, codeLengths + 256, 9); + std::fill(codeLengths + 256, codeLengths + 280, 7); + std::fill(codeLengths + 280, codeLengths + 288, 8); + member_ptr pDecoder(new HuffmanDecoder); + pDecoder->Initialize(codeLengths, 288); + return pDecoder.release(); + } +}; + +struct NewFixedDistanceDecoder +{ + HuffmanDecoder * operator()() const + { + unsigned int codeLengths[32]; + std::fill(codeLengths + 0, codeLengths + 32, 5); + member_ptr pDecoder(new HuffmanDecoder); + pDecoder->Initialize(codeLengths, 32); + return pDecoder.release(); + } +}; + +const HuffmanDecoder& Inflator::GetLiteralDecoder() const +{ + return m_blockType == 1 ? Singleton().Ref() : m_dynamicLiteralDecoder; +} + +const HuffmanDecoder& Inflator::GetDistanceDecoder() const +{ + return m_blockType == 1 ? Singleton().Ref() : m_dynamicDistanceDecoder; +} + +NAMESPACE_END +``` + +```c +// Complete file: zinflate.h (tree-sitter fallback) +#ifndef CRYPTOPP_ZINFLATE_H +#define CRYPTOPP_ZINFLATE_H + +#include "cryptlib.h" +#include "secblock.h" +#include "filters.h" +#include "stdcpp.h" + +NAMESPACE_BEGIN(CryptoPP) + +//! \class LowFirstBitReader +//! \since Crypto++ 1.0 +class LowFirstBitReader +{ +public: + LowFirstBitReader(BufferedTransformation &store) + : m_store(store), m_buffer(0), m_bitsBuffered(0) {} + unsigned int BitsBuffered() const {return m_bitsBuffered;} + unsigned long PeekBuffer() const {return m_buffer;} + bool FillBuffer(unsigned int length); + unsigned long PeekBits(unsigned int length); + void SkipBits(unsigned int length); + unsigned long GetBits(unsigned int length); + +private: + BufferedTransformation &m_store; + unsigned long m_buffer; + unsigned int m_bitsBuffered; +}; + +struct CodeLessThan; + +//! \class HuffmanDecoder +//! \brief Huffman Decoder +//! \since Crypto++ 1.0 +class HuffmanDecoder +{ +public: + typedef unsigned int code_t; + typedef unsigned int value_t; + enum {MAX_CODE_BITS = sizeof(code_t)*8}; + + class Err : public Exception {public: Err(const std::string &what) : Exception(INVALID_DATA_FORMAT, "HuffmanDecoder: " + what) {}}; + + HuffmanDecoder() : m_maxCodeBits(0), m_cacheBits(0), m_cacheMask(0), m_normalizedCacheMask(0) {} + HuffmanDecoder(const unsigned int *codeBitLengths, unsigned int nCodes) + : m_maxCodeBits(0), m_cacheBits(0), m_cacheMask(0), m_normalizedCacheMask(0) + {Initialize(codeBitLengths, nCodes);} + + void Initialize(const unsigned int *codeBitLengths, unsigned int nCodes); + unsigned int Decode(code_t code, /* out */ value_t &value) const; + bool Decode(LowFirstBitReader &reader, value_t &value) const; + +private: + friend struct CodeLessThan; + + struct CodeInfo + { + CodeInfo(code_t code=0, unsigned int len=0, value_t value=0) : code(code), len(len), value(value) {} + inline bool operator<(const CodeInfo &rhs) const {return code < rhs.code;} + code_t code; + unsigned int len; + value_t value; + }; + + struct LookupEntry + { + unsigned int type; + union + { + value_t value; + const CodeInfo *begin; + }; + union + { + unsigned int len; + const CodeInfo *end; + }; + }; + + static code_t NormalizeCode(code_t code, unsigned int codeBits); + void FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const; + + unsigned int m_maxCodeBits, m_cacheBits, m_cacheMask, m_normalizedCacheMask; + std::vector > m_codeToValue; + mutable std::vector > m_cache; +}; + +//! \class Inflator +//! \brief DEFLATE decompressor (RFC 1951) +//! \since Crypto++ 1.0 +class Inflator : public AutoSignaling +{ +public: + class Err : public Exception + { + public: + Err(ErrorType e, const std::string &s) + : Exception(e, s) {} + }; + class UnexpectedEndErr : public Err {public: UnexpectedEndErr() : Err(INVALID_DATA_FORMAT, "Inflator: unexpected end of compressed block") {}}; + class BadBlockErr : public Err {public: BadBlockErr() : Err(INVALID_DATA_FORMAT, "Inflator: error in compressed block") {}}; + + //! \brief RFC 1951 Decompressor + //! \param attachment the filter's attached transformation + //! \param repeat decompress multiple compressed streams in series + //! \param autoSignalPropagation 0 to turn off MessageEnd signal + Inflator(BufferedTransformation *attachment = NULLPTR, bool repeat = false, int autoSignalPropagation = -1); + + void IsolatedInitialize(const NameValuePairs ¶meters); + size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking); + bool IsolatedFlush(bool hardFlush, bool blocking); + + virtual unsigned int GetLog2WindowSize() const {return 15;} + +protected: + ByteQueue m_inQueue; + +private: + virtual unsigned int MaxPrestreamHeaderSize() const {return 0;} + virtual void ProcessPrestreamHeader() {} + virtual void ProcessDecompressedData(const byte *string, size_t length) + {AttachedTransformation()->Put(string, length);} + virtual unsigned int MaxPoststreamTailSize() const {return 0;} + virtual void ProcessPoststreamTail() {} + + void ProcessInput(bool flush); + void DecodeHeader(); + bool DecodeBody(); + void FlushOutput(); + void OutputByte(byte b); + void OutputString(const byte *string, size_t length); + void OutputPast(unsigned int length, unsigned int distance); + + static const HuffmanDecoder *FixedLiteralDecoder(); + static const HuffmanDecoder *FixedDistanceDecoder(); + + const HuffmanDecoder& GetLiteralDecoder() const; + const HuffmanDecoder& GetDistanceDecoder() const; + + enum State {PRE_STREAM, WAIT_HEADER, DECODING_BODY, POST_STREAM, AFTER_END}; + State m_state; + bool m_repeat, m_eof, m_wrappedAround; + byte m_blockType; + word16 m_storedLen; + enum NextDecode {LITERAL, LENGTH_BITS, DISTANCE, DISTANCE_BITS}; + NextDecode m_nextDecode; + unsigned int m_literal, m_distance; // for LENGTH_BITS or DISTANCE_BITS + HuffmanDecoder m_dynamicLiteralDecoder, m_dynamicDistanceDecoder; + LowFirstBitReader m_reader; + SecByteBlock m_window; + size_t m_current, m_lastFlush; +}; + +NAMESPACE_END + +#endif +``` + +## Bug Fix Patch + +```diff +diff --git a/validat1.cpp b/validat1.cpp +index cd8655b4..e81a46c6 100644 +--- a/validat1.cpp ++++ b/validat1.cpp +@@ -623,7 +623,7 @@ bool TestRandomPool() + std::cout << "FAILED:"; + else + std::cout << "passed:"; +- std::cout << " GenerateWord32 and Crop\n"; ++ std::cout << " GenerateWord32 and Crop\n"; + } + + #if !defined(NO_OS_DEPENDENCE) +@@ -711,7 +711,7 @@ bool TestRandomPool() + std::cout << "FAILED:"; + else + std::cout << "passed:"; +- std::cout << " GenerateWord32 and Crop\n"; ++ std::cout << " GenerateWord32 and Crop\n"; + } + #endif + +@@ -808,7 +808,7 @@ bool TestAutoSeededX917() + std::cout << "FAILED:"; + else + std::cout << "passed:"; +- std::cout << " GenerateWord32 and Crop\n"; ++ std::cout << " GenerateWord32 and Crop\n"; + + std::cout.flush(); + return pass; +diff --git a/zinflate.cpp b/zinflate.cpp +index 62431771..ee15c945 100644 +--- a/zinflate.cpp ++++ b/zinflate.cpp +@@ -552,12 +552,18 @@ bool Inflator::DecodeBody() + case DISTANCE_BITS: + // TODO: this surfaced during fuzzing. What do we do??? + CRYPTOPP_ASSERT(m_distance < COUNTOF(distanceExtraBits)); +- bits = (m_distance >= COUNTOF(distanceExtraBits)) ? distanceExtraBits[29] : distanceExtraBits[m_distance]; ++ if (m_distance >= COUNTOF(distanceExtraBits)) ++ throw BadDistanceErr(); ++ bits = distanceExtraBits[m_distance]; + if (!m_reader.FillBuffer(bits)) + { + m_nextDecode = DISTANCE_BITS; + break; + } ++ // TODO: this surfaced during fuzzing. What do we do??? ++ CRYPTOPP_ASSERT(m_distance < COUNTOF(distanceStarts)); ++ if (m_distance >= COUNTOF(distanceStarts)) ++ throw BadDistanceErr(); + m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance]; + OutputPast(m_literal, m_distance); + } +diff --git a/zinflate.h b/zinflate.h +index b0879cef..0767d4f9 100644 +--- a/zinflate.h ++++ b/zinflate.h +@@ -98,8 +98,12 @@ public: + Err(ErrorType e, const std::string &s) + : Exception(e, s) {} + }; ++ //! \brief Exception thrown when a truncated stream is encountered + class UnexpectedEndErr : public Err {public: UnexpectedEndErr() : Err(INVALID_DATA_FORMAT, "Inflator: unexpected end of compressed block") {}}; ++ //! \brief Exception thrown when a bad block is encountered + class BadBlockErr : public Err {public: BadBlockErr() : Err(INVALID_DATA_FORMAT, "Inflator: error in compressed block") {}}; ++ //! \brief Exception thrown when an invalid distance is encountered ++ class BadDistanceErr : public Err {public: BadDistanceErr() : Err(INVALID_DATA_FORMAT, "Inflator: error in bit distance") {}}; + + //! \brief RFC 1951 Decompressor + //! \param attachment the filter's attached transformation +``` \ No newline at end of file diff --git a/prompt_template/semgrep_examples/out-of-bound/pattern.md b/prompt_template/semgrep_examples/out-of-bound/pattern.md new file mode 100644 index 00000000..34078d30 --- /dev/null +++ b/prompt_template/semgrep_examples/out-of-bound/pattern.md @@ -0,0 +1,44 @@ +### Bug Pattern + +The bug pattern identified in this semgrep rule is an **out-of-bounds array access vulnerability**. The code uses a ternary operator to conditionally access array elements but still performs unsafe array access in certain conditions. + +**Problematic Pattern:** +```cpp +bits = (m_distance >= COUNTOF($ARRAY)) ? $ARRAY[$INDEX] : $ARRAY[m_distance]; +``` + +**Root Cause:** +- The ternary operator checks if `m_distance` is out of bounds but then accesses `$ARRAY[$INDEX]` when the condition is true +- This means when `m_distance >= COUNTOF($ARRAY)`, the code still performs an array access with `$ARRAY[$INDEX]` +- The `$INDEX` variable may not be properly bounds-checked, leading to potential out-of-bounds access +- Even when the condition is false, accessing `$ARRAY[m_distance]` assumes `m_distance` is within bounds + +**Additional Patterns:** +1. **Incorrect pointer alignment check:** + ```cpp + IsAlignedOn($PTR, GetAlignmentOf<$T*>()) // Wrong: checks pointer alignment + ``` + Should be: + ```cpp + IsAlignedOn($PTR, GetAlignmentOf<$T>()) // Correct: checks type alignment + ``` + +2. **Unaligned array declaration:** + ```cpp + T m_array[$S]; // May cause alignment issues + ``` + Should be: + ```cpp + CRYPTOPP_ALIGN_DATA(8) T m_array[$S]; // Properly aligned + ``` + +**Vulnerability Types:** +- CWE-125 (Out-of-bounds Read) +- CWE-787 (Out-of-bounds Write) + +**Risk:** These patterns can lead to: +- Memory corruption +- Application crashes +- Information disclosure +- Potential code execution vulnerabilities +- Performance degradation due to misalignment diff --git a/prompt_template/semgrep_examples/out-of-bound/plan.md b/prompt_template/semgrep_examples/out-of-bound/plan.md new file mode 100644 index 00000000..a04ec517 --- /dev/null +++ b/prompt_template/semgrep_examples/out-of-bound/plan.md @@ -0,0 +1,65 @@ +### Plan + +1. **Pattern Detection:** + - **Primary Pattern:** Detect ternary operators with conditional array access: `bits = (condition >= COUNTOF($ARRAY)) ? $ARRAY[$INDEX] : $ARRAY[variable];` + - **Alignment Pattern:** Identify incorrect pointer alignment checks: `IsAlignedOn($PTR, GetAlignmentOf<$T*>())` + - **Declaration Pattern:** Find unaligned array declarations: `T m_array[$S];` + - **Exclusion Rules:** Skip cases with proper bounds checking and exception handling + +2. **Static Analysis Approach:** + - **Ternary Operator Analysis:** Identify conditional array access patterns that may still cause out-of-bounds access + - **Bounds Checking Validation:** Verify if proper bounds checking exists before array access + - **Alignment Verification:** Check for correct alignment specifications in pointer operations + - **Declaration Analysis:** Ensure arrays requiring alignment are properly declared + +3. **Vulnerability Validation:** + - Confirm that array access occurs without comprehensive bounds validation + - Verify that the index variables (`$INDEX`, `m_distance`) can exceed array bounds + - Check if alignment requirements are properly specified for performance-critical arrays + - Validate that exception handling is absent for out-of-bounds conditions + +4. **Fix Strategy:** + - **Bounds Checking:** Replace unsafe ternary operations with explicit bounds checking and exception throwing + - **Alignment Correction:** Use proper type alignment instead of pointer alignment + - **Memory Alignment:** Add proper alignment directives for array declarations + - **Exception Handling:** Implement proper error handling for boundary violations + +5. **Recommended Fixes:** + + **For out-of-bounds array access:** + ```cpp + // Before (vulnerable): + bits = (m_distance >= COUNTOF(array)) ? array[index] : array[m_distance]; + + // After (safe): + if (m_distance >= COUNTOF(array)) + throw BadDistanceErr(); + bits = array[m_distance]; + ``` + + **For incorrect alignment check:** + ```cpp + // Before (incorrect): + IsAlignedOn(ptr, GetAlignmentOf()) + + // After (correct): + IsAlignedOn(ptr, GetAlignmentOf()) + ``` + + **For unaligned array declaration:** + ```cpp + // Before (potentially problematic): + T m_array[SIZE]; + + // After (properly aligned): + CRYPTOPP_ALIGN_DATA(8) T m_array[SIZE]; + ``` + +6. **Prevention Measures:** + - Implement comprehensive bounds checking before all array accesses + - Use range-checked containers (e.g., `std::array` with `at()` method) where possible + - Establish coding standards requiring explicit bounds validation + - Use static analysis tools to detect similar patterns during development + - Implement unit tests that verify boundary condition handling + - Ensure proper memory alignment for performance-critical data structures + - Consider using safe array access patterns with RAII and smart pointers diff --git a/prompt_template/semgrep_examples/out-of-bound/rule.yml b/prompt_template/semgrep_examples/out-of-bound/rule.yml new file mode 100644 index 00000000..6d15f2a7 --- /dev/null +++ b/prompt_template/semgrep_examples/out-of-bound/rule.yml @@ -0,0 +1,81 @@ +rules: +- id: vuln-cryptopp-07dbcc3d + pattern: 'bits = (m_distance >= COUNTOF($ARRAY)) ? $ARRAY[$INDEX] : $ARRAY[m_distance]; + + ' + pattern-not: "if (m_distance >= COUNTOF($ARRAY))\n throw BadDistanceErr();\nbits\ + \ = $ARRAY[m_distance];\n" + languages: + - cpp + message: 'The code directly accesses an array without proper bounds checking, which + can lead to out-of-bounds access and potential memory corruption. The vulnerability + occurs when `m_distance` is greater than or equal to the size of the array, leading + to undefined behavior. The fix involves adding explicit bounds checking and throwing + an exception when the index is out of bounds. + + ' + severity: ERROR + metadata: + source-url: github.com/weidai11/cryptopp/commit/07dbcc3d9644b18e05c1776db2a57fe04d780965 + category: security + cwe: + - 'CWE-125: Out-of-bounds Read' + owasp: + - A1:2017-Injection + references: + - https://cwe.mitre.org/data/definitions/125.html + technology: + - cpp +- id: vuln-cryptopp-9fe5ccfb + pattern: 'IsAlignedOn($PTR, GetAlignmentOf<$T*>()) + + ' + pattern-not: 'IsAlignedOn($PTR, GetAlignmentOf<$T>()) + + ' + languages: + - cpp + message: 'Detected incorrect pointer alignment check. The function `IsAlignedOn` + is being called with `GetAlignmentOf()`, which checks alignment for a pointer + to `T` instead of `T` itself. This can lead to incorrect alignment validation, + potentially causing undefined behavior or security vulnerabilities. Fix by using + `GetAlignmentOf()` to check alignment for the type `T` directly. + + ' + severity: ERROR + metadata: + source-url: github.com/weidai11/cryptopp/commit/9fe5ccfbeed3c3c48b6e1d42e4abb64d11662527 + category: security + cwe: + - CWE-125 + owasp: + - 'A1: Injection' + references: + - https://github.com/weidai11/cryptopp/issues/992 + technology: + - cpp +- id: vuln-cryptopp-4bc7408a + pattern: 'T m_array[$S]; + + ' + pattern-not: 'CRYPTOPP_ALIGN_DATA(8) T m_array[$S]; + + ' + languages: + - cpp + message: "Detected unaligned array declaration which may cause memory alignment\ + \ issues on some toolchains. \nThis can lead to performance degradation or crashes,\ + \ especially for 64-bit elements requiring 8-byte alignment. \nEnsure proper alignment\ + \ by using CRYPTOPP_ALIGN_DATA(8) for array declarations.\n" + severity: ERROR + metadata: + source-url: github.com/weidai11/cryptopp/commit/4bc7408ae2aefac9357c16809541ecbe225b7f3a + category: security + cwe: + - CWE-787 + owasp: + - 'A9: Using Components with Known Vulnerabilities' + references: + - https://github.com/weidai11/cryptopp/issues/992 + technology: + - cpp diff --git a/src/agent.py b/src/agent.py index f4d68be7..ee586867 100644 --- a/src/agent.py +++ b/src/agent.py @@ -5,12 +5,14 @@ from checker_example import choose_example from global_config import global_config -from model import invoke_llm +from model import invoke_llm, invoke_llm_semgrep from tools import error_formatting, grab_error_message prompt_template_dir = Path(__file__).parent.parent / "prompt_template" example_dir = prompt_template_dir / "examples" default_checker_examples = [] +semgrep_example_dir = prompt_template_dir / "semgrep_examples" +default_semgrep_examples = [] UTILITY_FUNCTION = (prompt_template_dir / "knowledge" / "utility.md").read_text() SUGGESTIONS = (prompt_template_dir / "knowledge" / "suggestions.md").read_text() @@ -35,12 +37,37 @@ def load_example_from_dir(checker_dir: str): patch=patch, pattern=pattern, plan=plan, checker_code=checker_code ) +class SemgrepExample(BaseModel): + patch: str + pattern: str + plan: str + semgrep_rule: str + + @staticmethod + def load_example_from_dir(example_dir: str): + example_dir = Path(example_dir) + patch = (example_dir / "patch.md").read_text() + pattern = (example_dir / "pattern.md").read_text() if (example_dir / "pattern.md").exists() else "" + plan = (example_dir / "plan.md").read_text() if (example_dir / "plan.md").exists() else "" + semgrep_rule = (example_dir / "semgrep_rule.yml").read_text() if (example_dir / "semgrep_rule.yml").exists() else "" + + return SemgrepExample( + patch=patch, pattern=pattern, plan=plan, semgrep_rule=semgrep_rule + ) + for checker_dir in example_dir.iterdir(): if not checker_dir.is_dir(): continue default_checker_examples.append(Example.load_example_from_dir(checker_dir)) +# Load semgrep examples +if semgrep_example_dir.exists(): + for semgrep_dir in semgrep_example_dir.iterdir(): + if not semgrep_dir.is_dir(): + continue + default_semgrep_examples.append(SemgrepExample.load_example_from_dir(semgrep_dir)) + def get_example_text( example_list, @@ -64,6 +91,31 @@ def get_example_text( example_text += "```\n\n" return example_text +def get_semgrep_example_text( + example_list=None, + need_patch: bool = True, + need_pattern: bool = False, + need_plan: bool = False, + need_semgrep_rule: bool = True, +): + """Get example text for Semgrep rules from actual example files.""" + if example_list is None: + example_list = default_semgrep_examples + + example_text = "" + for i, example in enumerate(example_list): + example_text += f"## Example {i+1}\n" + if need_patch: + example_text += example.patch + "\n\n" + if need_pattern and example.pattern: + example_text += example.pattern + "\n\n" + if need_plan and example.plan: + example_text += example.plan + "\n\n" + if need_semgrep_rule and example.semgrep_rule: + example_text += "### Semgrep Rule\n```yaml\n" + example_text += example.semgrep_rule + example_text += "```\n\n" + return example_text patch2checker_template = (prompt_template_dir / "patch2checker.md").read_text() patch2pattern_template = ( @@ -477,3 +529,188 @@ def repair_syntax(id: str, iter: int, times, checker_code, error_content): response_store = prompt_history_dir / f"response_repair_syntax-{times}.md" response_store.write_text(response) return response + +"""Patch to Semgrep Rule""" +patch2semgrep_template = (prompt_template_dir / "patch2semgrep.md").read_text() + +"""Pattern to Semgrep Plan""" +pattern2semplan_template = (prompt_template_dir / "pattern2semplan.md").read_text() + +"""Plan to Semgrep Rule""" +plan2semgrep_template = (prompt_template_dir / "plan2semgrep.md").read_text() + +"""Repair Semgrep Rule""" +repair_semgrep_template = (prompt_template_dir / "repair_semgrep.md").read_text() + + +def patch2semgrep(id: str, iter: int, patch: str): + """Generate Semgrep rule directly from patch.""" + logger.info("start generating patch2semgrep prompts") + + # Use semgrep examples if available + example_text = get_semgrep_example_text() + + patch2semgrep_prompt = patch2semgrep_template.replace("{{input_patch}}", patch) + patch2semgrep_prompt = patch2semgrep_prompt.replace("{{examples}}", example_text) + + prompt_history_dir = ( + Path(global_config.result_dir) / "semgrep_rules" / id / "prompt_history" / str(iter) + ) + path2store = prompt_history_dir / "patch2semgrep.md" + prompt_history_dir.mkdir(parents=True, exist_ok=True) + + path2store.write_text(patch2semgrep_prompt) + logger.info("finish patch2semgrep generation") + + response = invoke_llm(patch2semgrep_prompt) + response_store = prompt_history_dir / "response_semgrep.md" + response_store.write_text(response) + return response + + +def plan2semgrep( + id: str, + iter: int, + pattern: str, + plan: str, + patch: str, + no_utility=False, + sample_examples=False, +): + """Generate Semgrep rule from plan.""" + logger.info("start generating plan2semgrep prompts") + + # Use semgrep examples if available + if sample_examples: + logger.warning("Sample examples for plan2semgrep") + # TODO: Implement sampling for semgrep examples + example_text = get_semgrep_example_text() + else: + example_text = get_semgrep_example_text() + + plan2semgrep_prompt = ( + plan2semgrep_template.replace("{{input_pattern}}", pattern) + .replace("{{input_plan}}", plan) + .replace("{{input_patch}}", patch) + .replace("{{examples}}", example_text) + ) + + prompt_history_dir = ( + Path(global_config.result_dir) / "semgrep_rules" / id / "prompt_history" / str(iter) + ) + path2store = prompt_history_dir / "plan2semgrep.md" + prompt_history_dir.mkdir(parents=True, exist_ok=True) + + path2store.write_text(plan2semgrep_prompt) + logger.info("finish plan2semgrep generation") + + response = invoke_llm_semgrep(plan2semgrep_prompt) + response_store = prompt_history_dir / "response_semgrep.md" + response_store.write_text(response) + return response + + +def repair_semgrep_syntax(id: str, repair_name: str, times: int, semgrep_rule: str, error_content: str): + """Repair Semgrep rule syntax errors.""" + logger.info("start generating repair_semgrep_syntax prompts") + + prompt = ( + repair_semgrep_template.replace("{{semgrep_rule}}", semgrep_rule) + .replace("{{error_messages}}", error_content) + ) + + prompt_history_dir = ( + Path(global_config.result_dir) / "semgrep_rules" / id / "prompt_history" / repair_name + ) + path2store = prompt_history_dir / f"repair_semgrep_syntax-{times}.md" + prompt_history_dir.mkdir(parents=True, exist_ok=True) + + path2store.write_text(prompt) + logger.info("finish repair_semgrep_syntax generation") + + response = invoke_llm(prompt) + if response is None: + logger.error("Empty response") + response = "SKIP" + + response_store = prompt_history_dir / f"response_repair_semgrep_syntax-{times}.md" + response_store.write_text(response) + return response + + +def pattern2semplan( + id: str, + iter: int, + pattern: str, + patch: str, + no_tp_plans=None, + no_fp_plans=None, + sample_examples=False, +): + """Generate Semgrep plan based on the given pattern and patch. + + Args: + id (str): The id of the current task. + iter (int): The iteration number. + pattern (str): The pattern of the bug. + patch (str): The patch of the bug. + no_tp_plans (list, optional): Plans that cannot detect the buggy pattern. Defaults to None. + no_fp_plans (list, optional): Plans that can label the non-buggy pattern correctly. Defaults to None. + sample_examples (bool, optional): Whether to sample examples. Defaults to False. + """ + logger.info("start generating pattern2semplan prompts") + + template = pattern2semplan_template + + if sample_examples: + logger.warning("Sample examples for pattern2semplan") + example_text = get_semgrep_example_text(need_pattern=True, need_plan=True, need_semgrep_rule=False) + else: + example_text = get_semgrep_example_text(need_pattern=True, need_plan=True, need_semgrep_rule=False) + + pattern2semplan_prompt = ( + template.replace("{{input_pattern}}", pattern) + .replace("{{input_patch}}", patch) + .replace("{{examples}}", example_text) + ) + + feedback_plan_text = "" + if no_tp_plans: + no_tp_plan_text = "# Plans that cannot detect the buggy pattern\n" + # The last three plans if there are more than 3 failed plans + if len(no_tp_plans) > 3: + no_tp_plans = no_tp_plans[-3:] + + for i, plan in enumerate(no_tp_plans): + no_tp_plan_text += f"## Failed Plan {i+1}\n" + no_tp_plan_text += plan + "\n\n" + feedback_plan_text += no_tp_plan_text + if no_fp_plans: + no_fp_plan_text = "# Plans that can label the non-buggy pattern correctly\n" + + if len(no_fp_plans) > 3: + no_fp_plans = no_fp_plans[-3:] + + for i, plan in enumerate(no_fp_plans): + no_fp_plan_text += f"## Failed Plan {i+1}\n" + no_fp_plan_text += plan + "\n\n" + feedback_plan_text += no_fp_plan_text + + pattern2semplan_prompt = pattern2semplan_prompt.replace( + "{{failed_plan_examples}}", feedback_plan_text + ) + + prompt_history_dir = ( + Path(global_config.result_dir) / "semgrep_rules" / id / "prompt_history" / str(iter) + ) + path2store = prompt_history_dir / "pattern2semplan.md" + prompt_history_dir.mkdir(parents=True, exist_ok=True) + + path2store.write_text(pattern2semplan_prompt) + logger.info("finish pattern2semplan generation") + + response = invoke_llm(pattern2semplan_prompt) + response_store = prompt_history_dir / "response_semplan.md" + + response_store.write_text(response) + return response \ No newline at end of file diff --git a/src/backends/factory.py b/src/backends/factory.py index 31900773..20f90369 100644 --- a/src/backends/factory.py +++ b/src/backends/factory.py @@ -21,12 +21,29 @@ def __init__(self, backend_path: str): self.backend_path = Path(backend_path) @abstractmethod - def build_checker(self, checker_code: str, log_dir: Path, attempt=1, **kwargs): + @abstractmethod + def build_checker( + self, + checker_code: str, + log_dir: Path, + checker_name: str = "SAGenTest", + attempt: int = 1, + jobs: int = 8, + timeout: int = 300, + ) -> Tuple[int, str]: """ - Build the checker in the backend. + Build/validate the checker code. Args: - **kwargs: Additional arguments for the build command. + checker_code (str): The checker code to build/validate. + log_dir (Path): Directory for build logs. + checker_name (str): Name of the checker. + attempt (int): Attempt number. + jobs (int): Number of parallel jobs. + timeout (int): Timeout in seconds. + + Returns: + Tuple[int, str]: (return_code, error_message) """ raise NotImplementedError("Subclasses must implement this method.") @@ -37,18 +54,20 @@ def validate_checker( commit_id: str, patch: str, target: TargetFactory, - skip_build_checker=False, + skip_build_checker: bool = False, ) -> Tuple[int, int]: """ Validate the checker against a commit and patch. Args: + checker_code (str): The checker code to validate. commit_id (str): The commit ID to validate against. - patch (str): The patch to apply. - target (TargetFactory): The target to be tested. + patch (str): The patch content. + target (TargetFactory): The target to validate against. skip_build_checker (bool): Whether to skip building the checker. + Returns: - Tuple[int, int]: The number of true positives and true negatives. + Tuple[int, int]: (TP_count, TN_count) """ raise NotImplementedError("Subclasses must implement this method.") @@ -61,26 +80,41 @@ def run_checker( checker_code: str, commit_id: str, target: TargetFactory, - object_to_analyze=None, - jobs=32, - output_dir="tmp", - skip_build_checker=False, - skip_checkout=False, + object_to_analyze: Optional[str] = None, + jobs: int = 32, + output_dir: str = "tmp", **kwargs, - ): + ) -> int: """ - Run the checker against a commit and patch. + Run the checker against a commit. Args: - checker_code (str): The code of the checker to run. - commit_id (str): The commit ID to validate against. - target (TargetFactory): The target to be tested. - object_to_analyze (str): The object to analyze. - jobs (int): The number of jobs to run in parallel. - output_dir (str): The directory to store the output. - **kwargs: Additional arguments for the run command. + checker_code (str): The checker code to run. + commit_id (str): The commit ID to run against. + target (TargetFactory): The target to run against. + object_to_analyze (Optional[str]): Specific object to analyze. + jobs (int): Number of parallel jobs. + output_dir (str): Output directory for results. + **kwargs: Additional arguments. + + Returns: + int: Number of bugs found, or negative value for errors. """ raise NotImplementedError("Subclasses must implement this method.") + + @staticmethod + @abstractmethod + def get_num_bugs(content: str) -> int: + """ + Extract number of bugs from analysis output. + + Args: + content (str): The analysis output content. + + Returns: + int: Number of bugs found. + """ + pass @staticmethod @abstractmethod diff --git a/src/backends/semgrep.py b/src/backends/semgrep.py new file mode 100644 index 00000000..6da9db29 --- /dev/null +++ b/src/backends/semgrep.py @@ -0,0 +1,420 @@ +import re +import subprocess as sp +import tempfile +import yaml +from pathlib import Path +from typing import Optional + +from loguru import logger + +from backends.factory import AnalysisBackendFactory +from targets.factory import TargetFactory +from targets.linux import Linux + + +class SemgrepBackend(AnalysisBackendFactory): + """ + Concrete implementation of the Backend class for Semgrep. + """ + + def __init__(self, backend_path: str): + """ + Initialize Semgrep backend. + + Args: + backend_path (str): Path where semgrep rules are stored + """ + super().__init__(backend_path) + # Semgrep doesn't need a build directory, just rules storage + self.rules_path = self.backend_path / "rules" + self.rules_path.mkdir(parents=True, exist_ok=True) + + def build_checker( + self, + checker_code: str, + log_dir: Path, + checker_name="SAGenTest", + attempt=1, + jobs=8, + timeout=300, + ): + """ + Build the checker in the Semgrep backend. + For Semgrep, this means saving the YAML rule to a file. + + Args: + checker_code (str): The YAML rule content to save. + log_dir (Path): Directory for logs. + checker_name (str): Name of the checker. + attempt (int): Attempt number. + jobs (int): Not used for Semgrep. + timeout (int): Not used for Semgrep. + """ + # Create rule file path + rule_file_path = self.rules_path / f"{checker_name}.yml" + log_dir.mkdir(parents=True, exist_ok=True) + + try: + # Parse and validate YAML + rule_data = yaml.safe_load(checker_code) + if not rule_data or 'rules' not in rule_data: + error_msg = "Invalid rule format: missing 'rules' key" + logger.error(error_msg) + (log_dir / f"build_error_{attempt}.log").write_text(error_msg) + return -1, error_msg + + # Write the rule to file + rule_file_path.write_text(checker_code) + + # Log success + success_msg = f"Successfully saved Semgrep rule to {rule_file_path}" + logger.info(success_msg) + (log_dir / f"build_stdout_{attempt}.log").write_text(success_msg) + + return 0, "Rule saved successfully" + + except yaml.YAMLError as e: + error_msg = f"Invalid YAML format: {e}" + logger.error(error_msg) + (log_dir / f"build_error_{attempt}.log").write_text(error_msg) + return -1, error_msg + except Exception as e: + error_msg = f"Error saving rule: {e}" + logger.error(error_msg) + (log_dir / f"build_error_{attempt}.log").write_text(error_msg) + return -1, error_msg + + def validate_checker( + self, + checker_code, + commit_id, + patch, + target: TargetFactory, + skip_build_checker=False, + ): + """ + Validate the checker against a commit and patch. + """ + if target._target_type == "linux": + return self._validate_checker_linux( + checker_code, commit_id, patch, target, skip_build_checker + ) + else: + raise NotImplementedError( + f"Validation for target type {target._target_type} is not implemented." + ) + + def run_checker( + self, + checker_code, + commit_id, + target, + object_to_analyze=None, + jobs=32, + output_dir="tmp", + **kwargs, + ): + """ + Run the checker against a commit. + """ + if target._target_type == "linux": + return self._run_checker_linux( + checker_code, + commit_id, + target, + object_to_analyze=object_to_analyze, + jobs=jobs, + output_dir=output_dir, + **kwargs, + ) + else: + raise NotImplementedError( + f"Running checker for target type {target._target_type} is not implemented." + ) + + def _validate_checker_linux( + self, + checker_code: str, + commit_id: str, + patch: str, + target: Linux, + skip_build_checker=False, + ): + """ + Validate the Semgrep rule against a commit and patch. + """ + TP, TN = 0, 0 + + if not skip_build_checker: + build_res, build_msg = self.build_checker( + checker_code, + Path("tmp"), + attempt=1, + ) + if build_res != 0: + logger.error(f"Rule validation failed: {build_msg}") + return -1, -1 + + # Get rule file path + rule_file = self.rules_path / "SAGenTest.yml" + + # Checkout buggy version + target.checkout_commit(commit_id, is_before=True) + + # Get modified files from patch + objects = target.get_objects_from_patch(patch) + + for obj in objects: + # Convert object to source file for Semgrep scanning + logger.info(f"Validating object: {obj}") + source_files = self._get_source_files_from_object(obj, target) + + for source_file in source_files: + file_path = Path(target.repo.working_dir) / source_file + if not file_path.exists(): + continue + + # Run Semgrep on buggy version + logger.info(f"Running Semgrep on buggy version for {source_file}") + bugs_found = self._run_semgrep_on_file(rule_file, file_path) + logger.info(f"Buggy version - {source_file}: {bugs_found} bugs found") + + if bugs_found > 0: + TP += 1 + break # Found bug in this object + + # Checkout fixed version + target.checkout_commit(commit_id, is_before=False) + + for obj in objects: + source_files = self._get_source_files_from_object(obj, target) + + for source_file in source_files: + file_path = Path(target.repo.working_dir) / source_file + if not file_path.exists(): + continue + + # Run Semgrep on fixed version + logger.info(f"Running Semgrep on fixed version for {source_file}") + bugs_found = self._run_semgrep_on_file(rule_file, file_path) + logger.info(f"Fixed version - {source_file}: {bugs_found} bugs found") + + if bugs_found == 0: + TN += 1 + break # No bugs in fixed version + + return TP, TN + + def _run_checker_linux( + self, + checker_code: str, + commit_id: str, + target: Linux, + object_to_analyze: str = None, + jobs: int = 32, + output_dir: str = "tmp", + **kwargs, + ): + """ + Run the Semgrep checker against a Linux repository. + """ + output_dir = Path(output_dir) + timeout = kwargs.get("timeout", 1800) + + # Build (save) the rule + build_res, build_msg = self.build_checker(checker_code, Path("tmp"), attempt=1) + if build_res != 0: + logger.error("Rule save failed, skipping analysis.") + raise Exception("Rule save failed, skipping analysis.") + + # Checkout the specified commit + target.checkout_commit(commit_id) + + # Get rule file + rule_file = self.rules_path / "SAGenTest.yml" + + # Determine scan target + if object_to_analyze: + # Convert object to source files + source_files = self._get_source_files_from_object(object_to_analyze, target) + scan_paths = [str(Path(target.repo.working_dir) / f) for f in source_files] + else: + # Scan entire repository + scan_paths = [str(target.repo.working_dir)] + + total_bugs = 0 + + try: + for scan_path in scan_paths: + if not Path(scan_path).exists(): + continue + + logger.info(f"Running Semgrep on: {scan_path}") + + # Run Semgrep + cmd = [ + "semgrep", + "--config", str(rule_file), + "--json", + "--timeout", str(timeout), + scan_path + ] + + process = sp.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + cwd=target.repo.working_dir + ) + + if process.returncode != 0 and process.returncode != 1: # 1 means findings found + logger.error(f"Semgrep failed with return code {process.returncode}") + logger.error(f"Stderr: {process.stderr}") + return -999 + + # Parse results + bugs_in_path = self._parse_semgrep_output(process.stdout) + total_bugs += bugs_in_path + + logger.info(f"Found {bugs_in_path} bugs in {scan_path}") + + except sp.TimeoutExpired: + logger.warning("Semgrep scan timed out!") + return -1 + except Exception as e: + logger.error(f"Error running Semgrep: {e}") + return -10 + + # Save output + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "semgrep_results.json").write_text(process.stdout) + + logger.success(f"Semgrep scan completed: {total_bugs} bugs found!") + return total_bugs + + def _get_source_files_from_object(self, obj: str, target: Linux) -> list: + """ + Convert object file name to corresponding source files. + + Args: + obj (str): Object file name (e.g., "fs/ext4/inode.o") + target (Linux): Linux target + + Returns: + list: List of source file paths + """ + # Remove .o extension and add common source extensions + base_path = obj.replace('.o', '') + extensions = ['.c', '.cc', '.cpp', '.cxx'] + + source_files = [] + for ext in extensions: + source_file = base_path + ext + if (Path(target.repo.working_dir) / source_file).exists(): + source_files.append(source_file) + + return source_files + + def _run_semgrep_on_file(self, rule_file: Path, target_file: Path) -> int: + """ + Run Semgrep on a single file and return number of findings. + """ + try: + logger.info(f"Running Semgrep on {target_file} with rule {rule_file}") + cmd = [ + "semgrep", + "--config", str(rule_file), + "--json", + str(target_file) + ] + + result = sp.run(cmd, capture_output=True, text=True, timeout=30) + return self._parse_semgrep_output(result.stdout) + + except Exception as e: + logger.error(f"Error running Semgrep on {target_file}: {e}") + return 0 + + def _parse_semgrep_output(self, output: str) -> int: + """ + Parse Semgrep JSON output and return number of findings. + """ + try: + import json + data = json.loads(output) + results = data.get('results', []) + return len(results) + except Exception as e: + logger.error(f"Error parsing Semgrep output: {e}") + return 0 + + @staticmethod + def get_num_bugs(content: str) -> int: + """ + Extract number of bugs from Semgrep output. + """ + try: + import json + data = json.loads(content) + results = data.get('results', []) + return len(results) + except Exception: + logger.error("Error: Couldn't extract number of bugs from Semgrep output.") + return 0 + + @staticmethod + def get_objects_from_report(report: str, target: TargetFactory): + """ + Get the objects from the Semgrep report. + + Args: + report (str): The JSON report from Semgrep. + target (TargetFactory): The target to be tested. + + Returns: + list: List of objects found in the report. + """ + try: + import json + data = json.loads(report) + results = data.get('results', []) + + objects = set() + for result in results: + file_path = result.get('path', '') + if file_path: + # Convert source file back to object file + obj_path = target.get_object_name(file_path) + objects.add(obj_path) + + return list(objects) + + except Exception as e: + logger.error(f"Error parsing Semgrep report: {e}") + return [] + + @staticmethod + def extract_reports( + report_dir: str, + output_dir: str, + sampled_num: int = 5, + stop_num: int = 5, + max_num: int = 100, + seed: int = 0, + ) -> tuple[Optional[list], int]: + """ + Extract reports from Semgrep JSON output. + + Args: + report_dir (str): Directory containing Semgrep JSON reports + output_dir (str): Directory to store extracted reports + sampled_num (int): Number of reports to sample + stop_num (int): Number of reports to stop at + max_num (int): Maximum number of reports to process + seed (int): Random seed for sampling + + Returns: + Tuple[Optional[List[ReportData]], int]: List of extracted reports and total count + """ + pass \ No newline at end of file diff --git a/src/checker_data.py b/src/checker_data.py index f9c981d7..6cf6000c 100644 --- a/src/checker_data.py +++ b/src/checker_data.py @@ -182,12 +182,16 @@ def __init__( self.plan: Optional[str] = None self.refined_plan: Optional[str] = None # Note: often same as plan in snippets self.initial_checker_code: Optional[str] = None # Code before repair/refinement + # For CSA: C++ checker code + # For Semgrep: YAML rule content # Syntax Repair self.syntax_repair_log: List[RepairResult] = [] # List of repair attempts self.repaired_checker_code: Optional[ str ] = None # Code after repairChecker step + # For CSA: Repaired C++ checker code + # For Semgrep: Repaired YAML rule content # Evaluation results self.tp_score: int = -10 # True Positives, default from checker_gen.py @@ -195,7 +199,9 @@ def __init__( # Data from the refinement phase (checker_refine.py) self.refinement_history: List[RefineResult] = [] - self.final_checker_code: Optional[str] = None + self.final_checker_code: Optional[str] = None # Final code after refinement + # For CSA: Final C++ checker code + # For Semgrep: Final YAML rule content def update_base_result_dir(self, base_result_dir: Path): """Updates the base result directory.""" @@ -246,7 +252,7 @@ def is_valid(self) -> bool: def to_dict(self) -> dict: """Converts the CheckerData instance to a JSON-serializable dictionary.""" - return { + result = { "commit_id": self.commit_id, "commit_type": self.commit_type, "index": self.index, @@ -266,6 +272,7 @@ def to_dict(self) -> dict: # "refinement_history": [hist.to_dict() for hist in self.refinement_history], # "final_checker_code": self.final_checker_code, } + return result @property def checker_id(self) -> str: @@ -303,11 +310,28 @@ def dump_dir(self): (output_dir / "pattern.txt").write_text(self.pattern or "") (output_dir / "plan.txt").write_text(self.plan or "") (output_dir / "refined_plan.txt").write_text(self.refined_plan or "") - (output_dir / "checker-initial.cpp").write_text(self.initial_checker_code or "") - (output_dir / "checker-repaired.cpp").write_text( - self.repaired_checker_code or "" - ) - (output_dir / "checker-final.cpp").write_text(self.final_checker_code or "") + + # Save initial code with appropriate extension based on content + if self.initial_checker_code: + if self.initial_checker_code.strip().startswith('rules:'): + (output_dir / "checker-initial.yml").write_text(self.initial_checker_code) + else: + (output_dir / "checker-initial.cpp").write_text(self.initial_checker_code) + + # Save repaired code with appropriate extension + if self.repaired_checker_code: + if self.repaired_checker_code.strip().startswith('rules:'): + (output_dir / "checker-repaired.yml").write_text(self.repaired_checker_code) + else: + (output_dir / "checker-repaired.cpp").write_text(self.repaired_checker_code) + + # Save final code with appropriate extension + if self.final_checker_code: + if self.final_checker_code.strip().startswith('rules:'): + (output_dir / "checker-final.yml").write_text(self.final_checker_code) + else: + (output_dir / "checker-final.cpp").write_text(self.final_checker_code) + (output_dir / "score.txt").write_text( f"TP: {self.tp_score}\nTN: {self.tn_score}" ) @@ -360,17 +384,27 @@ def load_checker_data_from_dir(dir_path: str) -> "CheckerData": index=index, ) - # Load the files + # Load the files - try both extensions checker_data.patch = (dir_path / "patch.txt").read_text() checker_data.pattern = (dir_path / "pattern.txt").read_text() checker_data.plan = (dir_path / "plan.txt").read_text() checker_data.refined_plan = (dir_path / "refined_plan.txt").read_text() - checker_data.initial_checker_code = ( - dir_path / "checker-initial.cpp" - ).read_text() - checker_data.repaired_checker_code = ( - dir_path / "checker-repaired.cpp" - ).read_text() + + # Load initial code - try both extensions + initial_cpp = dir_path / "checker-initial.cpp" + initial_yml = dir_path / "checker-initial.yml" + if initial_yml.exists(): + checker_data.initial_checker_code = initial_yml.read_text() + elif initial_cpp.exists(): + checker_data.initial_checker_code = initial_cpp.read_text() + + # Load repaired code - try both extensions + repaired_cpp = dir_path / "checker-repaired.cpp" + repaired_yml = dir_path / "checker-repaired.yml" + if repaired_yml.exists(): + checker_data.repaired_checker_code = repaired_yml.read_text() + elif repaired_cpp.exists(): + checker_data.repaired_checker_code = repaired_cpp.read_text() score_file = dir_path / "score.txt" if score_file.exists(): @@ -378,7 +412,7 @@ def load_checker_data_from_dir(dir_path: str) -> "CheckerData": print(score_content) checker_data.tp_score = int(score_content[0].split(":")[-1].strip()) checker_data.tn_score = int(score_content[1].split(":")[-1].strip()) - + return checker_data diff --git a/src/checker_example.py b/src/checker_example.py index f1453d5a..69001248 100644 --- a/src/checker_example.py +++ b/src/checker_example.py @@ -7,6 +7,8 @@ example_dir = Path(__file__).parent.parent / "checker_database" example_list = [] +semgrep_example_dir = Path(__file__).parent.parent / "prompt_template" / "semgrep_examples" +semgrep_example_list = [] class ExampleChecker: @@ -54,6 +56,32 @@ def init_example(): continue example_list.append(ExampleChecker.load_example_from_dir(checker_dir)) +def init_semgrep_example(): + """Initialize only semgrep examples for semgrep rule generation.""" + global semgrep_example_list + if not semgrep_example_dir.exists(): + return + + for example_dir in semgrep_example_dir.iterdir(): + if not example_dir.is_dir(): + continue + # For semgrep examples, we might not have embeddings, so we can skip that part + # or implement a simpler version without embeddings + try: + pattern = (example_dir / "pattern.md").read_text() if (example_dir / "pattern.md").exists() else "" + plan = (example_dir / "plan.md").read_text() if (example_dir / "plan.md").exists() else "" + checker_code = (example_dir / "semgrep_rule.yml").read_text() if (example_dir / "semgrep_rule.yml").exists() else "" + + # Create a simplified example for semgrep - use unified field names + semgrep_example = { + 'pattern': pattern, + 'plan': plan, + 'checker_code': checker_code, # Use unified field name + 'dir': example_dir + } + semgrep_example_list.append(semgrep_example) + except Exception as e: + print(f"Error loading semgrep example from {example_dir}: {e}") def choose_example(content: str, type: str, num_samples=3): """Choose the most similar example checker for the given content.""" diff --git a/src/global_config.py b/src/global_config.py index 44c9f34b..2dd2ae59 100644 --- a/src/global_config.py +++ b/src/global_config.py @@ -6,6 +6,7 @@ import yaml from backends.csa import ClangBackend +from backends.semgrep import SemgrepBackend from backends.factory import AnalysisBackendFactory from targets.factory import TargetFactory from targets.linux import Linux @@ -45,7 +46,12 @@ def setup(self, config_path: str = "config.yaml"): # Init the target and backend # FIXME: This should be extended to support other targets and backends self._config["target"] = Linux(self.get("linux_dir")) - self._config["backend"] = ClangBackend(self.get("LLVM_dir")) + # Initialize backend based on configuration + backend_type = self.get("backend_type", "csa") + if backend_type == "semgrep": + self._config["backend"] = SemgrepBackend(self.get("semgrep_dir", "./semgrep_rules")) + else: + self._config["backend"] = ClangBackend(self.get("LLVM_dir")) def _init_logger(self): """Initialize the logger.""" diff --git a/src/main.py b/src/main.py index 99f3838d..5082105a 100644 --- a/src/main.py +++ b/src/main.py @@ -12,6 +12,8 @@ ) from checker_scan import scan, scan_single_checker, triage_report from commit_label import label_commits +from semgrep_gen import semgrep_gen +from semgrep_repair import repair_semgrep_rule from global_config import global_config, logger from model import init_llm @@ -25,7 +27,7 @@ def init_config(config_file: str): logger.debug("Config file: " + config_file) logger.debug("Result dir: " + result_dir) - logger.debug("Analysis backend: " + str(global_config.get("backend"))) + logger.debug("Analysis backend: " + str(global_config.get("backend_type"))) logger.debug("Target: " + str(global_config.get("target"))) @@ -62,6 +64,8 @@ def main(mode: str, *args, **kwargs): "scan_single": (scan_single_checker, "Scan with a single checker from file"), "triage": (triage_report, "Triage the report"), "label": (label_commits, "Label commits"), + "sem_gen": (semgrep_gen, "Generate semgrep rules using LLM"), + "sem_repair": (repair_semgrep_rule, "Repair semgrep rules"), } if mode not in modes: diff --git a/src/model.py b/src/model.py index 8be1e11c..393376b1 100644 --- a/src/model.py +++ b/src/model.py @@ -208,6 +208,108 @@ def invoke_llm( return None +def invoke_llm_semgrep( + prompt, + temperature=model_config["temperature"], + model=model_config["model"], + max_tokens=model_config["max_tokens"], +) -> str: + """Invoke the LLM model with the given prompt for Semgrep rule generation.""" + + logger.info(f"start LLM process: {model}") + num_tokens = num_tokens_from_string(prompt) + logger.info("Token counts: {}".format(num_tokens)) + if num_tokens > 100000: + logger.warning("Token counts exceed the limit. Skip.") + return None + + failed_count = 0 + while True: + try: + # Get the appropriate client and model + client, actual_model = get_client_and_model(model) + + # Handle different client types for Semgrep generation + if isinstance( + client, anthropic.Anthropic if ANTHROPIC_AVAILABLE else type(None) + ): + # Claude API + system_prompt = """You generate Semgrep rules in YAML format. +Return only the raw YAML content without any markdown formatting or additional text. +Always include these required fields: id, pattern, message, severity, languages. +Focus on detecting security vulnerabilities and coding issues.""" + + response = client.messages.create( + model=actual_model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ], + max_tokens=max_tokens, + temperature=temperature, + ) + answer = response.content[0].text + + elif isinstance(client, genai.Client): + # Google API + system_prompt = """You generate Semgrep rules in YAML format. +Return only the raw YAML content without any markdown formatting or additional text. +Always include these required fields: id, pattern, message, severity, languages. +Focus on detecting security vulnerabilities and coding issues.""" + + full_prompt = f"{system_prompt}\n\n{prompt}" + response = client.models.generate_content( + model=actual_model, + contents=full_prompt, + ) + answer = response.text + + else: # OpenAI or compatible + system_content = """You generate Semgrep rules in YAML format. +Return only the raw YAML content without any markdown formatting or additional text. +Always include these required fields: id, pattern, message, severity, languages. +Focus on detecting security vulnerabilities and coding issues.""" + + kwargs = { + "model": actual_model, + "messages": [ + {"role": "system", "content": system_content}, + {"role": "user", "content": prompt} + ], + "max_completion_tokens": max_tokens, + } + + # Only add temperature for models that support it + no_temp_models = ["o1", "o3-mini", "o4-mini", "o1-preview", "gpt-5"] + if not any(m in actual_model for m in no_temp_models): + kwargs["temperature"] = temperature + + response = client.chat.completions.create(**kwargs) + answer = response.choices[0].message.content + + except Exception as e: + logger.error("Error: {}".format(e)) + failed_count += 1 + if failed_count > 5: + logger.error("Failed too many times. Skip.") + raise e + time.sleep(2) + else: + logger.info("finish LLM process") + + if isinstance(answer, str): + # Remove think tags if present + if "" in answer or "" in answer: + answer = answer.split("")[-1].strip() + return answer + else: + logger.warning("Response is not a string") + failed_count += 1 + if failed_count > 5: + logger.error("Failed too many times. Skip.") + return None + time.sleep(2) + continue def get_embeddings(text: str) -> list: """Get embeddings using OpenAI API""" diff --git a/src/semgrep_gen.py b/src/semgrep_gen.py new file mode 100644 index 00000000..2dbac52f --- /dev/null +++ b/src/semgrep_gen.py @@ -0,0 +1,229 @@ +from pathlib import Path +from typing import List +import time + +from agent import patch2pattern, pattern2plan, plan2semgrep, patch2semgrep, pattern2semplan +from checker_data import CheckerData +from checker_example import init_semgrep_example +from semgrep_repair import repair_semgrep_rule +from global_config import global_config, logger +from tools import extract_semgrep_rule + +def semgrep_gen( + commit_file="commits.txt", + result_file=None, + use_multi=True, + use_general=False, + no_utility=False, + sample_examples=False, +): + """Generate semgrep rules for commits, similar to gen_checker.""" + logger.info("Using multi: " + str(use_multi)) + + content = Path(commit_file).read_text() + semgrep_dir = Path(global_config.get("semgrep_dir", "./semgrep_rules")) + semgrep_dir.mkdir(parents=True, exist_ok=True) + + result_content = "" + if result_file: + result_content = Path(result_file).read_text() + + # Init semgrep example checkers if needed + if sample_examples: + init_semgrep_example() + + log_file = semgrep_dir / f"semgrep-log-{time.time()}.log" + result_file = log_file.with_suffix(".txt") + + for line in content.splitlines(): + if result_content and line in result_content: + if line + ",False" in result_content or line + ",True" in result_content: + logger.info(f"Skip {line}") + continue + commit_id, commit_type = line.split(",") + logger.info(f"Processing semgrep for {commit_id} {commit_type}") + try: + semgrep_id = sem_gen_worker( + commit_id, + commit_type, + use_multi=use_multi, + use_general=use_general, + no_utility=no_utility, + sample_examples=sample_examples, + ) + with open(log_file, "a") as flog: + flog.write(f"{commit_id} {commit_type} {semgrep_id}\n") + with open(result_file, "a") as fres: + # If exists a pair (X, True, True) + correct = any([TP > 0 and TN > 0 for _, TP, TN in semgrep_id]) + fres.write(f"{commit_id},{commit_type},{correct}\n") + except Exception as e: + logger.error(f"Error: {e}") + e = str(e).replace("\n", " ") + with open(log_file, "a") as flog: + flog.write(f"{commit_id} {commit_type} {e}\n") + with open(result_file, "a") as fres: + fres.write(f"{commit_id},{commit_type},Exception\n") + +def sem_gen_worker( + commit_id, + commit_type, + use_multi=True, + use_plan_feedback=False, + use_general=False, + no_utility=False, + sample_examples=False, +): + """Generate semgrep rules for one commit, similar to gen_checker_worker.""" + + from backends.semgrep import SemgrepBackend + + if not isinstance(global_config.backend, SemgrepBackend): + logger.info("Switching to Semgrep backend for rule generation") + semgrep_dir = global_config.get("semgrep_dir", "./semgrep_rules") + analysis_backend = SemgrepBackend(semgrep_dir) + else: + analysis_backend = global_config.backend + + target = global_config.target + + semgrep_id = [] + semgrep_data_list: List[CheckerData] = [] + checker_nums = global_config.get("checker_nums") + + id = f"SemgrepGen-{commit_type}-{commit_id}" + semgrep_dir = Path(global_config.get("semgrep_dir", "./semgrep_rules")) + + _build_directory(id) + + patch = target.get_patch(commit_id) + + (semgrep_dir / id).mkdir(parents=True, exist_ok=True) + (semgrep_dir / id / "commit_id.txt").write_text(commit_id) + (semgrep_dir / id / "patchfile.md").write_text(patch) + + ranking_file = semgrep_dir / id / "ranking.txt" + if ranking_file.exists(): + semgrep_id = eval(ranking_file.read_text()) + has_correct_rule = any([TP > 0 and TN > 0 for _, TP, TN in semgrep_id]) + if has_correct_rule: + logger.info(f"Find a perfect semgrep rule!") + logger.info(f"Skip {id}!") + return semgrep_id + + # Generate semgrep rules + for i in range(len(semgrep_id), checker_nums): + semgrep_data = CheckerData(commit_id, commit_type, semgrep_dir, i, patch) + + intermediate_dir = semgrep_dir / id / f"intermediate-{i}" + intermediate_dir.mkdir(parents=True, exist_ok=True) + + if use_multi: + # Patch to Pattern + pattern = patch2pattern(id, i, patch, use_general=use_general) + # Pattern to Semgrep Plan + plan = pattern2semplan( + id, + i, + pattern, + patch, + sample_examples=sample_examples, + ) + refined_plan = plan + # Plan to Semgrep Rule + semgrep_rule = plan2semgrep( + id, + i, + pattern, + refined_plan, + patch, + no_utility=no_utility, + sample_examples=sample_examples, + ) + else: + pattern = "" + plan = "" + refined_plan = "" + semgrep_rule = patch2semgrep(id, i, patch) + + print(f"Semgrep Rule {i}: {semgrep_rule}") + + semgrep_rule = extract_semgrep_rule(semgrep_rule) + + # Update the semgrep_data + semgrep_data.pattern = pattern + semgrep_data.plan = plan + semgrep_data.initial_checker_code = semgrep_rule # Store semgrep rule in checker_code field + + # Dump the semgrep data + (intermediate_dir / "pattern.txt").write_text(pattern) + (intermediate_dir / "plan.txt").write_text(plan) + (intermediate_dir / "refined_plan.txt").write_text(refined_plan) + (intermediate_dir / "semgrep-rule-0.yml").write_text(semgrep_rule) + + # Repair Semgrep Rule + ret, repaired_semgrep_rule = repair_semgrep_rule( + id=id, + repair_name="syntax-repair-" + str(i), + max_idx=4, + intermediate_dir=intermediate_dir, + semgrep_rule=semgrep_rule, + ) + semgrep_data.repaired_checker_code = repaired_semgrep_rule # Store repaired rule in checker_code field + + if not ret: + logger.error(f"Fail to generate valid semgrep rule{i}") + semgrep_id.append((i, -10, -10)) + semgrep_data_list.append(semgrep_data) + continue + + # Store the semgrep rule + rules_dir = semgrep_dir / id / "rules" + rules_dir.mkdir(parents=True, exist_ok=True) + (rules_dir / f"rule{i}.yml").write_text(repaired_semgrep_rule) + logger.info(f"Start to validate semgrep rule{i} in commit {commit_id}") + + TP, TN = analysis_backend.validate_checker( + repaired_semgrep_rule, + commit_id, + patch, + target, + skip_build_checker=True, # Just built the rule + ) + + # Update the semgrep_data + semgrep_data.tp_score = TP + semgrep_data.tn_score = TN + + semgrep_id.append((i, TP, TN)) + semgrep_data_list.append(semgrep_data) + logger.info(f"Semgrep Rule{i} TP: {TP} TN: {TN}") + if TP > 0 and TN > 0: + logger.info(f"Find a perfect semgrep rule{i}!") + break + elif TP == -1 and TN == -1: + logger.error(f"Fail to evaluate semgrep rule{i}!") + break + + for semgrep_data in semgrep_data_list: + # Write the semgrep data + semgrep_data.dump() + semgrep_data.dump_dir() + + # First compare the TP, then TN + semgrep_id = sorted(semgrep_id, key=lambda x: (x[1], x[2]), reverse=True) + print(semgrep_id) + + ranking_file = semgrep_dir / id / "ranking.txt" + ranking_file.write_text(str(semgrep_id)) + + return semgrep_id + +def _build_directory(id: str): + """Build the directory structure for the result.""" + basedir = Path(global_config.get("semgrep_dir", "./semgrep_rules")) / id + basedir.mkdir(parents=True, exist_ok=True) + build_log_dir = basedir / "build_logs" + prompt_history_dir = basedir / "prompt_history" + build_log_dir.mkdir(parents=True, exist_ok=True) + prompt_history_dir.mkdir(parents=True, exist_ok=True) \ No newline at end of file diff --git a/src/semgrep_repair.py b/src/semgrep_repair.py new file mode 100644 index 00000000..301f0027 --- /dev/null +++ b/src/semgrep_repair.py @@ -0,0 +1,94 @@ +from pathlib import Path +from typing import Optional, Tuple + +from global_config import global_config, logger +from tools import extract_semgrep_rule + +# Define constants for clarity +MAX_REPAIR_ATTEMPTS = 4 + + +def repair_semgrep_rule( + id: str, + repair_name: str, + semgrep_rule: str, + max_idx: int = MAX_REPAIR_ATTEMPTS, + intermediate_dir: Optional[Path] = None, +) -> Tuple[bool, Optional[str]]: + """ + Repair the semgrep rule using a language model. + """ + base_dir = Path(global_config.get("semgrep_dir", "./semgrep_rules")) / id + + # Setup directories + prompt_history_dir = base_dir / "prompt_history" / repair_name + prompt_history_dir.mkdir(parents=True, exist_ok=True) + + if intermediate_dir is None: + intermediate_dir = base_dir / f"intermediate-{repair_name}" + intermediate_dir.mkdir(parents=True, exist_ok=True) + + log_dir = base_dir / "build_logs" / repair_name + log_dir.mkdir(parents=True, exist_ok=True) + + current_semgrep_rule = semgrep_rule + + for attempt in range(1, max_idx + 2): + logger.info(f"Semgrep rule validation attempt {attempt}/{max_idx + 1}") + + # Create a temporary Semgrep backend for validation + from backends.semgrep import SemgrepBackend + + if not isinstance(global_config.backend, SemgrepBackend): + temp_backend = SemgrepBackend(global_config.get("semgrep_dir", "./semgrep_rules")) + else: + temp_backend = global_config.backend + + return_code, stderr_content = temp_backend.build_checker( + current_semgrep_rule, + log_dir, + attempt=attempt, + ) + + if return_code == 0: + logger.info("Semgrep rule validation successful!") + return True, current_semgrep_rule + + # Rule validation failed + logger.warning( + f"Semgrep rule validation attempt {attempt} failed with return code {return_code}." + ) + if not stderr_content: + logger.warning("Rule validation failed, but error message was empty.") + return False, None + if attempt > max_idx: + logger.error(f"Semgrep rule repair failed after {max_idx} attempts.") + return False, None + + # Attempt repair + logger.info(f"Attempting semgrep rule repair {attempt} using LLM...") + try: + from agent import repair_semgrep_syntax + + llm_response = repair_semgrep_syntax( + id, repair_name, attempt, current_semgrep_rule, stderr_content + ) + new_semgrep_rule = extract_semgrep_rule(llm_response) + + if new_semgrep_rule is None: + logger.error( + f"Failed to extract new semgrep rule from LLM response for attempt {attempt}." + ) + continue + else: + current_semgrep_rule = new_semgrep_rule + (intermediate_dir / f"semgrep-rule-{attempt}.yml").write_text( + current_semgrep_rule + ) + except Exception as e: + logger.error(f"Error during LLM repair call for attempt {attempt}: {e}") + return False, None + + # Should not be reached if loop logic is correct, but as a safeguard: + logger.error("Exited semgrep rule repair loop unexpectedly.") + return False, None diff --git a/src/tools.py b/src/tools.py index c29b60ce..9babdd22 100644 --- a/src/tools.py +++ b/src/tools.py @@ -661,3 +661,47 @@ def truncate_large_file(content: str, max_lines: int = 500) -> str: truncated_content += "\n".join(last_part) return truncated_content + +def grab_yaml_code(llm_response: str) -> str: + """Extract YAML code block from LLM response.""" + # Try different YAML block patterns + patterns = [ + r"```ya?ml\n([\s\S]*?)\n```", + r"```\n(rules:[\s\S]*?)\n```", + r"(rules:\s*\n[\s\S]*)" + ] + + for pattern in patterns: + match = re.search(pattern, llm_response, re.IGNORECASE) + if match: + return match.group(1) + + return None + +def extract_semgrep_rule(llm_response: str) -> str: + """Extract semgrep rule from LLM response.""" + # First try to get YAML code block + semgrep_rule = grab_yaml_code(llm_response) + + if semgrep_rule is None: + # If no code block found, try to extract rules: section directly + rules_pattern = r"(rules:\s*\n[\s\S]*)" + match = re.search(rules_pattern, llm_response, re.IGNORECASE) + if match: + semgrep_rule = match.group(1) + else: + return None + + # Clean up the rule + if semgrep_rule: + # Remove markdown code block markers if present + semgrep_rule = re.sub(r'^```ya?ml\s*\n', '', semgrep_rule, flags=re.MULTILINE | re.IGNORECASE) + semgrep_rule = re.sub(r'\n```\s*$', '', semgrep_rule, flags=re.MULTILINE) + + # Ensure it starts with 'rules:' + if not semgrep_rule.strip().startswith('rules:'): + semgrep_rule = 'rules:\n' + semgrep_rule + + return semgrep_rule.strip() + + return None \ No newline at end of file