"""
This module performs git operations
"""
import pathlib
import tempfile
from typing import Dict, List, Optional, Union
from urllib.parse import urlparse
import pygit2
import requests
from charset_normalizer import from_bytes
[docs]
def clone_repository(repo_url: str) -> pathlib.Path:
    """
    Clones the GitHub repository to a temporary directory.
    Args:
        repo_url (str): The URL of the GitHub repository.
    Returns:
        pathlib.Path: Path to the cloned repository.
    """
    # Create a temporary directory to store the cloned repository
    temp_dir = tempfile.mkdtemp()
    # Define the path for the cloned repository within the temporary directory
    repo_path = pathlib.Path(temp_dir) / "repo"
    # Clone the repository from the given URL into the defined path
    pygit2.clone_repository(repo_url, str(repo_path))
    return repo_path 
[docs]
def get_commits(repo: pygit2.Repository) -> List[pygit2.Commit]:
    """
    Retrieves the list of commits from the main branch.
    Args:
        repo (pygit2.Repository): The Git repository.
    Returns:
        List[pygit2.Commit]: List of commits in the repository.
    """
    # Get the latest commit (HEAD) from the repository
    head = repo.revparse_single("HEAD")
    # Create a walker to iterate over commits starting from the HEAD
    # sorting by time.
    walker = repo.walk(head.id, pygit2.GIT_SORT_TIME)
    # Collect all commits from the walker into a list
    commits = list(walker)
    return commits 
[docs]
def get_edited_files(
    repo: pygit2.Repository,
    source_commit: pygit2.Commit,
    target_commit: pygit2.Commit,
) -> List[str]:
    """
    Finds all files that have been edited, added, or deleted between two specific commits.
    Args:
        repo (pygit2.Repository): The Git repository.
        source_commit (pygit2.Commit): The source commit.
        target_commit (pygit2.Commit): The target commit.
    Returns:
        List[str]: List of file names that have been edited, added, or deleted between the two commits.
    """
    # Create a set to store unique file names that have been edited
    file_names = set()
    # Get the differences (diff) between the source and target commits
    diff = repo.diff(source_commit, target_commit)
    # Iterate through each patch in the diff
    for patch in diff:
        # If the old file path is present, add it to the set
        if patch.delta.old_file.path:
            file_names.add(patch.delta.old_file.path)
        # If the new file path is present, add it to the set
        if patch.delta.new_file.path:
            file_names.add(patch.delta.new_file.path)
    return list(file_names) 
[docs]
def get_loc_changed(
    repo_path: pathlib.Path, source: str, target: str, file_names: List[str]
) -> Dict[str, int]:
    """
    Finds the total number of code lines changed for each specified file between two commits.
    Args:
        repo_path (pathlib.Path): The path to the git repository.
        source (str): The source commit hash.
        target (str): The target commit hash.
        file_names (List[str]): List of file names to calculate changes for.
    Returns:
        Dict[str, int]: A dictionary where the key is the filename, and the value is the lines changed (added and removed).
    """
    repo = pygit2.Repository(str(repo_path))
    # Resolve the source and target commits by their hashes
    source_commit = repo.revparse_single(source)
    target_commit = repo.revparse_single(target)
    changes = {}
    # Compute the diff between the source and target commits
    diff = repo.diff(source_commit, target_commit)
    # Iterate over each patch in the diff
    for patch in diff:
        if patch.delta.new_file.path in file_names:
            additions = 0
            deletions = 0
            # Iterate over each hunk in the patch
            for hunk in patch.hunks:
                # Iterate over each line in the hunk
                for line in hunk.lines:
                    if line.origin == "+":
                        additions += 1
                    elif line.origin == "-":
                        deletions += 1
            lines_changed = additions + deletions
            # Store the number of lines changed for the file
            changes[patch.delta.new_file.path] = lines_changed
    return changes 
[docs]
def get_most_recent_commits(repo_path: pathlib.Path) -> tuple[str, str]:
    """
    Retrieves the two most recent commit hashes in the test repositories
    Args:
        repo_path (pathlib.Path): The path to the git repository.
    Returns:
        tuple[str, str]: Tuple containing the source and target commit hashes.
    """
    repo = pygit2.Repository(str(repo_path))
    commits = get_commits(repo)
    # Assumes that commits are sorted by time, with the most recent first
    source_commit = commits[1]  # Second most recent
    target_commit = commits[0]  # Most recent
    return str(source_commit.id), str(target_commit.id) 
[docs]
def detect_encoding(blob_data: bytes) -> str:
    """
    Detect the encoding of the given blob data using charset-normalizer.
    Args:
        blob_data (bytes): The raw bytes of the blob to analyze.
    Returns:
        str: The best detected encoding of the blob data.
    Raises:
        ValueError: If no encoding could be detected.
    """
    if not blob_data:
        raise ValueError("No data provided for encoding detection.")
    result = from_bytes(blob_data)
    if result.best():
        # Get the best encoding found
        return result.best().encoding
    raise ValueError("Encoding could not be detected.") 
[docs]
def find_file(
    repo: pygit2.Repository,
    filepath: str,
    case_insensitive: bool = False,
    extensions: list[str] = [".md", ".txt", ".rtf", ".rst", ""],
) -> Optional[pygit2.Object]:
    """
    Locate a file in the repository by its path.
    Args:
        repo (pygit2.Repository):
            The repository object.
        filepath (str):
            The path to the file within the repository.
        case_insensitive (bool):
            If True, perform case-insensitive comparison.
        extensions (list[str]):
            List of possible file extensions to check (e.g., [".md", ""]).
    Returns:
        Optional[pygit2.Object]:
            The entry of the found file,
            or None if no matching file is found.
    """
    # Get the tree object of the latest commit
    tree = repo.head.peel().tree
    # Iterate over each extension to check for the file
    for ext in extensions:
        full_path = f"{filepath}{ext}"  # Construct the full path with the extension
        if not case_insensitive:
            try:
                # Try to get the file entry directly (case-sensitive)
                return tree[full_path]
            except KeyError:
                continue  # If not found, continue to the next extension
        else:
            # Split the path into parts for case-insensitive comparison
            path_parts = full_path.lower().split("/")
            current_tree = tree
            for i, part in enumerate(path_parts):
                try:
                    # Find the entry in the current tree that matches the part (case-insensitive)
                    entry = next(e for e in current_tree if e.name.lower() == part)
                except StopIteration:
                    break  # If no matching entry is found, break the loop
                if entry.type == pygit2.GIT_OBJECT_TREE:
                    # If the entry is a tree, update the current tree to this entry
                    current_tree = repo[entry.id]
                elif entry.type == pygit2.GIT_OBJECT_BLOB:
                    # If the entry is a blob and it's the last part, return the entry
                    if i == len(path_parts) - 1:
                        return entry
                    else:
                        break  # If it's not the last part, break the loop
                else:
                    break  # If the entry is neither a tree nor a blob, break the loop
    # Return None if no valid file is found
    return None 
[docs]
def count_files(tree: Union[pygit2.Tree, pygit2.Blob]) -> int:
    """
    Counts all files (Blobs) within a Git tree, including files
    in subdirectories.
    This function recursively traverses the provided `tree`
    object to count each file, represented as a `pygit2.Blob`,
    within the tree and any nested subdirectories.
    Args:
        tree (Union[pygit2.Tree, pygit2.Blob]):
            The Git tree object (of type `pygit2.Tree`)
            to traverse and count files. The initial call
            should be made with the root tree of a commit.
    Returns:
        int:
            The total count of files (Blobs) within the tree,
            including nested files in subdirectories.
    """
    if isinstance(tree, pygit2.Blob):
        # Directly return 1 if the input is a Blob
        return 1
    elif isinstance(tree, pygit2.Tree):
        # Recursively count files for Tree
        return sum(count_files(entry) for entry in tree)
    else:
        # If neither, return 0
        return 0 
[docs]
def read_file(
    repo: pygit2.Repository,
    entry: Optional[pygit2.Object] = None,
    filepath: Optional[str] = None,
    case_insensitive: bool = False,
) -> Optional[str]:
    """
    Read the content of a file from the repository.
    Args:
        repo (pygit2.Repository):
            The repository object.
        entry (Optional[pygit2.Object]):
            The entry of the file to read. If not provided, filepath must be specified.
        filepath (Optional[str]):
            The path to the file within the repository. Used if entry is not provided.
        case_insensitive (bool):
            If True, perform case-insensitive comparison when using filepath.
    Returns:
        Optional[str]:
            The content of the file as a string,
            or None if the file is not found or reading fails.
    """
    if entry is None:
        if filepath is None:
            raise ValueError("Either entry or filepath must be provided.")
        entry = find_file(repo, filepath, case_insensitive)
        if entry is None:
            return None
    try:
        blob = repo[entry.id]
        blob_data: bytes = blob.data
        decoded_data = blob_data.decode(detect_encoding(blob_data))
        return decoded_data
    except (AttributeError, UnicodeDecodeError):
        return None 
[docs]
def resolve_redirects(url: str, timeout: int = 10) -> str:
    """
    Follow HTTP redirects until the final URL is reached.
    Parameters
    ----------
    url : str
        The starting URL to check.
    timeout : int, optional
        Timeout (in seconds) for each request, by default 10.
    Returns
    -------
    str
        The last non-redirect URL.
    """
    try:
        # Use allow_redirects=True so requests will follow automatically
        response = requests.get(url, allow_redirects=True, timeout=timeout)
        final_url = response.url
        return final_url
    except requests.RequestException:
        # return the original URL if any error in redirection occurs
        return url 
[docs]
def get_remote_url(repo: pygit2.Repository) -> Optional[str]:
    """
    Determines the remote URL of a git repository, if available.
    We use the `upstream` remote first, then `origin`,
    and finally any other remote.
    The upstream remote is preferred because it will be used
    for referential data lookups (such as GitHub issues, stars, etc.).
    Args:
        repo (pygit2.Repository): The pygit2 repository object.
    Returns:
        Optional[str]: The remote URL if found, otherwise None.
    """
    # use upstream and then origin to try and find the correct remote URL
    for name in ("upstream", "origin"):
        try:
            # Get the 'origin' remote URL (common convention)
            remote = repo.remotes[name]
            remote_url = remote.url.removesuffix(".git")
            # Validate the URL structure using urlparse
            parsed_url = urlparse(remote_url)
            if parsed_url.scheme in {"http", "https", "ssh"} and parsed_url.netloc:
                return resolve_redirects(url=remote_url)
        except (KeyError, AttributeError):
            # 'origin' remote does not exist or URL is not accessible
            pass
    # Fallback: Try to get any remote URL if 'origin' does not exist
    try:
        for remote in repo.remotes:
            remote_url = remote.url
            parsed_url = urlparse(remote_url)
            if parsed_url.scheme in {"http", "https", "ssh"} and parsed_url.netloc:
                return resolve_redirects(url=remote_url.removesuffix(".git"))
    except AttributeError:
        pass
    # Return None if no valid URL is found
    return None 
[docs]
def file_exists_in_repo(
    repo: pygit2.Repository,
    expected_file_name: str,
    check_extension: bool = False,
    extensions: list[str] = [".md", ".txt", ".rtf", ""],
) -> bool:
    """
    Check if a file (case-insensitive and with optional extensions)
    exists in the latest commit of the repository.
    Args:
        repo (pygit2.Repository):
            The repository object to search in.
        expected_file_name (str):
            The base file name to check (e.g., "readme").
        check_extension (bool):
            Whether to check the extension of the file or not.
        extensions (list[str]):
            List of possible file extensions to check (e.g., [".md", ""]).
    Returns:
        bool:
            True if the file exists, False otherwise.
    """
    # Gather a tree from the HEAD of the repo
    tree = repo.revparse_single("HEAD").tree
    # Normalize expected file name to lowercase for case-insensitive comparison
    expected_file_name = expected_file_name.lower()
    for entry in tree:
        # Normalize entry name to lowercase
        entry_name = entry.name.lower()
        # Check if the base file name matches with any allowed extension
        if check_extension and any(
            entry_name == f"{expected_file_name}{ext.lower()}" for ext in extensions
        ):
            return True
        # Check whether the filename without an extension matches the expected file name
        if not check_extension and entry_name.split(".", 1)[0] == expected_file_name:
            return True
    return False