|
from __future__ import annotations |
|
|
|
import re |
|
import requests |
|
from dataclasses import dataclass |
|
|
|
import gradio as gr |
|
from tree_sitter import Tree, Node |
|
from tree_sitter_languages import get_parser |
|
|
|
def non_whitespace_len(s: str) -> int: |
|
return len(re.sub("\s", "", s)) |
|
|
|
def get_line_number(index: int, source_code: str) -> int: |
|
total_chars = 0 |
|
for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1): |
|
total_chars += len(line) |
|
if total_chars > index: |
|
return line_number - 1 |
|
return line_number |
|
|
|
@dataclass |
|
class Span: |
|
|
|
start: int = 0 |
|
end: int = 0 |
|
|
|
def __post_init__(self): |
|
|
|
if self.end is None: |
|
self.end = self.start |
|
|
|
def extract(self, s: str) -> str: |
|
|
|
return s[self.start: self.end] |
|
|
|
def extract_lines(self, s: str) -> str: |
|
|
|
return "\n".join(s.splitlines()[self.start:self.end]) |
|
|
|
def __add__(self, other: Span | int) -> Span: |
|
|
|
|
|
|
|
if isinstance(other, int): |
|
return Span(self.start + other, self.end + other) |
|
elif isinstance(other, Span): |
|
return Span(self.start, other.end) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def __len__(self) -> int: |
|
|
|
return self.end - self.start |
|
|
|
def chunk_tree( |
|
tree: Tree, |
|
source_code: bytes, |
|
MAX_CHARS=512 * 3, |
|
coalesce=50 |
|
) -> list[Span]: |
|
|
|
|
|
def chunk_node(node: Node) -> list[Span]: |
|
chunks: list[Span] = [] |
|
current_chunk: Span = Span(node.start_byte, node.start_byte) |
|
node_children = node.children |
|
for child in node_children: |
|
if child.end_byte - child.start_byte > MAX_CHARS: |
|
chunks.append(current_chunk) |
|
current_chunk = Span(child.end_byte, child.end_byte) |
|
chunks.extend(chunk_node(child)) |
|
elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS: |
|
chunks.append(current_chunk) |
|
current_chunk = Span(child.start_byte, child.end_byte) |
|
else: |
|
current_chunk += Span(child.start_byte, child.end_byte) |
|
chunks.append(current_chunk) |
|
return chunks |
|
chunks = chunk_node(tree.root_node) |
|
|
|
|
|
for prev, curr in zip(chunks[:-1], chunks[1:]): |
|
prev.end = curr.start |
|
curr.start = tree.root_node.end_byte |
|
|
|
|
|
new_chunks = [] |
|
current_chunk = Span(0, 0) |
|
for chunk in chunks: |
|
current_chunk += chunk |
|
if non_whitespace_len(current_chunk.extract(source_code.decode("utf-8"))) > coalesce \ |
|
and "\n" in current_chunk.extract(source_code.decode("utf-8")): |
|
new_chunks.append(current_chunk) |
|
current_chunk = Span(chunk.end, chunk.end) |
|
if len(current_chunk) > 0: |
|
new_chunks.append(current_chunk) |
|
|
|
|
|
line_chunks = [ |
|
Span( |
|
get_line_number(chunk.start, source_code), |
|
get_line_number(chunk.end, source_code) |
|
) |
|
for chunk in new_chunks |
|
] |
|
|
|
|
|
line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0] |
|
|
|
return line_chunks |
|
|
|
css = """ |
|
.code_container { |
|
} |
|
""" |
|
|
|
def chunk_code( |
|
code: str, |
|
language: str, |
|
MAX_CHARS: int, |
|
coalesce: int |
|
): |
|
try: |
|
parser = get_parser(language) |
|
tree = parser.parse(code.encode("utf-8")) |
|
chunks = chunk_tree(tree, code.encode("utf-8"), MAX_CHARS=MAX_CHARS, coalesce=coalesce) |
|
chunks = [chunk.extract_lines(code) for chunk in chunks] |
|
return "\n\n====================\n\n".join(chunks) |
|
except Exception as e: |
|
return str(e) |
|
|
|
examples_dict = { |
|
"Python: Sweep's GiHub Actions log handler": ("https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py", "python"), |
|
"Typescript: LlamaIndex TS's BaseIndex abstract base class": ("https://raw.githubusercontent.com/run-llama/LlamaIndexTS/bfab1d407b7b390d76b3d7a1a1df0928e9f9ae11/packages/core/src/indices/BaseIndex.ts", "typescript"), |
|
"Rust: Ruff's autofix code modification algorithm": ("https://raw.githubusercontent.com/astral-sh/ruff/main/crates/ruff/src/autofix/codemods.rs", "rust"), |
|
"Go: Infisical's CLI's config manager": ("https://raw.githubusercontent.com/Infisical/infisical/de7bd27b4b48847c9ca7cd12d208225b06f170fe/cli/packages/util/config.go", "go") |
|
} |
|
|
|
default_key = "Python: Sweep's GiHub Actions log handler" |
|
default_url, default_language = examples_dict[default_key] |
|
default_code = requests.get(default_url).text |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("## Code Chunking Demo") |
|
gr.Markdown("Start typing below and the chunked output will automatically show up.️ Drop a like if you enjoy the demo! \n\nCheckout how this algorithm works at https://docs.sweep.dev/blogs/chunking-2m-files and https://docs.sweep.dev/blogs/chunking-improvements or play with the notebook at https://github.com/sweepai/sweep/blob/main/notebooks/chunking.ipynb.") |
|
|
|
with gr.Row(): |
|
language = gr.Dropdown(['python', 'javascript', 'typescript', 'rust', 'go', 'ruby', 'r', 'html', 'css', 'shell'], label="Language", value=default_language) |
|
max_chars = gr.Slider(10, 3000, 1500, label="Max Characters", step=10) |
|
coalesce = gr.Slider(0, 300, 100, label="Coalesce", step=10) |
|
examples = gr.Dropdown(list(examples_dict.keys()), label="Examples", value=default_key, interactive=True) |
|
with gr.Row(): |
|
input_code = gr.Code(label="Input Code", language=language.value, lines=60, elem_classes="code_container", value=default_code) |
|
output_code = gr.Code(label="Chunked Code", language=language.value, lines=60, value=chunk_code(default_code, language.value, max_chars.value, coalesce.value)) |
|
|
|
def update_examples(examples): |
|
url, language = examples_dict[examples] |
|
code = requests.get(url).text |
|
return gr.Code.update(language=language, value=code), gr.Code.update(language=language, value=chunk_code(code, language, max_chars.value, coalesce.value)), language |
|
|
|
examples.change(fn=update_examples, inputs=examples, outputs=[input_code, output_code, language]) |
|
|
|
def update_language(language): |
|
return gr.Code.update(language=language), gr.Code.update(language=language) |
|
|
|
language.change(fn=update_language, inputs=language, outputs=[input_code, output_code]) \ |
|
.then(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code) |
|
max_chars.change(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code) |
|
coalesce.change(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code) |
|
input_code.change(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code) |
|
|
|
demo.launch() |
|
|