From c2c0325384d259162e167a6da78ee179bd16f715 Mon Sep 17 00:00:00 2001 From: Bonnie <41487185+estrogently@users.noreply.github.com> Date: Fri, 3 Jan 2020 23:39:12 +0100 Subject: [PATCH] Fix #1383: "Save with Sudo" rewrite (#1424) * Rewrite save with sudo (Fixes #1383) * Combine overrideFile & overrideFileAsRoot into 1 function --- internal/buffer/backup.go | 2 +- internal/buffer/save.go | 105 ++++++++++++----------------------- internal/buffer/serialize.go | 2 +- 3 files changed, 36 insertions(+), 73 deletions(-) diff --git a/internal/buffer/backup.go b/internal/buffer/backup.go index 8a2e24be..192ab638 100644 --- a/internal/buffer/backup.go +++ b/internal/buffer/backup.go @@ -71,7 +71,7 @@ func (b *Buffer) Backup(checkTime bool) error { } } return - }) + }, false) return err } diff --git a/internal/buffer/save.go b/internal/buffer/save.go index 8bf04259..990e2f25 100644 --- a/internal/buffer/save.go +++ b/internal/buffer/save.go @@ -27,76 +27,42 @@ const LargeFileThreshold = 50000 // overwriteFile opens the given file for writing, truncating if one exists, and then calls // the supplied function with the file as io.Writer object, also making sure the file is // closed afterwards. -func overwriteFile(name string, enc encoding.Encoding, fn func(io.Writer) error) (err error) { - var file *os.File +func overwriteFile(name string, enc encoding.Encoding, fn func(io.Writer) error, withSudo bool) (err error) { + var writeCloser io.WriteCloser - if file, err = os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); err != nil { - return - } + if withSudo { + cmd := exec.Command(config.GlobalSettings["sucmd"].(string), "dd", "bs=4k", "of="+name) - defer func() { - if e := file.Close(); e != nil && err == nil { - err = e - } - }() + if writeCloser, err = cmd.StdinPipe(); err != nil { + return + } - w := transform.NewWriter(file, enc.NewEncoder()) - // w := bufio.NewWriter(file) + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + <-c + cmd.Process.Kill() + }() - if err = fn(w); err != nil { - return - } + defer func() { + screenb := screen.TempFini() + if e := cmd.Run(); e != nil && err == nil { + err = e + } + screen.TempStart(screenb) + }() + } else if writeCloser, err = os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644); err != nil { + return + } - // err = w.Flush() - return -} + w := transform.NewWriter(writeCloser, enc.NewEncoder()) + err = fn(w) -// overwriteFileAsRoot executes dd as root and then calls the supplied function -// with dd's standard input as an io.Writer object. Dd opens the given file for writing, -// truncating it if it exists, and writes what it receives on its standard input to the file. -func overwriteFileAsRoot(name string, enc encoding.Encoding, fn func(io.Writer) error) (err error) { - var cmd *exec.Cmd - if runtime.GOOS == "windows" { - return errors.New("Save with sudo not supported on Windows") - } else if runtime.GOOS == "darwin" { - cmd = exec.Command(config.GlobalSettings["sucmd"].(string), "dd", "bs=4k", "of="+name) - } else { - cmd = exec.Command(config.GlobalSettings["sucmd"].(string), "dd", "status=none", "bs=4K", "of="+name) - } - var stdin io.WriteCloser + if e := writeCloser.Close(); e != nil && err == nil { + err = e + } - screenb := screen.TempFini() - defer screen.TempStart(screenb) - - // This is a trap for Ctrl-C so that it doesn't kill micro - // Instead we trap Ctrl-C to kill the program we're running - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - go func() { - for range c { - cmd.Process.Kill() - } - }() - - if stdin, err = cmd.StdinPipe(); err != nil { - return - } - - if err = cmd.Start(); err != nil { - return - } - - e := fn(stdin) - - if err = stdin.Close(); err != nil { - return - } - - if err = cmd.Wait(); err != nil { - return - } - - return e + return } // Save saves the buffer to its default path @@ -125,6 +91,9 @@ func (b *Buffer) saveToFile(filename string, withSudo bool) error { if b.Type.Scratch { return errors.New("Cannot save scratch buffer") } + if withSudo && runtime.GOOS == "windows" { + return errors.New("Save with sudo not supported on Windows") + } b.UpdateRules() if b.Settings["rmtrailingws"].(bool) { @@ -208,14 +177,8 @@ func (b *Buffer) saveToFile(filename string, withSudo bool) error { return } - if withSudo { - err = overwriteFileAsRoot(absFilename, enc, fwriter) - } else { - err = overwriteFile(absFilename, enc, fwriter) - } - - if err != nil { - return err + if err = overwriteFile(absFilename, enc, fwriter, withSudo); err != nil { + return err } if !b.Settings["fastdirty"].(bool) { diff --git a/internal/buffer/serialize.go b/internal/buffer/serialize.go index 16c4b6bf..6bd08200 100644 --- a/internal/buffer/serialize.go +++ b/internal/buffer/serialize.go @@ -39,7 +39,7 @@ func (b *Buffer) Serialize() error { b.ModTime, }) return err - }) + }, false) } func (b *Buffer) Unserialize() error {