diff --git a/CLAUDE.md b/AGENTS.md similarity index 100% rename from CLAUDE.md rename to AGENTS.md diff --git a/actions.go b/actions.go index f70e703..82cff07 100644 --- a/actions.go +++ b/actions.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "strings" + "time" tea "github.com/charmbracelet/bubbletea" ) @@ -21,7 +22,10 @@ func executeAction(versionDir string, action Action, index int, backupDir string return executeAdd(versionDir, action, index, backupDir, prefix) case "delete": return func() tea.Msg { - absPath := filepath.Join(versionDir, action.Path) + absPath, _, err := safeJoin(versionDir, action.Path) + if err != nil { + return actionErrorMsg{index: index, err: err} + } if err := backupPath(versionDir, action.Path, backupDir); err != nil { return actionErrorMsg{index: index, err: fmt.Errorf("backup failed: %w", err)} } @@ -32,13 +36,22 @@ func executeAction(versionDir string, action Action, index int, backupDir string } case "copy": return func() tea.Msg { - dst := filepath.Join(versionDir, action.NewPath) + dst, _, err := safeJoin(versionDir, action.NewPath) + if err != nil { + return actionErrorMsg{index: index, err: err} + } + src, _, err := safeJoin(versionDir, action.Path) + if err != nil { + return actionErrorMsg{index: index, err: err} + } + if _, err := validateCopyPath(src, dst); err != nil { + return actionErrorMsg{index: index, err: fmt.Errorf("copy failed: %w", err)} + } if _, err := os.Stat(dst); err == nil { if err := backupPath(versionDir, action.NewPath, backupDir); err != nil { return actionErrorMsg{index: index, err: fmt.Errorf("backup failed: %w", err)} } } - src := filepath.Join(versionDir, action.Path) if err := copyPath(src, dst); err != nil { return actionErrorMsg{index: index, err: fmt.Errorf("copy failed: %w", err)} } @@ -46,8 +59,14 @@ func executeAction(versionDir string, action Action, index int, backupDir string } case "move": return func() tea.Msg { - src := filepath.Join(versionDir, action.Path) - dst := filepath.Join(versionDir, action.NewPath) + src, _, err := safeJoin(versionDir, action.Path) + if err != nil { + return actionErrorMsg{index: index, err: err} + } + dst, _, err := safeJoin(versionDir, action.NewPath) + if err != nil { + return actionErrorMsg{index: index, err: err} + } if _, err := os.Stat(dst); err == nil { if err := backupPath(versionDir, action.NewPath, backupDir); err != nil { return actionErrorMsg{index: index, err: fmt.Errorf("backup failed: %w", err)} @@ -60,13 +79,18 @@ func executeAction(versionDir string, action Action, index int, backupDir string if err := copyPath(src, dst); err != nil { return actionErrorMsg{index: index, err: fmt.Errorf("move failed: %w", err)} } - os.RemoveAll(src) + if err := os.RemoveAll(src); err != nil { + return actionErrorMsg{index: index, err: fmt.Errorf("cleanup failed: %w", err)} + } } return actionCompleteMsg{index: index} } case "new": return func() tea.Msg { - absPath := filepath.Join(versionDir, action.Path) + absPath, _, err := safeJoin(versionDir, action.Path) + if err != nil { + return actionErrorMsg{index: index, err: err} + } if _, err := os.Stat(absPath); err == nil { return actionCompleteMsg{index: index} } @@ -95,7 +119,10 @@ func executeAction(versionDir string, action Action, index int, backupDir string func executeAdd(versionDir string, action Action, index int, backupDir, prefix string) tea.Cmd { return func() tea.Msg { - absPath := filepath.Join(versionDir, action.Path) + absPath, _, err := safeJoin(versionDir, action.Path) + if err != nil { + return actionErrorMsg{index: index, err: err} + } if err := os.MkdirAll(filepath.Dir(absPath), 0o755); err != nil { return actionErrorMsg{index: index, err: fmt.Errorf("mkdir failed: %w", err)} @@ -107,7 +134,7 @@ func executeAdd(versionDir string, action Action, index int, backupDir, prefix s } } - err := downloadFile(action.URL, absPath) + err = downloadFile(action.URL, absPath) if err != nil { if len(action.Mirrors) > 0 { return mirrorChoiceMsg{index: index, mirrors: action.Mirrors, action: action} @@ -117,10 +144,12 @@ func executeAdd(versionDir string, action Action, index int, backupDir, prefix s if action.Unzip { destDir := filepath.Dir(absPath) - if err := unzipFile(absPath, destDir); err != nil { + if err := unzipFile(absPath, destDir, versionDir, backupDir); err != nil { return actionErrorMsg{index: index, err: fmt.Errorf("unzip failed: %w", err)} } - os.Remove(absPath) + if err := os.Remove(absPath); err != nil { + return actionErrorMsg{index: index, err: fmt.Errorf("cleanup failed: %w", err)} + } } return actionCompleteMsg{index: index} @@ -161,25 +190,43 @@ func downloadFile(url, destPath string) error { return err } - return os.Rename(tmpPath, destPath) + if err := os.Rename(tmpPath, destPath); err != nil { + os.Remove(tmpPath) + return err + } + return nil } -func unzipFile(zipPath, destDir string) error { +func unzipFile(zipPath, destDir, versionDir, backupDir string) error { r, err := zip.OpenReader(zipPath) if err != nil { return err } defer r.Close() - for _, f := range r.File { - target := filepath.Join(destDir, f.Name) + versionAbs, err := filepath.Abs(versionDir) + if err != nil { + return err + } + destAbs, err := filepath.Abs(destDir) + if err != nil { + return err + } - if !strings.HasPrefix(filepath.Clean(target), filepath.Clean(destDir)+string(os.PathSeparator)) { - continue + for _, f := range r.File { + if f.FileInfo().Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to extract symlink: %s", f.Name) + } + + target, err := safeZipTarget(destAbs, f.Name) + if err != nil { + return err } if f.FileInfo().IsDir() { - os.MkdirAll(target, 0o755) + if err := os.MkdirAll(target, 0o755); err != nil { + return err + } continue } @@ -187,12 +234,30 @@ func unzipFile(zipPath, destDir string) error { return err } + rel, err := filepath.Rel(versionAbs, target) + if err != nil { + return err + } + if rel == "." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || filepath.IsAbs(rel) { + return fmt.Errorf("zip target escapes version directory: %s", f.Name) + } + if isBackupRelativePath(rel) { + return fmt.Errorf("refusing to extract into backup directory: %s", f.Name) + } + if _, err := os.Stat(target); err == nil { + if err := backupPath(versionDir, rel, backupDir); err != nil { + return err + } + } else if !os.IsNotExist(err) { + return err + } + rc, err := f.Open() if err != nil { return err } - outFile, err := os.Create(target) + outFile, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.FileInfo().Mode().Perm()) if err != nil { rc.Close() return err @@ -208,21 +273,108 @@ func unzipFile(zipPath, destDir string) error { return nil } -func backupPath(versionDir, relativePath, backupDir string) error { - src := filepath.Join(versionDir, relativePath) - if _, err := os.Stat(src); os.IsNotExist(err) { - return nil +func safeJoin(versionDir, relativePath string) (string, string, error) { + if strings.TrimSpace(relativePath) == "" { + return "", "", fmt.Errorf("empty path is not allowed") + } + if filepath.IsAbs(relativePath) || filepath.VolumeName(relativePath) != "" || + strings.HasPrefix(relativePath, "/") || strings.HasPrefix(relativePath, "\\") { + return "", "", fmt.Errorf("absolute path is not allowed: %s", relativePath) } - dst := filepath.Join(versionDir, backupDir, relativePath) - if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + cleanRel := filepath.Clean(relativePath) + if cleanRel == "." || cleanRel == ".." || strings.HasPrefix(cleanRel, ".."+string(os.PathSeparator)) { + return "", "", fmt.Errorf("path escapes version directory: %s", relativePath) + } + if isBackupRelativePath(cleanRel) { + return "", "", fmt.Errorf("refusing to operate on backup directory: %s", relativePath) + } + + base, err := filepath.Abs(versionDir) + if err != nil { + return "", "", err + } + target, err := filepath.Abs(filepath.Join(base, cleanRel)) + if err != nil { + return "", "", err + } + if !pathInside(base, target) { + return "", "", fmt.Errorf("path escapes version directory: %s", relativePath) + } + return target, cleanRel, nil +} + +func safeZipTarget(destDir, zipName string) (string, error) { + if strings.TrimSpace(zipName) == "" { + return "", fmt.Errorf("empty zip entry name") + } + if filepath.IsAbs(zipName) || filepath.VolumeName(zipName) != "" || + strings.HasPrefix(zipName, "/") || strings.HasPrefix(zipName, "\\") { + return "", fmt.Errorf("absolute zip entry is not allowed: %s", zipName) + } + cleanName := filepath.Clean(zipName) + if cleanName == "." || cleanName == ".." || strings.HasPrefix(cleanName, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("zip entry escapes target directory: %s", zipName) + } + + target, err := filepath.Abs(filepath.Join(destDir, cleanName)) + if err != nil { + return "", err + } + if !pathInside(destDir, target) { + return "", fmt.Errorf("zip entry escapes target directory: %s", zipName) + } + return target, nil +} + +func isBackupRelativePath(relativePath string) bool { + clean := filepath.ToSlash(filepath.Clean(relativePath)) + return clean == "amt/backup" || strings.HasPrefix(clean, "amt/backup/") +} + +func pathInside(base, target string) bool { + rel, err := filepath.Rel(base, target) + if err != nil { + return false + } + return rel == "." || (rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) && !filepath.IsAbs(rel)) +} + +func newBackupDir() string { + return fmt.Sprintf("amt/backup/%s", time.Now().Format("20060102_150405_000000000")) +} + +func backupPath(versionDir, relativePath, backupDir string) error { + src, cleanRel, err := safeJoin(versionDir, relativePath) + if err != nil { return err } - - info, err := os.Stat(src) + info, err := os.Lstat(src) + if os.IsNotExist(err) { + return nil + } if err != nil { return err } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to back up symlink: %s", relativePath) + } + + backupRoot, err := backupRootPath(versionDir, backupDir) + if err != nil { + return err + } + dst := filepath.Join(backupRoot, cleanRel) + if info.IsDir() && pathInside(src, dst) { + return fmt.Errorf("refusing to back up %s into itself", relativePath) + } + dst, err = uniqueBackupDestination(dst) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } if info.IsDir() { return copyDir(src, dst) @@ -230,8 +382,77 @@ func backupPath(versionDir, relativePath, backupDir string) error { return copyFile(src, dst) } +func backupRootPath(versionDir, backupDir string) (string, error) { + if strings.TrimSpace(backupDir) == "" { + return "", fmt.Errorf("empty backup directory") + } + if filepath.IsAbs(backupDir) { + return "", fmt.Errorf("absolute backup directory is not allowed: %s", backupDir) + } + cleanRel := filepath.Clean(backupDir) + cleanSlash := filepath.ToSlash(cleanRel) + if cleanSlash != "amt/backup" && !strings.HasPrefix(cleanSlash, "amt/backup/") { + return "", fmt.Errorf("backup directory must be under amt/backup: %s", backupDir) + } + + base, err := filepath.Abs(versionDir) + if err != nil { + return "", err + } + root, err := filepath.Abs(filepath.Join(base, cleanRel)) + if err != nil { + return "", err + } + if !pathInside(base, root) { + return "", fmt.Errorf("backup directory escapes version directory: %s", backupDir) + } + return root, nil +} + +func uniqueBackupDestination(dst string) (string, error) { + if _, err := os.Lstat(dst); os.IsNotExist(err) { + return dst, nil + } else if err != nil { + return "", err + } + + for i := 1; ; i++ { + candidate := fmt.Sprintf("%s.%d", dst, i) + if _, err := os.Lstat(candidate); os.IsNotExist(err) { + return candidate, nil + } else if err != nil { + return "", err + } + } +} + +func validateCopyPath(src, dst string) (os.FileInfo, error) { + info, err := os.Lstat(src) + if err != nil { + return nil, err + } + if info.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("refusing to copy symlink: %s", src) + } + srcAbs, err := filepath.Abs(src) + if err != nil { + return nil, err + } + dstAbs, err := filepath.Abs(dst) + if err != nil { + return nil, err + } + if srcAbs == dstAbs { + return nil, fmt.Errorf("source and destination are the same: %s", src) + } + if info.IsDir() && pathInside(srcAbs, dstAbs) { + return nil, fmt.Errorf("refusing to copy directory into itself: %s", src) + } + return info, nil +} + func copyPath(src, dst string) error { - info, err := os.Stat(src) + info, err := validateCopyPath(src, dst) if err != nil { return err } @@ -247,27 +468,54 @@ func copyPath(src, dst string) error { } func copyFile(src, dst string) error { + info, err := os.Lstat(src) + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to copy symlink: %s", src) + } + in, err := os.Open(src) if err != nil { return err } defer in.Close() - out, err := os.Create(dst) + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode().Perm()) if err != nil { return err } - defer out.Close() - _, err = io.Copy(out, in) - return err + if _, err := io.Copy(out, in); err != nil { + out.Close() + return err + } + if err := out.Close(); err != nil { + return err + } + if err := os.Chmod(dst, info.Mode().Perm()); err != nil { + return err + } + return os.Chtimes(dst, info.ModTime(), info.ModTime()) } func copyDir(src, dst string) error { + rootInfo, err := os.Lstat(src) + if err != nil { + return err + } + if rootInfo.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to copy symlink: %s", src) + } + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { if err != nil { return err } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to copy symlink: %s", path) + } rel, err := filepath.Rel(src, path) if err != nil { @@ -277,7 +525,10 @@ func copyDir(src, dst string) error { target := filepath.Join(dst, rel) if info.IsDir() { - return os.MkdirAll(target, 0o755) + if err := os.MkdirAll(target, info.Mode().Perm()); err != nil { + return err + } + return os.Chtimes(target, info.ModTime(), info.ModTime()) } return copyFile(path, target) }) diff --git a/actions_test.go b/actions_test.go new file mode 100644 index 0000000..a9c38ca --- /dev/null +++ b/actions_test.go @@ -0,0 +1,147 @@ +package main + +import ( + "archive/zip" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestSafeJoinRejectsUnsafePaths(t *testing.T) { + versionDir := t.TempDir() + + unsafePaths := []string{ + "", + ".", + "/outside.txt", + "\\outside.txt", + "..", + filepath.Join("..", "outside.txt"), + filepath.Join("amt", "backup", "old.txt"), + } + + for _, p := range unsafePaths { + if _, _, err := safeJoin(versionDir, p); err == nil { + t.Fatalf("safeJoin(%q) returned nil error", p) + } + } + + target, cleanRel, err := safeJoin(versionDir, filepath.Join("mods", "ok.jar")) + if err != nil { + t.Fatalf("safeJoin valid path failed: %v", err) + } + if cleanRel != filepath.Join("mods", "ok.jar") { + t.Fatalf("unexpected clean path: %q", cleanRel) + } + if !strings.HasPrefix(target, versionDir) { + t.Fatalf("target %q is not under %q", target, versionDir) + } +} + +func TestBackupPathDoesNotOverwriteExistingBackup(t *testing.T) { + versionDir := t.TempDir() + src := filepath.Join(versionDir, "mods", "config.txt") + if err := os.MkdirAll(filepath.Dir(src), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(src, []byte("first"), 0o644); err != nil { + t.Fatal(err) + } + + backupDir := filepath.Join("amt", "backup", "run") + if err := backupPath(versionDir, filepath.Join("mods", "config.txt"), backupDir); err != nil { + t.Fatalf("first backup failed: %v", err) + } + if err := os.WriteFile(src, []byte("second"), 0o644); err != nil { + t.Fatal(err) + } + if err := backupPath(versionDir, filepath.Join("mods", "config.txt"), backupDir); err != nil { + t.Fatalf("second backup failed: %v", err) + } + + first, err := os.ReadFile(filepath.Join(versionDir, backupDir, "mods", "config.txt")) + if err != nil { + t.Fatal(err) + } + second, err := os.ReadFile(filepath.Join(versionDir, backupDir, "mods", "config.txt.1")) + if err != nil { + t.Fatal(err) + } + if string(first) != "first" || string(second) != "second" { + t.Fatalf("unexpected backups: first=%q second=%q", first, second) + } +} + +func TestBackupPathRejectsSelfNestedBackup(t *testing.T) { + versionDir := t.TempDir() + if err := os.MkdirAll(filepath.Join(versionDir, "amt", "data"), 0o755); err != nil { + t.Fatal(err) + } + + err := backupPath(versionDir, "amt", filepath.Join("amt", "backup", "run")) + if err == nil { + t.Fatal("backupPath allowed backing up a directory into itself") + } +} + +func TestUnzipBacksUpOverwrittenFiles(t *testing.T) { + versionDir := t.TempDir() + if err := os.WriteFile(filepath.Join(versionDir, "config.txt"), []byte("old"), 0o644); err != nil { + t.Fatal(err) + } + + zipPath := filepath.Join(versionDir, "pack.zip") + writeTestZip(t, zipPath, map[string]string{"config.txt": "new"}) + + backupDir := filepath.Join("amt", "backup", "run") + if err := unzipFile(zipPath, versionDir, versionDir, backupDir); err != nil { + t.Fatalf("unzipFile failed: %v", err) + } + + current, err := os.ReadFile(filepath.Join(versionDir, "config.txt")) + if err != nil { + t.Fatal(err) + } + backup, err := os.ReadFile(filepath.Join(versionDir, backupDir, "config.txt")) + if err != nil { + t.Fatal(err) + } + if string(current) != "new" || string(backup) != "old" { + t.Fatalf("unexpected files: current=%q backup=%q", current, backup) + } +} + +func TestUnzipRejectsBackupTarget(t *testing.T) { + versionDir := t.TempDir() + zipPath := filepath.Join(versionDir, "pack.zip") + writeTestZip(t, zipPath, map[string]string{"amt/backup/evil.txt": "bad"}) + + err := unzipFile(zipPath, versionDir, versionDir, filepath.Join("amt", "backup", "run")) + if err == nil { + t.Fatal("unzipFile allowed writing into backup directory") + } +} + +func writeTestZip(t *testing.T, zipPath string, files map[string]string) { + t.Helper() + + out, err := os.Create(zipPath) + if err != nil { + t.Fatal(err) + } + defer out.Close() + + w := zip.NewWriter(out) + defer w.Close() + + for name, body := range files { + f, err := w.Create(name) + if err != nil { + t.Fatal(err) + } + if _, err := f.Write([]byte(body)); err != nil { + t.Fatal(err) + } + } +} diff --git a/pages.go b/pages.go index 5267ca0..2d2ad43 100644 --- a/pages.go +++ b/pages.go @@ -2,8 +2,8 @@ package main import ( "fmt" + "path/filepath" "strings" - "time" "github.com/charmbracelet/bubbles/list" "github.com/charmbracelet/bubbles/progress" @@ -72,7 +72,7 @@ func updateMainMenu(m model, msg tea.Msg) (model, tea.Cmd) { return m, m.codeInput.Focus() case 1: m.currentPage = pageVersionSelect - m.logLines = nil // Bug2: 清除残留 logLines,否则版本列表被错误文本遮挡 + m.logLines = nil // Bug2: 清除残留 logLines,否则版本列表被错误文本遮挡 m.versionList.SetItems(nil) // Bug10: 清除旧列表,避免短暂显示过时数据 return m, scanVersions(m.exeDir) } @@ -155,7 +155,7 @@ func updateExecuting(m model, msg tea.Msg) (model, tea.Cmd) { m.logLines = append(m.logLines, successStyle.Render("没有需要执行的操作")) return m, nil } - m.backupDir = fmt.Sprintf("amt/backup/%s", time.Now().Format("20060102_150405")) + m.backupDir = newBackupDir() m.logLines = append(m.logLines, boldStyle.Render(fmt.Sprintf("共 %d 个操作", len(m.actions)))) m.logLines = append(m.logLines, describeAction(m.actions[0], 0, len(m.actions))) return m, tea.Batch(m.spinner.Tick, executeAction(m.versionDir, m.actions[0], 0, m.backupDir)) @@ -200,6 +200,9 @@ func updateExecuting(m model, msg tea.Msg) (model, tea.Cmd) { m.progressCh = nil m.execErr = msg.err m.logLines = append(m.logLines, fmt.Sprintf("%s [%d/%d] 失败: %s", crossMark, msg.index+1, len(m.actions), msg.err.Error())) + if m.backupDir != "" { + m.logLines = append(m.logLines, subtleStyle.Render("备份目录: "+filepath.Join(m.versionDir, m.backupDir))) + } return m, nil case mirrorChoiceMsg: