diff --git a/internal/buffer/backup.go b/internal/buffer/backup.go index dfbc698c..313761d9 100644 --- a/internal/buffer/backup.go +++ b/internal/buffer/backup.go @@ -65,23 +65,32 @@ func (b *Buffer) RequestBackup() { } } +func (b *Buffer) backupDir() string { + backupdir, err := util.ReplaceHome(b.Settings["backupdir"].(string)) + if backupdir == "" || err != nil { + backupdir = filepath.Join(config.ConfigDir, "backups") + } + return backupdir +} + +func (b *Buffer) keepBackup() bool { + return b.forceKeepBackup || b.Settings["permbackup"].(bool) +} + // Backup saves the current buffer to ConfigDir/backups func (b *Buffer) Backup() error { if !b.Settings["backup"].(bool) || b.Path == "" || b.Type != BTDefault { return nil } - backupdir, err := util.ReplaceHome(b.Settings["backupdir"].(string)) - if backupdir == "" || err != nil { - backupdir = filepath.Join(config.ConfigDir, "backups") - } + backupdir := b.backupDir() if _, err := os.Stat(backupdir); errors.Is(err, fs.ErrNotExist) { os.Mkdir(backupdir, os.ModePerm) } name := util.DetermineEscapePath(backupdir, b.AbsPath) - err = overwriteFile(name, encoding.Nop, func(file io.Writer) (e error) { + err := overwriteFile(name, encoding.Nop, func(file io.Writer) (e error) { b.Lock() defer b.Unlock() @@ -120,7 +129,7 @@ func (b *Buffer) Backup() error { // RemoveBackup removes any backup file associated with this buffer func (b *Buffer) RemoveBackup() { - if !b.Settings["backup"].(bool) || b.Settings["permbackup"].(bool) || b.Path == "" || b.Type != BTDefault { + if !b.Settings["backup"].(bool) || b.keepBackup() || b.Path == "" || b.Type != BTDefault { return } f := util.DetermineEscapePath(filepath.Join(config.ConfigDir, "backups"), b.AbsPath) diff --git a/internal/buffer/buffer.go b/internal/buffer/buffer.go index 64894fa6..eb8176cf 100644 --- a/internal/buffer/buffer.go +++ b/internal/buffer/buffer.go @@ -102,6 +102,7 @@ type SharedBuffer struct { diff map[int]DiffStatus requestedBackup bool + forceKeepBackup bool // ReloadDisabled allows the user to disable reloads if they // are viewing a file that is constantly changing diff --git a/internal/buffer/save.go b/internal/buffer/save.go index 29143e4d..6e6fc6e5 100644 --- a/internal/buffer/save.go +++ b/internal/buffer/save.go @@ -95,6 +95,50 @@ func overwriteFile(name string, enc encoding.Encoding, fn func(io.Writer) error, return } +func (b *Buffer) overwrite(name string, withSudo bool) (int, error) { + enc, err := htmlindex.Get(b.Settings["encoding"].(string)) + if err != nil { + return 0, err + } + + var size int + fwriter := func(file io.Writer) error { + if len(b.lines) == 0 { + return err + } + + // end of line + var eol []byte + if b.Endings == FFDos { + eol = []byte{'\r', '\n'} + } else { + eol = []byte{'\n'} + } + + // write lines + if size, err = file.Write(b.lines[0].data); err != nil { + return err + } + + for _, l := range b.lines[1:] { + if _, err = file.Write(eol); err != nil { + return err + } + if _, err = file.Write(l.data); err != nil { + return err + } + size += len(eol) + len(l.data) + } + return err + } + + if err = overwriteFile(name, enc, fwriter, withSudo); err != nil { + return size, err + } + + return size, err +} + // Save saves the buffer to its default path func (b *Buffer) Save() error { return b.SaveAs(b.Path) @@ -159,18 +203,29 @@ func (b *Buffer) saveToFile(filename string, withSudo bool, autoSave bool) error err = b.Serialize() }() - // Removes any tilde and replaces with the absolute path to home - absFilename, _ := util.ReplaceHome(filename) - - fileInfo, err := os.Stat(absFilename) - if err != nil && !errors.Is(err, fs.ErrNotExist) { + filename, err = util.ReplaceHome(filename) + if err != nil { return err } + + newFile := false + fileInfo, err := os.Stat(filename) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return err + } + newFile = true + } if err == nil && fileInfo.IsDir() { - return errors.New("Error: " + absFilename + " is a directory and cannot be saved") + return errors.New("Error: " + filename + " is a directory and cannot be saved") } if err == nil && !fileInfo.Mode().IsRegular() { - return errors.New("Error: " + absFilename + " is not a regular file and cannot be saved") + return errors.New("Error: " + filename + " is not a regular file and cannot be saved") + } + + absFilename, err := filepath.Abs(filename) + if err != nil { + return err } // Get the leading path to the file | "." is returned if there's no leading path provided @@ -190,49 +245,13 @@ func (b *Buffer) saveToFile(filename string, withSudo bool, autoSave bool) error } } - var fileSize int - - enc, err := htmlindex.Get(b.Settings["encoding"].(string)) + size, err := b.safeWrite(absFilename, withSudo, newFile) if err != nil { return err } - fwriter := func(file io.Writer) (e error) { - if len(b.lines) == 0 { - return - } - - // end of line - var eol []byte - if b.Endings == FFDos { - eol = []byte{'\r', '\n'} - } else { - eol = []byte{'\n'} - } - - // write lines - if fileSize, e = file.Write(b.lines[0].data); e != nil { - return - } - - for _, l := range b.lines[1:] { - if _, e = file.Write(eol); e != nil { - return - } - if _, e = file.Write(l.data); e != nil { - return - } - fileSize += len(eol) + len(l.data) - } - return - } - - if err = overwriteFile(absFilename, enc, fwriter, withSudo); err != nil { - return err - } - if !b.Settings["fastdirty"].(bool) { - if fileSize > LargeFileThreshold { + if size > LargeFileThreshold { // For large files 'fastdirty' needs to be on b.Settings["fastdirty"] = true } else { @@ -241,9 +260,47 @@ func (b *Buffer) saveToFile(filename string, withSudo bool, autoSave bool) error } b.Path = filename - absPath, _ := filepath.Abs(filename) - b.AbsPath = absPath + b.AbsPath = absFilename b.isModified = false b.ReloadSettings(true) return err } + +// safeWrite writes the buffer to a file in a "safe" way, preventing loss of the +// contents of the file if it fails to write the new contents. +// This means that the file is not overwritten directly but by writing to the +// backup file first. +func (b *Buffer) safeWrite(path string, withSudo bool, newFile bool) (int, error) { + backupDir := b.backupDir() + if _, err := os.Stat(backupDir); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return 0, err + } + if err = os.Mkdir(backupDir, os.ModePerm); err != nil { + return 0, err + } + } + + backupName := util.DetermineEscapePath(backupDir, path) + _, err := b.overwrite(backupName, false) + if err != nil { + os.Remove(backupName) + return 0, err + } + + b.forceKeepBackup = true + size, err := b.overwrite(path, withSudo) + if err != nil { + if newFile { + os.Remove(path) + } + return size, err + } + b.forceKeepBackup = false + + if !b.keepBackup() { + os.Remove(backupName) + } + + return size, err +}