From e9e118561710202de33a456b2c8d0a6d37e29cc9 Mon Sep 17 00:00:00 2001 From: Loek Le Blansch Date: Tue, 28 Oct 2025 11:56:51 +0100 Subject: support -C/--root for setting patchset root --- patchtree/cli.py | 51 ++++++++++++++++++++++++-------- patchtree/config.py | 3 +- patchtree/context.py | 83 +++++++++++++++++++++++++++++++++++----------------- patchtree/patch.py | 3 ++ 4 files changed, 100 insertions(+), 40 deletions(-) diff --git a/patchtree/cli.py b/patchtree/cli.py index 9ddb4fb..bef0c0d 100644 --- a/patchtree/cli.py +++ b/patchtree/cli.py @@ -1,6 +1,7 @@ from dataclasses import fields from sys import stderr from pathlib import Path +from argparse import ArgumentTypeError from .context import Context from .config import Config @@ -16,6 +17,13 @@ def load_config() -> Config: return Config(**init) +def path_dir(path: str) -> Path: + out = Path(path) + if not out.is_dir(): + raise ArgumentTypeError(f"not a directory: `{path}'") + return out + + def parse_arguments(config: Config) -> Context: parser = config.argument_parser( prog="patchtree", @@ -36,6 +44,7 @@ def parse_arguments(config: Config) -> Context: parser.add_argument( "-c", "--context", + metavar="NUM", help="lines of context in output diff", type=int, ) @@ -45,8 +54,32 @@ def parse_arguments(config: Config) -> Context: help="output shebang in resulting patch", action="store_true", ) - parser.add_argument("target", help="target directory or archive") - parser.add_argument("patch", help="patch input glob(s)", nargs="*") + parser.add_argument( + "-C", + "--root", + metavar="DIR", + help="patchset root directory", + type=path_dir, + ) + parser.add_argument( + "-g", + "--glob", + help="enable globbing for input(s)", + action="store_true", + ) + parser.add_argument( + "target", + metavar="TARGET", + help="target directory or archive", + type=Path, + ) + parser.add_argument( + "patch", + metavar="PATCH", + help="patchset input(s)", + nargs="*", + type=str, + ) options = parser.parse_args() @@ -70,24 +103,18 @@ def main(): context = parse_arguments(config) - file_set: set[Path] = set() - for pattern in context.options.patch: - for path in Path(".").glob(pattern): - if not path.is_file(): - continue - file_set.add(path) - files = sorted(file_set) - - if len(files) == 0: + if len(context.inputs) == 0: print("no files to patch!", file=stderr) return 0 config.header(config, context) - for file in files: + for file in context.inputs: patch = config.patch(config, file) patch.write(context) + context.close() + return 0 diff --git a/patchtree/config.py b/patchtree/config.py index fbe78e9..2f2373e 100644 --- a/patchtree/config.py +++ b/patchtree/config.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, field from argparse import ArgumentParser from importlib import metadata +from pathlib import Path from .context import Context from .patch import Patch @@ -95,7 +96,7 @@ class Config: """Whether to output a shebang line with the ``git patch`` command to apply the patch.""" - default_patch_sources: list[str] = field(default_factory=list) + default_patch_sources: list[Path] = field(default_factory=list) """List of default sources.""" def __post_init__(self): diff --git a/patchtree/context.py b/patchtree/context.py index 2538849..cca419a 100644 --- a/patchtree/context.py +++ b/patchtree/context.py @@ -18,9 +18,9 @@ ZIP_CREATE_SYSTEM_UNX = 3 class FS: - target: str + target: Path - def __init__(self, target: str): + def __init__(self, target: Path): self.target = target def get_dir(self, dir: str) -> list[str]: @@ -34,18 +34,15 @@ class FS: class DiskFS(FS): - path: Path - - def __init__(self, target: str): + def __init__(self, target): super(DiskFS, self).__init__(target) - self.path = Path(target) def get_dir(self, dir: str) -> list[str]: - here = self.path.joinpath(dir) + here = self.target.joinpath(dir) return [path.name for path in here.iterdir()] def get_content(self, file: str) -> str | None: - here = self.path.joinpath(file) + here = self.target.joinpath(file) if not here.exists(): return None bytes = here.read_bytes() @@ -55,7 +52,7 @@ class DiskFS(FS): return "" def get_mode(self, file: str) -> int: - here = self.path.joinpath(file) + here = self.target.joinpath(file) if not here.exists(): return 0 return here.stat().st_mode @@ -65,9 +62,9 @@ class ZipFS(FS): zip: ZipFile files: dict[Path, ZipInfo] = {} - def __init__(self, target: str): + def __init__(self, target): super(ZipFS, self).__init__(target) - self.zip = ZipFile(target) + self.zip = ZipFile(str(target)) for info in self.zip.infolist(): self.files[Path(info.filename)] = info @@ -117,29 +114,61 @@ class Context: fs: FS output: IO + root: Path + target: Path + inputs: list[Path] = [] + in_place: bool + config: Config - options: Namespace def __init__(self, config: Config, options: Namespace): self.config = config - self.options = options + + self.root = options.root + self.target = options.target + self.in_place = options.in_place + self.inputs = self.collect_inputs(options) self.fs = self.get_fs() - self.output = self.get_output() + self.output = self.get_output(options) - if self.options.in_place: + if self.in_place: self.apply(True) - def __del__(self): + def close(self): # patch must have a trailing newline self.output.write("\n") self.output.flush() - if self.options.in_place: + if self.in_place: self.apply(False) self.output.close() + def collect_inputs(self, options: Namespace) -> list[Path]: + inputs: set[Path] = set() + + if len(inputs) == 0: + options.glob = True + options.patch = [str(Path(options.root or ".").joinpath("**"))] + + if options.glob: + for pattern in options.patch: + for path in Path(".").glob(pattern): + if not path.is_file(): + continue + inputs.add(path) + return sorted(inputs) + else: + for input in options.patch: + path = Path(input) + if not path.exists(): + raise Exception(f"cannot open `{input}'") + if not path.is_file(): + raise Exception(f"not a file: `{input}'") + inputs.add(path) + return list(inputs) + def get_dir(self, dir: str) -> list[str]: return self.fs.get_dir(dir) @@ -150,32 +179,32 @@ class Context: return self.fs.get_mode(file) def get_fs(self) -> FS: - target: str = self.options.target + target = self.target - if not path.exists(target): + if not target.exists(): raise Exception(f"cannot open `{target}'") if path.isdir(target): return DiskFS(target) if is_zipfile(target): - if self.options.in_place: + if self.in_place: raise Exception("cannot edit zip in-place!") return ZipFS(target) raise Exception("cannot read `{target}'") - def get_output(self) -> IO: - if self.options.in_place: - if self.options.out is not None: + def get_output(self, options: Namespace) -> IO: + if self.in_place: + if options.out is not None: print("warning: --out is ignored when using --in-place", file=stderr) return TemporaryFile("w+") - if self.options.out is not None: - if self.options.out == "-": + if options.out is not None: + if options.out == "-": return stdout else: - return open(self.options.out, "w+") + return open(options.out, "w+") return stdout @@ -190,7 +219,7 @@ class Context: return cmd def apply(self, reverse: bool) -> None: - location = cast(DiskFS, self.fs).path + location = cast(DiskFS, self.fs).target cache = location.joinpath(".patchtree.diff") cmd = self.get_apply_cmd() diff --git a/patchtree/patch.py b/patchtree/patch.py index 32d3b7f..bff7ee4 100644 --- a/patchtree/patch.py +++ b/patchtree/patch.py @@ -42,6 +42,9 @@ class Patch: ) def write(self, context: Context) -> None: + if context.root is not None: + self.file = str(Path(self.file).relative_to(context.root)) + diff = Diff(self.config, self.file) diff.a = File( -- cgit v1.2.3