Harden backup paths and backup handling

This commit is contained in:
chenxiangtong
2026-06-05 20:26:24 +08:00
parent d0103519d4
commit 9cc35b9aac
4 changed files with 437 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
)
@@ -21,7 +22,10 @@ func executeAction(versionDir string, action Action, index int, backupDir string
return executeAdd(versionDir, action, index, backupDir, prefix)
case "delete":
return func() tea.Msg {
absPath := filepath.Join(versionDir, action.Path)
absPath, _, err := safeJoin(versionDir, action.Path)
if err != nil {
return actionErrorMsg{index: index, err: err}
}
if err := backupPath(versionDir, action.Path, backupDir); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("backup failed: %w", err)}
}
@@ -32,13 +36,22 @@ func executeAction(versionDir string, action Action, index int, backupDir string
}
case "copy":
return func() tea.Msg {
dst := filepath.Join(versionDir, action.NewPath)
dst, _, err := safeJoin(versionDir, action.NewPath)
if err != nil {
return actionErrorMsg{index: index, err: err}
}
src, _, err := safeJoin(versionDir, action.Path)
if err != nil {
return actionErrorMsg{index: index, err: err}
}
if _, err := validateCopyPath(src, dst); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("copy failed: %w", err)}
}
if _, err := os.Stat(dst); err == nil {
if err := backupPath(versionDir, action.NewPath, backupDir); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("backup failed: %w", err)}
}
}
src := filepath.Join(versionDir, action.Path)
if err := copyPath(src, dst); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("copy failed: %w", err)}
}
@@ -46,8 +59,14 @@ func executeAction(versionDir string, action Action, index int, backupDir string
}
case "move":
return func() tea.Msg {
src := filepath.Join(versionDir, action.Path)
dst := filepath.Join(versionDir, action.NewPath)
src, _, err := safeJoin(versionDir, action.Path)
if err != nil {
return actionErrorMsg{index: index, err: err}
}
dst, _, err := safeJoin(versionDir, action.NewPath)
if err != nil {
return actionErrorMsg{index: index, err: err}
}
if _, err := os.Stat(dst); err == nil {
if err := backupPath(versionDir, action.NewPath, backupDir); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("backup failed: %w", err)}
@@ -60,13 +79,18 @@ func executeAction(versionDir string, action Action, index int, backupDir string
if err := copyPath(src, dst); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("move failed: %w", err)}
}
os.RemoveAll(src)
if err := os.RemoveAll(src); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("cleanup failed: %w", err)}
}
}
return actionCompleteMsg{index: index}
}
case "new":
return func() tea.Msg {
absPath := filepath.Join(versionDir, action.Path)
absPath, _, err := safeJoin(versionDir, action.Path)
if err != nil {
return actionErrorMsg{index: index, err: err}
}
if _, err := os.Stat(absPath); err == nil {
return actionCompleteMsg{index: index}
}
@@ -95,7 +119,10 @@ func executeAction(versionDir string, action Action, index int, backupDir string
func executeAdd(versionDir string, action Action, index int, backupDir, prefix string) tea.Cmd {
return func() tea.Msg {
absPath := filepath.Join(versionDir, action.Path)
absPath, _, err := safeJoin(versionDir, action.Path)
if err != nil {
return actionErrorMsg{index: index, err: err}
}
if err := os.MkdirAll(filepath.Dir(absPath), 0o755); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("mkdir failed: %w", err)}
@@ -107,7 +134,7 @@ func executeAdd(versionDir string, action Action, index int, backupDir, prefix s
}
}
err := downloadFile(action.URL, absPath)
err = downloadFile(action.URL, absPath)
if err != nil {
if len(action.Mirrors) > 0 {
return mirrorChoiceMsg{index: index, mirrors: action.Mirrors, action: action}
@@ -117,10 +144,12 @@ func executeAdd(versionDir string, action Action, index int, backupDir, prefix s
if action.Unzip {
destDir := filepath.Dir(absPath)
if err := unzipFile(absPath, destDir); err != nil {
if err := unzipFile(absPath, destDir, versionDir, backupDir); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("unzip failed: %w", err)}
}
os.Remove(absPath)
if err := os.Remove(absPath); err != nil {
return actionErrorMsg{index: index, err: fmt.Errorf("cleanup failed: %w", err)}
}
}
return actionCompleteMsg{index: index}
@@ -161,25 +190,43 @@ func downloadFile(url, destPath string) error {
return err
}
return os.Rename(tmpPath, destPath)
if err := os.Rename(tmpPath, destPath); err != nil {
os.Remove(tmpPath)
return err
}
return nil
}
func unzipFile(zipPath, destDir string) error {
func unzipFile(zipPath, destDir, versionDir, backupDir string) error {
r, err := zip.OpenReader(zipPath)
if err != nil {
return err
}
defer r.Close()
for _, f := range r.File {
target := filepath.Join(destDir, f.Name)
versionAbs, err := filepath.Abs(versionDir)
if err != nil {
return err
}
destAbs, err := filepath.Abs(destDir)
if err != nil {
return err
}
if !strings.HasPrefix(filepath.Clean(target), filepath.Clean(destDir)+string(os.PathSeparator)) {
continue
for _, f := range r.File {
if f.FileInfo().Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("refusing to extract symlink: %s", f.Name)
}
target, err := safeZipTarget(destAbs, f.Name)
if err != nil {
return err
}
if f.FileInfo().IsDir() {
os.MkdirAll(target, 0o755)
if err := os.MkdirAll(target, 0o755); err != nil {
return err
}
continue
}
@@ -187,12 +234,30 @@ func unzipFile(zipPath, destDir string) error {
return err
}
rel, err := filepath.Rel(versionAbs, target)
if err != nil {
return err
}
if rel == "." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || filepath.IsAbs(rel) {
return fmt.Errorf("zip target escapes version directory: %s", f.Name)
}
if isBackupRelativePath(rel) {
return fmt.Errorf("refusing to extract into backup directory: %s", f.Name)
}
if _, err := os.Stat(target); err == nil {
if err := backupPath(versionDir, rel, backupDir); err != nil {
return err
}
} else if !os.IsNotExist(err) {
return err
}
rc, err := f.Open()
if err != nil {
return err
}
outFile, err := os.Create(target)
outFile, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.FileInfo().Mode().Perm())
if err != nil {
rc.Close()
return err
@@ -208,21 +273,108 @@ func unzipFile(zipPath, destDir string) error {
return nil
}
func backupPath(versionDir, relativePath, backupDir string) error {
src := filepath.Join(versionDir, relativePath)
if _, err := os.Stat(src); os.IsNotExist(err) {
return nil
func safeJoin(versionDir, relativePath string) (string, string, error) {
if strings.TrimSpace(relativePath) == "" {
return "", "", fmt.Errorf("empty path is not allowed")
}
if filepath.IsAbs(relativePath) || filepath.VolumeName(relativePath) != "" ||
strings.HasPrefix(relativePath, "/") || strings.HasPrefix(relativePath, "\\") {
return "", "", fmt.Errorf("absolute path is not allowed: %s", relativePath)
}
dst := filepath.Join(versionDir, backupDir, relativePath)
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
cleanRel := filepath.Clean(relativePath)
if cleanRel == "." || cleanRel == ".." || strings.HasPrefix(cleanRel, ".."+string(os.PathSeparator)) {
return "", "", fmt.Errorf("path escapes version directory: %s", relativePath)
}
if isBackupRelativePath(cleanRel) {
return "", "", fmt.Errorf("refusing to operate on backup directory: %s", relativePath)
}
base, err := filepath.Abs(versionDir)
if err != nil {
return "", "", err
}
target, err := filepath.Abs(filepath.Join(base, cleanRel))
if err != nil {
return "", "", err
}
if !pathInside(base, target) {
return "", "", fmt.Errorf("path escapes version directory: %s", relativePath)
}
return target, cleanRel, nil
}
func safeZipTarget(destDir, zipName string) (string, error) {
if strings.TrimSpace(zipName) == "" {
return "", fmt.Errorf("empty zip entry name")
}
if filepath.IsAbs(zipName) || filepath.VolumeName(zipName) != "" ||
strings.HasPrefix(zipName, "/") || strings.HasPrefix(zipName, "\\") {
return "", fmt.Errorf("absolute zip entry is not allowed: %s", zipName)
}
cleanName := filepath.Clean(zipName)
if cleanName == "." || cleanName == ".." || strings.HasPrefix(cleanName, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("zip entry escapes target directory: %s", zipName)
}
target, err := filepath.Abs(filepath.Join(destDir, cleanName))
if err != nil {
return "", err
}
if !pathInside(destDir, target) {
return "", fmt.Errorf("zip entry escapes target directory: %s", zipName)
}
return target, nil
}
func isBackupRelativePath(relativePath string) bool {
clean := filepath.ToSlash(filepath.Clean(relativePath))
return clean == "amt/backup" || strings.HasPrefix(clean, "amt/backup/")
}
func pathInside(base, target string) bool {
rel, err := filepath.Rel(base, target)
if err != nil {
return false
}
return rel == "." || (rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) && !filepath.IsAbs(rel))
}
func newBackupDir() string {
return fmt.Sprintf("amt/backup/%s", time.Now().Format("20060102_150405_000000000"))
}
func backupPath(versionDir, relativePath, backupDir string) error {
src, cleanRel, err := safeJoin(versionDir, relativePath)
if err != nil {
return err
}
info, err := os.Stat(src)
info, err := os.Lstat(src)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return err
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("refusing to back up symlink: %s", relativePath)
}
backupRoot, err := backupRootPath(versionDir, backupDir)
if err != nil {
return err
}
dst := filepath.Join(backupRoot, cleanRel)
if info.IsDir() && pathInside(src, dst) {
return fmt.Errorf("refusing to back up %s into itself", relativePath)
}
dst, err = uniqueBackupDestination(dst)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
return err
}
if info.IsDir() {
return copyDir(src, dst)
@@ -230,8 +382,77 @@ func backupPath(versionDir, relativePath, backupDir string) error {
return copyFile(src, dst)
}
func backupRootPath(versionDir, backupDir string) (string, error) {
if strings.TrimSpace(backupDir) == "" {
return "", fmt.Errorf("empty backup directory")
}
if filepath.IsAbs(backupDir) {
return "", fmt.Errorf("absolute backup directory is not allowed: %s", backupDir)
}
cleanRel := filepath.Clean(backupDir)
cleanSlash := filepath.ToSlash(cleanRel)
if cleanSlash != "amt/backup" && !strings.HasPrefix(cleanSlash, "amt/backup/") {
return "", fmt.Errorf("backup directory must be under amt/backup: %s", backupDir)
}
base, err := filepath.Abs(versionDir)
if err != nil {
return "", err
}
root, err := filepath.Abs(filepath.Join(base, cleanRel))
if err != nil {
return "", err
}
if !pathInside(base, root) {
return "", fmt.Errorf("backup directory escapes version directory: %s", backupDir)
}
return root, nil
}
func uniqueBackupDestination(dst string) (string, error) {
if _, err := os.Lstat(dst); os.IsNotExist(err) {
return dst, nil
} else if err != nil {
return "", err
}
for i := 1; ; i++ {
candidate := fmt.Sprintf("%s.%d", dst, i)
if _, err := os.Lstat(candidate); os.IsNotExist(err) {
return candidate, nil
} else if err != nil {
return "", err
}
}
}
func validateCopyPath(src, dst string) (os.FileInfo, error) {
info, err := os.Lstat(src)
if err != nil {
return nil, err
}
if info.Mode()&os.ModeSymlink != 0 {
return nil, fmt.Errorf("refusing to copy symlink: %s", src)
}
srcAbs, err := filepath.Abs(src)
if err != nil {
return nil, err
}
dstAbs, err := filepath.Abs(dst)
if err != nil {
return nil, err
}
if srcAbs == dstAbs {
return nil, fmt.Errorf("source and destination are the same: %s", src)
}
if info.IsDir() && pathInside(srcAbs, dstAbs) {
return nil, fmt.Errorf("refusing to copy directory into itself: %s", src)
}
return info, nil
}
func copyPath(src, dst string) error {
info, err := os.Stat(src)
info, err := validateCopyPath(src, dst)
if err != nil {
return err
}
@@ -247,27 +468,54 @@ func copyPath(src, dst string) error {
}
func copyFile(src, dst string) error {
info, err := os.Lstat(src)
if err != nil {
return err
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("refusing to copy symlink: %s", src)
}
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode().Perm())
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, in)
return err
if _, err := io.Copy(out, in); err != nil {
out.Close()
return err
}
if err := out.Close(); err != nil {
return err
}
if err := os.Chmod(dst, info.Mode().Perm()); err != nil {
return err
}
return os.Chtimes(dst, info.ModTime(), info.ModTime())
}
func copyDir(src, dst string) error {
rootInfo, err := os.Lstat(src)
if err != nil {
return err
}
if rootInfo.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("refusing to copy symlink: %s", src)
}
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("refusing to copy symlink: %s", path)
}
rel, err := filepath.Rel(src, path)
if err != nil {
@@ -277,7 +525,10 @@ func copyDir(src, dst string) error {
target := filepath.Join(dst, rel)
if info.IsDir() {
return os.MkdirAll(target, 0o755)
if err := os.MkdirAll(target, info.Mode().Perm()); err != nil {
return err
}
return os.Chtimes(target, info.ModTime(), info.ModTime())
}
return copyFile(path, target)
})