Diff between 923ac696f099bb6bfb87aadcdd98500fc6cc24a4 and c524c43b1ca5cb4e53af182b4129cd0cc952297d

Changed Files

File Additions Deletions Status
app.py +62 -8 modified
git/commit.py +6 -1 modified
git/misc.py +52 -1 modified

Full Patch

diff --git a/app.py b/app.py
index 46f6d63..7d9d190 100644
--- a/app.py
+++ b/app.py
@@ -1,5 +1,5 @@
 import os
-from flask import Flask, render_template, request
+from flask import Flask, render_template, request, abort
 from datetime import datetime
 from dotenv import load_dotenv
 
@@ -8,7 +8,7 @@ from git.commit import get_commits, get_commit
 from git.ref import get_refs
 from git.tree import get_tree_items
 from git.blob import get_blob
-from git.misc import get_version
+from git.misc import get_version, validate_repo_name, validate_ref, validate_ref_as_commit, sanitize_path
 from git.diff import get_diff
 from git.blame import get_blame
 from highlight import highlight_diff
@@ -22,7 +22,16 @@ app.jinja_env.globals['request'] = request
 
 @app.context_processor
 def inject_current_ref():
-    return {'current_ref': request.args.get('ref', 'HEAD')}
+    ref = request.args.get('ref', 'HEAD').strip()
+    # if ref is invalid, default to HEAD to prevent broken links
+    repo_name = request.view_args.get('repo_name') if request.view_args else None
+    if repo_name:
+        try:
+            if not validate_ref_as_commit(f"{repo_path}/{repo_name}", ref):
+                ref = 'HEAD'
+        except:
+            ref = 'HEAD'
+    return {'current_ref': ref}
 
 repo_path = os.getenv('GIT_REPO_PATH')
 
@@ -75,13 +84,22 @@ def index():
 
 @app.route("/<repo_name>")
 def repo_detail(repo_name):
-    commits = get_commits(f"{repo_path}/{repo_name}", ref="HEAD", max_count=10)
+    if not validate_repo_name(repo_name):
+        abort(404)
+    ref = request.args.get('ref', 'HEAD').strip()
+    if not validate_ref_as_commit(f"{repo_path}/{repo_name}", ref):
+        abort(400, "Invalid ref")
+    commits = get_commits(f"{repo_path}/{repo_name}", ref=ref, max_count=10)
     refs = get_refs(f"{repo_path}/{repo_name}")
     return render_template("repo.html", repo_name=repo_name, refs=refs, commits=commits)
 
 @app.route("/<repo_name>/commits")
 def repo_commits(repo_name):
-    ref = request.args.get('ref', 'HEAD')
+    if not validate_repo_name(repo_name):
+        abort(404)
+    ref = request.args.get('ref', 'HEAD').strip()
+    if not validate_ref_as_commit(f"{repo_path}/{repo_name}", ref):
+        abort(400, "Invalid ref")
     refs = get_refs(f"{repo_path}/{repo_name}")
     page = int(request.args.get('page', 0))
     # maybe pages are not the wisest way to do this?
@@ -95,32 +113,62 @@ def repo_commits(repo_name):
 
 @app.route("/<repo_name>/commits/<commit_id>")
 def commit_detail(repo_name, commit_id):
+    if not validate_repo_name(repo_name):
+        abort(404)
+    if not validate_ref_as_commit(f"{repo_path}/{repo_name}", commit_id):
+        abort(400, "Invalid commit id")
     commit = get_commit(f"{repo_path}/{repo_name}", commit_id)
     return render_template("commit.html", repo_name=repo_name, commit=commit)
 
 @app.route("/<repo_name>/refs")
 def repo_refs(repo_name):
+    if not validate_repo_name(repo_name):
+        abort(404)
     refs = get_refs(f"{repo_path}/{repo_name}")
     return render_template("refs.html", repo_name=repo_name, refs=refs)
 
 @app.route("/<repo_name>/tree", defaults={'path': ''})
 @app.route("/<repo_name>/tree/<path:path>")
 def repo_tree_path(repo_name, path):
-    ref = request.args.get('ref', 'HEAD')
+    if not validate_repo_name(repo_name):
+        abort(404)
+    ref = request.args.get('ref', 'HEAD').strip()
+    if not validate_ref(f"{repo_path}/{repo_name}", ref):
+        abort(400, "Invalid ref")
+    try:
+        path = sanitize_path(path)
+    except ValueError:
+        abort(400, "Invalid path")
     refs = get_refs(f"{repo_path}/{repo_name}")
     tree_items = get_tree_items(f"{repo_path}/{repo_name}", ref, path)
     return render_template("tree.html", repo_name=repo_name, ref=ref, path=path, tree_items=tree_items, refs=refs)
 
 @app.route("/<repo_name>/blob/<path:path>")
 def repo_blob_path(repo_name, path):
-    ref = request.args.get('ref', 'HEAD')
+    if not validate_repo_name(repo_name):
+        abort(404)
+    ref = request.args.get('ref', 'HEAD').strip()
+    if not validate_ref(f"{repo_path}/{repo_name}", ref):
+        abort(400, "Invalid ref")
+    try:
+        path = sanitize_path(path)
+    except ValueError:
+        abort(400, "Invalid path")
     refs = get_refs(f"{repo_path}/{repo_name}")
     blob = get_blob(f"{repo_path}/{repo_name}", ref, path)
     return render_template("blob.html", repo_name=repo_name, ref=ref, path=path, blob=blob, refs=refs)
 
 @app.route("/<repo_name>/blame/<path:path>")
 def repo_blame_path(repo_name, path):
-    ref = request.args.get('ref', 'HEAD')
+    if not validate_repo_name(repo_name):
+        abort(404)
+    ref = request.args.get('ref', 'HEAD').strip()
+    if not validate_ref(f"{repo_path}/{repo_name}", ref):
+        abort(400, "Invalid ref")
+    try:
+        path = sanitize_path(path)
+    except ValueError:
+        abort(400, "Invalid path")
     refs = get_refs(f"{repo_path}/{repo_name}")
     
     # if ajax (for loading)
@@ -133,9 +181,15 @@ def repo_blame_path(repo_name, path):
 
 @app.route("/<repo_name>/diff")
 def repo_diff(repo_name):
+    if not validate_repo_name(repo_name):
+        abort(404)
     refs = get_refs(f"{repo_path}/{repo_name}")
     id1 = request.args.get('id1')
     id2 = request.args.get('id2')
+    if id1 and not validate_ref_as_commit(f"{repo_path}/{repo_name}", id1):
+        abort(400, "Invalid id1 (reference from)")
+    if id2 and not validate_ref_as_commit(f"{repo_path}/{repo_name}", id2):
+        abort(400, "Invalid id2 (reference to)")
     context_lines = int(request.args.get('context_lines', 3))
     interhunk_lines = int(request.args.get('interhunk_lines', 0))
     # TODO: ADD ERROR HANDLING EVERYWHERE!!
diff --git a/git/commit.py b/git/commit.py
index 34ab713..9792037 100644
--- a/git/commit.py
+++ b/git/commit.py
@@ -19,7 +19,12 @@ def get_commits(path, ref="HEAD", max_count=None, skip=0):
     repo = git.Repository(path)
     commits = []
     # TODO: accept blob oids to filter commits that touch specific blobs
-    walker = repo.walk(repo.revparse_single(ref).id, git.GIT_SORT_TIME)
+    obj = repo.revparse_single(ref)
+    if obj.type == git.GIT_OBJECT_COMMIT:
+        commit = obj
+    else:
+        commit = obj.peel(git.GIT_OBJECT_COMMIT)
+    walker = repo.walk(commit.id, git.GIT_SORT_TIME)
 
     n = 0
     for commit in walker:
diff --git a/git/misc.py b/git/misc.py
index 4fc917e..a6f910e 100644
--- a/git/misc.py
+++ b/git/misc.py
@@ -1,4 +1,55 @@
 import pygit2 as git
 
 def get_version():
-    return git.LIBGIT2_VERSION
\ No newline at end of file
+    return git.LIBGIT2_VERSION
+
+def validate_repo_name(name):
+    if not name or not isinstance(name, str):
+        return False
+    # no path traversal or hidden dirs
+    if '/' in name or '\\' in name or '..' in name or name.startswith('.'):
+        return False
+    # basic char check
+    invalid_chars = ['<', '>', ':', '"', '|', '?', '*']
+    if any(char in name for char in invalid_chars):
+        return False
+    return True
+
+def validate_ref(repo_path, ref):
+    if not ref or not isinstance(ref, str):
+        return False
+    # attempt to resolve the ref, if fails, invalid
+    try:
+        repo = git.Repository(repo_path)
+        repo.revparse_single(ref.strip())
+        return True
+    except:
+        return False
+
+def validate_ref_as_commit(repo_path, ref):
+    # this tries to resolve ref to a commit
+    # it only works for commit, tag, and branch refs, since they all point to a commit
+    # blobs and trees do not resolve to commits, so they are invalid
+    # more about refs: 
+    # https://git-scm.com/book/en/v2/Git-Internals-Git-References
+
+    if not validate_ref(repo_path, ref):
+        return False
+    try:
+        repo = git.Repository(repo_path)
+        obj = repo.revparse_single(ref.strip())
+        if obj.type == git.GIT_OBJECT_COMMIT:
+            return True
+        obj.peel(git.GIT_OBJECT_COMMIT)
+        return True
+    except:
+        return False
+
+def sanitize_path(path):
+    # protects against path traversal and invalid paths
+    if not path:
+        return ""
+    path = path.strip('/')
+    if '..' in path:
+        raise ValueError("invalid path! contains '..'")
+    return path
\ No newline at end of file