util.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import git
  2. from git.exc import InvalidGitRepositoryError
  3. from git.config import GitConfigParser
  4. from io import BytesIO
  5. import weakref
  6. # typing -----------------------------------------------------------------------
  7. from typing import Any, Sequence, TYPE_CHECKING, Union
  8. from git.types import PathLike
  9. if TYPE_CHECKING:
  10. from .base import Submodule
  11. from weakref import ReferenceType
  12. from git.repo import Repo
  13. from git.refs import Head
  14. from git import Remote
  15. from git.refs import RemoteReference
  16. __all__ = (
  17. "sm_section",
  18. "sm_name",
  19. "mkhead",
  20. "find_first_remote_branch",
  21. "SubmoduleConfigParser",
  22. )
  23. # { Utilities
  24. def sm_section(name: str) -> str:
  25. """:return: section title used in .gitmodules configuration file"""
  26. return f'submodule "{name}"'
  27. def sm_name(section: str) -> str:
  28. """:return: name of the submodule as parsed from the section name"""
  29. section = section.strip()
  30. return section[11:-1]
  31. def mkhead(repo: "Repo", path: PathLike) -> "Head":
  32. """:return: New branch/head instance"""
  33. return git.Head(repo, git.Head.to_full_path(path))
  34. def find_first_remote_branch(remotes: Sequence["Remote"], branch_name: str) -> "RemoteReference":
  35. """Find the remote branch matching the name of the given branch or raise InvalidGitRepositoryError"""
  36. for remote in remotes:
  37. try:
  38. return remote.refs[branch_name]
  39. except IndexError:
  40. continue
  41. # END exception handling
  42. # END for remote
  43. raise InvalidGitRepositoryError("Didn't find remote branch '%r' in any of the given remotes" % branch_name)
  44. # } END utilities
  45. # { Classes
  46. class SubmoduleConfigParser(GitConfigParser):
  47. """
  48. Catches calls to _write, and updates the .gitmodules blob in the index
  49. with the new data, if we have written into a stream. Otherwise it will
  50. add the local file to the index to make it correspond with the working tree.
  51. Additionally, the cache must be cleared
  52. Please note that no mutating method will work in bare mode
  53. """
  54. def __init__(self, *args: Any, **kwargs: Any) -> None:
  55. self._smref: Union["ReferenceType[Submodule]", None] = None
  56. self._index = None
  57. self._auto_write = True
  58. super(SubmoduleConfigParser, self).__init__(*args, **kwargs)
  59. # { Interface
  60. def set_submodule(self, submodule: "Submodule") -> None:
  61. """Set this instance's submodule. It must be called before
  62. the first write operation begins"""
  63. self._smref = weakref.ref(submodule)
  64. def flush_to_index(self) -> None:
  65. """Flush changes in our configuration file to the index"""
  66. assert self._smref is not None
  67. # should always have a file here
  68. assert not isinstance(self._file_or_files, BytesIO)
  69. sm = self._smref()
  70. if sm is not None:
  71. index = self._index
  72. if index is None:
  73. index = sm.repo.index
  74. # END handle index
  75. index.add([sm.k_modules_file], write=self._auto_write)
  76. sm._clear_cache()
  77. # END handle weakref
  78. # } END interface
  79. # { Overridden Methods
  80. def write(self) -> None: # type: ignore[override]
  81. rval: None = super(SubmoduleConfigParser, self).write()
  82. self.flush_to_index()
  83. return rval
  84. # END overridden methods
  85. # } END classes