Source code for pykomodo.enhanced_chunker

from pykomodo.multi_dirs_chunker import ParallelChunker
import os
from typing import Optional, List, Tuple

[docs] class EnhancedParallelChunker(ParallelChunker):
[docs] def __init__( self, equal_chunks: Optional[int] = None, max_chunk_size: Optional[int] = None, output_dir: str = "chunks", user_ignore: Optional[List[str]] = None, user_unignore: Optional[List[str]] = None, binary_extensions: Optional[List[str]] = None, priority_rules: Optional[List[Tuple[str, int]]] = None, num_threads: int = 4, extract_metadata: bool = True, add_summaries: bool = True, remove_redundancy: bool = True, context_window: int = 4096, min_relevance_score: float = 0.3 ) -> None: super().__init__( equal_chunks=equal_chunks, max_chunk_size=max_chunk_size, output_dir=output_dir, user_ignore=user_ignore, user_unignore=user_unignore, binary_extensions=binary_extensions, priority_rules=priority_rules, num_threads=num_threads ) self.extract_metadata: bool = extract_metadata self.add_summaries: bool = add_summaries self.remove_redundancy: bool = remove_redundancy self.context_window: int = context_window self.min_relevance_score: float = min_relevance_score
def _extract_file_metadata(self, content: str) -> dict: metadata = { "functions": [], "classes": [], "imports": [], "docstrings": [] } lines = content.split('\n') for line in lines: line_stripped = line.strip() if line_stripped.startswith('def '): func_name = line_stripped[4:].split('(')[0].strip() if func_name != '__init__': metadata['functions'].append(func_name) elif line_stripped.startswith('class '): class_name = line_stripped[6:].split('(')[0].strip() class_name = class_name.rstrip(':') metadata['classes'].append(class_name) elif line_stripped.startswith('import '): if ' as ' in line_stripped: base_import = line_stripped.split(' as ')[0].strip() metadata['imports'].append(base_import) else: metadata['imports'].append(line_stripped) elif line_stripped.startswith('from '): base_from = line_stripped.split(' import ')[0].strip() metadata['imports'].append(base_from) if '"""' in content: start = content.find('"""') + 3 end = content.find('"""', start) if end > start: docstring = content[start:end].strip() metadata['docstrings'].append(docstring) return metadata def _calculate_chunk_relevance(self, chunk_content: str) -> float: all_lines = chunk_content.split('\n') lines = [] for line in all_lines: stripped = line.strip() if stripped: lines.append(stripped) if not lines: return 0.0 code_lines = 0 comment_lines = 0 for line in lines: if line.startswith('#'): comment_lines += 1 else: code_lines += 1 if code_lines == 0: return 0.3 score = 1.0 total_lines = code_lines + comment_lines comment_ratio = comment_lines / total_lines if total_lines else 0.0 if comment_ratio > 0.5: score *= 0.8 return min(0.99, score) def _remove_redundancy_across_all_files(self, big_text: str) -> str: lines = big_text.split('\n') final_lines = [] in_function = False current_function = [] def normalize_function(func_text: str) -> str: lines_ = [ln.strip() for ln in func_text.split('\n')] lines_ = [ln for ln in lines_ if ln] return '\n'.join(lines_) seen_functions = {} for line in lines: stripped = line.rstrip() if stripped.strip().startswith('def '): if in_function and current_function: normed = normalize_function('\n'.join(current_function)) if normed not in seen_functions: seen_functions[normed] = True final_lines.extend(current_function) current_function = [line] in_function = True elif in_function: if stripped.strip().startswith('def '): normed = normalize_function('\n'.join(current_function)) if normed not in seen_functions: seen_functions[normed] = True final_lines.extend(current_function) current_function = [line] else: current_function.append(line) else: final_lines.append(line) if in_function and current_function: normed = normalize_function('\n'.join(current_function)) if normed not in seen_functions: seen_functions[normed] = True final_lines.extend(current_function) return "\n".join(final_lines) def _chunk_by_equal_parts(self) -> None: if not self.loaded_files: return all_file_texts = [] combined_metadata = { "functions": set(), "classes": set(), "imports": [], "docstrings": set() } for path, content_bytes, _ in self.loaded_files: try: content = content_bytes.decode('utf-8', errors='replace') except Exception: print(f"Error decoding file {path}") continue if self.extract_metadata: fm = self._extract_file_metadata(content) combined_metadata["functions"].update(fm["functions"]) combined_metadata["classes"].update(fm["classes"]) combined_metadata["imports"].extend(fm["imports"]) combined_metadata["docstrings"].update(fm["docstrings"]) all_file_texts.append(content) combined_text = "\n".join(all_file_texts) if self.remove_redundancy: combined_text = self._remove_redundancy_across_all_files(combined_text) if not self.equal_chunks or self.equal_chunks <= 1: self._create_and_write_chunk( combined_text, 0, combined_metadata if self.extract_metadata else None ) return total_size = len(combined_text.encode('utf-8')) max_size = (self.context_window - 50) if (self.context_window and self.context_window > 200) else float('inf') max_size = int(max_size) if max_size != float('inf') else max_size target_size = min(total_size // self.equal_chunks, max_size) chunk_num = 0 remaining = combined_text while remaining: portion_bytes = remaining.encode('utf-8')[:target_size] portion = portion_bytes.decode('utf-8', errors='replace') last_newline = portion.rfind('\n') if last_newline > 0: portion = portion[:last_newline] self._create_and_write_chunk( portion, chunk_num, combined_metadata if self.extract_metadata else None ) chunk_num += 1 portion_len = len(portion) remaining = remaining[portion_len:] if chunk_num >= self.equal_chunks - 1: if remaining: self._create_and_write_chunk( remaining, chunk_num, combined_metadata if self.extract_metadata else None ) break def _create_and_write_chunk(self, text: str, chunk_num: int, metadata: dict = None) -> None: content_parts = [] if chunk_num == 0 and self.current_walk_root: tree = self.tree_generator.prepare_tree_header(self.current_walk_root) content_parts.append(tree) content_parts.append(f"CHUNK {chunk_num}") if metadata and self.extract_metadata: content_parts.append("METADATA:") if metadata["functions"]: funcs = ", ".join(sorted(metadata["functions"])) content_parts.append(f"FUNCTIONS: {funcs}") if metadata["classes"]: classes = ", ".join(sorted(metadata["classes"])) content_parts.append(f"CLASSES: {classes}") if metadata["imports"]: imports = ", ".join(metadata["imports"]) content_parts.append(f"IMPORTS: {imports}") if metadata["docstrings"]: doc = list(metadata["docstrings"])[0].replace('\n', ' ')[:100] content_parts.append(f"DOCSTRING: {doc}") score = self._calculate_chunk_relevance(text) content_parts.append(f"RELEVANCE_SCORE: {score:.2f}") header = "\n".join(content_parts) + "\n\n" full_content = header + text if self.context_window: max_bytes = self.context_window content_bytes = full_content.encode('utf-8') if len(content_bytes) > max_bytes: header_bytes = header.encode('utf-8') remaining_space = max_bytes - len(header_bytes) truncated = content_bytes[:len(header_bytes) + remaining_space] truncated_str = truncated.decode('utf-8', errors='ignore') last_newline = truncated_str.rfind('\n') if last_newline > len(header): full_content = truncated_str[:last_newline] os.makedirs(self.output_dir, exist_ok=True) chunk_path = os.path.join(self.output_dir, f"chunk-{chunk_num}.txt") try: with open(chunk_path, 'w', encoding='utf-8') as f: f.write(full_content) except Exception: print(f"Failed to write chunk {chunk_num}") def _write_minimal_chunk(self, content_bytes: bytes, chunk_num: int) -> None: try: tree_header = "" if chunk_num == 0 and hasattr(self, 'current_walk_root') and self.current_walk_root: tree_header = self.tree_generator.prepare_tree_header(self.current_walk_root) if isinstance(content_bytes, bytes): content_str = content_bytes.decode('utf-8', errors='replace') else: content_str = str(content_bytes) final_content = tree_header + content_str final_bytes = final_content.encode('utf-8') if self.context_window and len(final_bytes) > self.context_window: content_bytes = content_bytes[:self.context_window] chunk_path = os.path.join(self.output_dir, f"chunk-{chunk_num}.txt") with open(chunk_path, 'wb') as f: f.write(final_bytes) except Exception: print(f"Error writing minimal chunk-{chunk_num}")