# Copyright (c) 2009, Rotem Yaari # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of organization nor the # names of its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY Rotem Yaari ''AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL Rotem Yaari BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from collections import Sequence import re import os import subprocess import sys from . import branch from . import tag from . import commit from . import config from .files import ModifiedFile from . import ref from . import ref_container from . import remotes from .utils import quote_for_shell from .utils import CommandString as CMD #exceptions from .exceptions import CannotFindRepository from .exceptions import GitException from .exceptions import GitCommandFailedException from .exceptions import MergeConflict from .exceptions import NonexistentRefException BRANCH_ALIAS_MARKER = ' -> ' class Repository(ref_container.RefContainer): _git_command = None def setCommand(self, command): self._git_command = command ############################# internal methods ############################# _loggingEnabled = False def _getWorkingDirectory(self): return '.' def _logGitCommand(self, command, cwd): if self._loggingEnabled: print >> sys.stderr, ">>", command def enableLogging(self): self._loggingEnabled = True def disableLogging(self): self._loggingEnabled = False def _executeGitCommand(self, command, cwd = None): if cwd is None: cwd = self._getWorkingDirectory() command = '%s %s' % (self._git_command, str(command)) self._logGitCommand(command, cwd) returned = subprocess.Popen(command, shell = True, cwd = cwd, stdout = subprocess.PIPE, stderr = subprocess.PIPE) returned.wait() return returned def _executeGitCommandAssertSuccess(self, command, **kwargs): returned = self._executeGitCommand(command, **kwargs) assert returned.returncode is not None if returned.returncode != 0: raise GitCommandFailedException(kwargs.get('cwd', self._getWorkingDirectory()), command, returned) return returned def _getOutputAssertSuccess(self, command, **kwargs): return self._executeGitCommandAssertSuccess(command, **kwargs).stdout.read() def _getMergeBase(self, a, b): raise NotImplementedError() def getMergeBase(self, a, b): repo = self if isinstance(b, commit.Commit) and isinstance(b.repo, LocalRepository): repo = b.repo elif isinstance(a, commit.Commit) and isinstance(a.repo, LocalRepository): repo = a.repo return repo._getMergeBase(a, b) ############################## remote repositories ############################# class RemoteRepository(Repository): def __init__(self, url, command = 'git'): self.setCommand(command) super(RemoteRepository, self).__init__() self.url = url def _getRefs(self, prefix = ''): output = self._executeGitCommandAssertSuccess("ls-remote %s" % (self.url,)) for output_line in output.stdout: commit, refname = output_line.split() if refname.startswith(prefix): yield refname[len(prefix):], commit.strip() def _getRefsAsClass(self, prefix, cls): return [cls(self, ref) for ref, _ in self._getRefs(prefix)] def _getCommitByRefName(self, refname): sha_by_ref = dict(self._getRefs()) for prefix in 'refs/tags/', 'refs/heads/': sha = sha_by_ref.get(prefix + refname, None) if sha is not None: return commit.Commit(self, sha) raise NonexistentRefException("Cannot find ref name %r in %s" % (refname, self)) def getBranches(self): return self._getRefsAsClass('refs/heads/', branch.RemoteBranch) def getTags(self): return self._getRefsAsClass('refs/tags/', tag.RemoteTag) ############################## local repositories ############################## class LocalRepository(Repository): def __init__(self, path, command = 'git'): self.setCommand(command) super(LocalRepository, self).__init__() self.path = path self.config = config.GitConfiguration(self) self._version = None def __repr__(self): return "" % (self.path,) def _getWorkingDirectory(self): return self.path def _getCommitByHash(self, sha): return commit.Commit(self, sha) def _getCommitByRefName(self, name): return commit.Commit(self, self._getOutputAssertSuccess("rev-parse %s" % name).strip()) def _getCommitByPartialHash(self, sha): return self._getCommitByRefName(sha) def getGitVersion(self): if self._version is None: version_output = self._getOutputAssertSuccess("version") version_match = re.match(r"git\s+version\s+(\S+)[\s\(]?", version_output, re.I) if version_match is None: raise GitException("Cannot extract git version (unfamiliar output format %r?)" % version_output) self._version = version_match.group(1) return self._version ########################### Initializing a repository ########################## def init(self, bare = False): if not os.path.exists(self.path): os.mkdir(self.path) if not os.path.isdir(self.path): raise GitException("Cannot create repository in %s - " "not a directory" % self.path) self._executeGitCommandAssertSuccess("init %s" % ("--bare" if bare else "")) def _asURL(self, repo): if isinstance(repo, LocalRepository): repo = repo.path elif isinstance(repo, RemoteRepository): repo = repo.url elif not isinstance(repo, basestring): raise TypeError("Cannot clone from %r" % (repo,)) return repo def clone(self, repo): self._executeGitCommandAssertSuccess("clone %s %s" % (self._asURL(repo), self.path), cwd = ".") ########################### Querying repository refs ########################### def getBranches(self): returned = [] for git_branch_line in self._executeGitCommandAssertSuccess("branch").stdout: if git_branch_line.startswith("*"): git_branch_line = git_branch_line[1:] git_branch_line = git_branch_line.strip() if BRANCH_ALIAS_MARKER in git_branch_line: alias_name, aliased = git_branch_line.split(BRANCH_ALIAS_MARKER) returned.append(branch.LocalBranchAlias(self, alias_name, aliased)) else: returned.append(branch.LocalBranch(self, git_branch_line)) return returned def getTags(self): returned = [] for git_tag_line in self._executeGitCommandAssertSuccess("tag").stdout: returned.append(tag.LocalTag(self, git_tag_line.strip())) return returned def _getCommits(self, specs, includeMerges): command = "log --pretty=format:%%H %s" % specs if not includeMerges: command += " --no-merges" for c in self._executeGitCommandAssertSuccess(command).stdout: yield commit.Commit(self, c.strip()) def getCommits(self, start = None, end = "HEAD", includeMerges = True): spec = self._normalizeRefName(start or "") spec += ".." spec += self._normalizeRefName(end) return list(self._getCommits(spec, includeMerges = includeMerges)) def getCurrentBranch(self): #todo: improve this method of obtaining current branch for branch_name in self._executeGitCommandAssertSuccess("branch").stdout: branch_name = branch_name.strip() if not branch_name.startswith("*"): continue branch_name = branch_name[1:].strip() if branch_name == '(no branch)': return None return self.getBranchByName(branch_name) def getRemotes(self): config_dict = self.config.getDict() returned = [] for line in self._getOutputAssertSuccess("remote show -n").splitlines(): line = line.strip() returned.append(remotes.Remote(self, line, config_dict.get('remote.%s.url' % line.strip()))) return returned def getRemoteByName(self, name): return self._getByName(self.getRemotes, name) def _getMergeBase(self, a, b): if isinstance(a, ref.Ref): a = a.getHead() if isinstance(b, ref.Ref): b = b.getHead() returned = self._executeGitCommand("merge-base %s %s" % (a, b)) if returned.returncode == 0: return commit.Commit(self, returned.stdout.read().strip()) # make sure this is not a misc. error with git unused = self.getHead() return None ################################ Querying Status ############################### def containsCommit(self, commit): try: self._executeGitCommandAssertSuccess("log -1 %s" % (commit,)) except GitException: return False return True def getHead(self): return self._getCommitByRefName("HEAD") def _getFiles(self, *flags): flags = ["--exclude-standard"] + list(flags) return [f.strip() for f in self._getOutputAssertSuccess("ls-files %s" % (" ".join(flags))).splitlines()] def _getRawDiff(self, *flags, **options): match_statuses = options.pop('fileStatuses', None) if match_statuses is not None and not isinstance(match_statuses, Sequence): raise ValueError("matchedStatuses must be a sequence") if options: raise TypeError("Unknown arguments specified: %s" % ", ".join(options)) flags = " ".join(str(f) for f in flags) modified_files = [] for line in self._getOutputAssertSuccess("diff --raw %s" % flags).splitlines(): file_status = line.split()[-2] file_name = line.split()[-1] if match_statuses is None or file_status in match_statuses: modified_files.append(ModifiedFile(file_name)) return modified_files def getStagedFiles(self): if self.isInitialized(): return self._getRawDiff('--cached') return self._getFiles() def getUnchangedFiles(self): return self._getFiles() def getChangedFiles(self): return self._getRawDiff() def getDeletedFiles(self): return self._getRawDiff(fileStatuses = ['D']) def getUntrackedFiles(self): return self._getFiles("--others") def isInitialized(self): try: self.getHead() return True except GitException: return False def isValid(self): return os.path.isdir(os.path.join(self.path, ".git")) or \ (os.path.isfile(os.path.join(self.path, "HEAD")) and os.path.isdir(os.path.join(self.path, "objects"))) def isWorkingDirectoryClean(self): return not (self.getUntrackedFiles() or self.getChangedFiles() or self.getStagedFiles()) def __contains__(self, thing): if isinstance(thing, basestring) or isinstance(thing, commit.Commit): return self.containsCommit(thing) raise NotImplementedError() ################################ Staging content ############################### def add(self, path): self._executeGitCommandAssertSuccess("add %s" % quote_for_shell(path)) def delete(self, path, recursive = False, force = False): flags = "" if recursive: flags += "-r " if force: flags += "-f " self._executeGitCommandAssertSuccess("rm %s%s" % (flags, quote_for_shell(path))) def addAll(self): return self.add('.') ################################## Committing ################################## def _normalizeRefName(self, thing): if isinstance(thing, ref.Ref): thing = thing.getNormalizedName() return str(thing) def _deduceNewCommitFromCommitOutput(self, output): for pattern in [ # new-style commit pattern r"^\[\S+\s+(?:\(root-commit\)\s+)?(\S+)\]", ]: match = re.search(pattern, output) if match: return commit.Commit(self, match.group(1)) return None def commit(self, message, allowEmpty = False, commitAll = False): args = '' if commitAll: args = args + '--all' command = "commit %s -m %s" % (args, quote_for_shell(message)) if allowEmpty: command += " --allow-empty" output = self._getOutputAssertSuccess(command) return self._deduceNewCommitFromCommitOutput(output) ################################ Changing state ################################ def _createBranchOrTag(self, objname, name, startingPoint, returned_class): command = "%s %s " % (objname, name) if startingPoint is not None: command += self._normalizeRefName(startingPoint) self._executeGitCommandAssertSuccess(command) return returned_class(self, name) def createBranch(self, name, startingPoint = None): return self._createBranchOrTag('branch', name, startingPoint, branch.LocalBranch) def createTag(self, name, startingPoint = None): return self._createBranchOrTag('tag', name, startingPoint, tag.LocalTag) def checkout(self, thing = None, targetBranch = None, files = ()): if thing is None: thing = "" command = "checkout %s" % (self._normalizeRefName(thing),) if targetBranch is not None: command += " -b %s" % (targetBranch,) if files: command += " -- %s" % " ".join(files) self._executeGitCommandAssertSuccess(command) def mergeMultiple(self, srcs, allowFastForward = True, log = False, message = None): try: self._executeGitCommandAssertSuccess(CMD("merge", " ".join(self._normalizeRefName(src) for src in srcs), "--no-ff" if not allowFastForward else None, "--log" if log else None, ("-m \"%s\"" % message) if message is not None else None)) except GitCommandFailedException, e: # git-merge tends to ignore the stderr rule... output = e.stdout + e.stderr if 'conflict' in output.lower(): raise MergeConflict() raise def merge(self, src, *args, **kwargs): return self.mergeMultiple([src], *args, **kwargs) def _reset(self, flag, thing): command = "reset %s %s" % ( flag, self._normalizeRefName(thing)) self._executeGitCommandAssertSuccess(command) def resetSoft(self, thing = "HEAD"): return self._reset("--soft", thing) def resetHard(self, thing = "HEAD"): return self._reset("--hard", thing) def resetMixed(self, thing = "HEAD"): return self._reset("--mixed", thing) def _clean(self, flags): self._executeGitCommandAssertSuccess("clean -q " + flags) def cleanIgnoredFiles(self): """Cleans files that match the patterns in .gitignore""" return self._clean("-f -X") def cleanUntrackedFiles(self): return self._clean("-f -d") ################################# collaboration ################################ def addRemote(self, name, url): self._executeGitCommandAssertSuccess("remote add %s %s" % (name, url)) return remotes.Remote(self, name, url) def fetch(self, repo = None): command = "fetch" if repo is not None: command += " " command += self._asURL(repo) self._executeGitCommandAssertSuccess(command) def pull(self, repo = None): command = "pull" if repo is not None: command += " " command += self._asURL(repo) self._executeGitCommandAssertSuccess(command) def _getRefspec(self, fromBranch = None, toBranch = None, force = False): returned = "" if fromBranch is not None: returned += self._normalizeRefName(fromBranch) if returned or toBranch is not None: returned += ":" if toBranch is not None: if isinstance(toBranch, branch.RegisteredRemoteBranch): toBranch = toBranch.name returned += self._normalizeRefName(toBranch) if returned and force: returned = "+%s" % returned return returned def push(self, remote = None, fromBranch = None, toBranch = None, force = False): command = "push" #build push arguments refspec = self._getRefspec(toBranch = toBranch, fromBranch = fromBranch, force = force) if refspec and not remote: remote = "origin" if isinstance(remote, remotes.Remote): remote = remote.name elif isinstance(remote, RemoteRepository): remote = remote.url elif isinstance(remote, LocalRepository): remote = remote.path if remote is not None and not isinstance(remote, basestring): raise TypeError("Invalid type for 'remote' parameter: %s" % (type(remote),)) command = "push %s %s" % (remote if remote is not None else "", refspec) self._executeGitCommandAssertSuccess(command) def rebase(self, src): self._executeGitCommandAssertSuccess("rebase %s" % self._normalizeRefName(src)) #################################### Stashes ################################### def saveStash(self, name = None): command = "stash save" if name is not None: command += " %s" % name self._executeGitCommandAssertSuccess(command) def popStash(self, arg = None): command = "stash pop" if arg is not None: command += " %s" % arg self._executeGitCommandAssertSuccess(command) ################################# Configuration ################################ ################################### Shortcuts ################################## def clone(source, location): returned = LocalRepository(location) returned.clone(source) return returned def find_repository(): orig_path = path = os.path.realpath('.') drive, path = os.path.splitdrive(path) while path: current_path = os.path.join(drive, path) current_repo = LocalRepository(current_path) if current_repo.isValid(): return current_repo path, path_tail = os.path.split(current_path) if not path_tail: raise CannotFindRepository("Cannot find repository for %s" % (orig_path,))