Working MVP for sftp server
alex wennerberg alex@alexwennerberg.com
Fri, 26 Feb 2021 19:27:15 -0800
M
db.go
→
db.go
@@ -3,6 +3,7 @@
import ( "crypto/rand" "database/sql" + "fmt" "golang.org/x/crypto/bcrypt" "io" "io/ioutil"@@ -11,6 +12,7 @@ "os"
"path" "path/filepath" "sort" + "strings" "time" )@@ -23,6 +25,32 @@ if err != nil {
log.Fatal(err) } createTablesIfDNE() +} + +// returns nil if login OK, err otherwise +// log in with email or username +func checkLogin(name string, password string) (string, bool, error) { + row := DB.QueryRow("SELECT username, password_hash, active, admin FROM user where username = $1 OR email = $1", name) + var db_password []byte + var username string + var active bool + var isAdmin bool + err := row.Scan(&username, &db_password, &active, &isAdmin) + if err != nil { + if strings.Contains(err.Error(), "no rows") { + return username, isAdmin, fmt.Errorf("Username or email '" + name + "' does not exist") + } else { + return username, isAdmin, err + } + } + if db_password != nil && !active { + return username, isAdmin, fmt.Errorf("Your account is not active yet. Pending admin approval %v", c) + } + if bcrypt.CompareHashAndPassword(db_password, []byte(password)) == nil { + return username, isAdmin, nil + } else { + return username, isAdmin, fmt.Errorf("Invalid password") + } } func getAnalyticsDB() (*sql.DB, error) {
M
example-config.toml
→
example-config.toml
@@ -21,6 +21,11 @@ # SMTPServer = mail.goodsite.com:587
# SMTPUsername = myemail@coolplace.com # SMTPPassword = hunter2 +# Whether to enable user SFTP access +# experimental feature, enable at your own risk +EnableSFTP=true +HostKeyPath="id_rsa" # will be generated for you. Pub key at x.pub + # Templates and static files # Everything in the static subfolder will be served at / TemplatesDirectory="./templates"
M
http.go
→
http.go
@@ -360,36 +360,8 @@ } else if r.Method == "POST" {
r.ParseForm() name := strings.ToLower(r.Form.Get("username")) password := r.Form.Get("password") - row := DB.QueryRow("SELECT username, password_hash, active, admin FROM user where username = $1 OR email = $1", name) - var db_password []byte - var username string - var active bool - var isAdmin bool - err := row.Scan(&username, &db_password, &active, &isAdmin) + username, isAdmin, err := checkLogin(name, password) if err != nil { - if strings.Contains(err.Error(), "no rows") { - data := struct { - Error string - Config Config - }{"Username or email '" + name + "' does not exist", c} - w.WriteHeader(401) - t.ExecuteTemplate(w, "login.html", data) - return - } else { - serverError(w, err) - return - } - } - if db_password != nil && !active { - data := struct { - Error string - Config Config - }{"Your account is not active yet. Pending admin approval", c} - w.WriteHeader(401) - t.ExecuteTemplate(w, "login.html", data) - return - } - if bcrypt.CompareHashAndPassword(db_password, []byte(password)) == nil { log.Println("logged in") session, _ := SessionStore.Get(r, "cookie-session") session.Values["auth_user"] = username@@ -400,7 +372,7 @@ } else {
data := struct { Error string Config Config - }{"Invalid login or password", c} + }{err.Error(), c} w.WriteHeader(401) err := t.ExecuteTemplate(w, "login.html", data) if err != nil {
M
main.go
→
main.go
@@ -49,16 +49,19 @@ }
switch args[0] { case "serve": - runSFTPServer() // s1.StartAsync() wg := new(sync.WaitGroup) - wg.Add(2) + wg.Add(3) go func() { runHTTPServer() wg.Done() }() go func() { runGeminiServer() + wg.Done() + }() + go func() { + runSFTPServer() wg.Done() }() wg.Wait()
M
sftp.go
→
sftp.go
@@ -1,10 +1,12 @@
-// An example SFTP server implementation using the golang SSH package. -// Serves the whole filesystem visible to the user, and has a hard-coded username and password, -// so not for real use! +// SFTP server for users with Flounder accounts +// A lot of this is copied from SFTPGo, but simplified for our use case. package main import ( - "flag" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "fmt" "io" "io/ioutil"@@ -13,7 +15,9 @@ "net"
"os" "path" "path/filepath" + "runtime/debug" "strings" + "time" "github.com/pkg/sftp" "golang.org/x/crypto/ssh"@@ -35,7 +39,6 @@ }
func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { // check user perms -- cant write others files - // check if file is inside your directory -- strings prefix? fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath)) userDir := getUserDirectory(con.User) // NOTE -- not cross platform if strings.HasPrefix(fullpath, userDir) {@@ -51,6 +54,9 @@ }
func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath)) + if strings.Contains(request.Filepath, ".hidden") { + return nil, fmt.Errorf("Invalid permissions") // TODO fix better + } switch request.Method { case "List": f, err := os.Open(fullpath)@@ -72,8 +78,26 @@ }
return nil, fmt.Errorf("Invalid command") } -func (c *Connection) Filecmd(request *sftp.Request) error { +func (conn *Connection) Filecmd(request *sftp.Request) error { // remove, rename, setstat? find out + fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath)) + userDir := getUserDirectory(conn.User) // NOTE -- not cross platform + writePerms := strings.HasPrefix(fullpath, userDir) + var err error + if writePerms { + switch request.Method { + case "Remove": + err = os.Remove(fullpath) + case "Mkdir": + err = os.Mkdir(fullpath, 0755) + } + if err != nil { + return err + } + } else { + return fmt.Errorf("Unauthorized") + } + // Rename, Mkdir return nil }@@ -91,37 +115,40 @@ }
// Based on example server code from golang.org/x/crypto/ssh and server_standalone func runSFTPServer() { - - var ( - readOnly bool - debugStderr bool - ) - - flag.BoolVar(&readOnly, "R", false, "read-only server") - flag.BoolVar(&debugStderr, "e", false, "debug to stderr") - flag.Parse() - - debugStream := ioutil.Discard - if debugStderr { - debugStream = os.Stderr + if !c.EnableSFTP { + return } - // An SSH server is represented by a ServerConfig, which holds // certificate details and handles authentication of ServerConns. config := &ssh.ServerConfig{ - // PublicKeyCallback PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { // Should use constant-time compare (or better, salt+hash) in // a production setting. - fmt.Fprintf(debugStream, "Login: %s\n", c.User()) - if c.User() == "alex" && string(pass) == "alex" { + if isOkUsername(c.User()) != nil { // extra check, probably unnecessary + return nil, fmt.Errorf("Invalid username") + } + _, _, err := checkLogin(c.User(), string(pass)) + // TODO maybe give admin extra permissions? + if err != nil { + return nil, fmt.Errorf("password rejected for %q", c.User()) + } else { + log.Printf("Login: %s\n", c.User()) return nil, nil } - return nil, fmt.Errorf("password rejected for %q", c.User()) }, } - privateBytes, err := ioutil.ReadFile("id_rsa") + // TODO generate key automatically + if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) { + // path/to/whatever does not exist + log.Println("Host key not found, generating host key") + err := GenerateRSAKeys() + if err != nil { + log.Fatal(err) + } + } + + privateBytes, err := ioutil.ReadFile(c.HostKeyPath) if err != nil { log.Fatal("Failed to load private key", err) }@@ -133,75 +160,136 @@ }
config.AddHostKey(private) - // Once a ServerConfig has been configured, connections can be - // accepted. listener, err := net.Listen("tcp", "0.0.0.0:2024") if err != nil { log.Fatal("failed to listen for connection", err) } - fmt.Printf("Listening on %v\n", listener.Addr()) - nConn, err := listener.Accept() - if err != nil { - log.Fatal("failed to accept incoming connection", err) + log.Printf("SFTP server listening on %v\n", listener.Addr()) + + for { + conn, err := listener.Accept() + if err != nil { + log.Fatal(err) + } + go acceptInboundConnection(conn, config) } +} + +func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { + defer func() { + if r := recover(); r != nil { + log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack())) + } + }() + ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String()) + log.Println("Request from IP " + ipAddr) + limiter := getVisitor(ipAddr) + if limiter.Allow() == false { + conn.Close() + return + } + // Before beginning a handshake must be performed on the incoming net.Conn + // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH + conn.SetDeadline(time.Now().Add(2 * time.Minute)) // Before use, a handshake must be performed on the incoming net.Conn. - sconn, chans, reqs, err := ssh.NewServerConn(nConn, config) + sconn, chans, reqs, err := ssh.NewServerConn(conn, config) if err != nil { - log.Fatal("failed to handshake", err) + log.Printf("failed to accept an incoming connection: %v", err) + return } log.Println("login detected:", sconn.User()) - fmt.Fprintf(debugStream, "SSH server established\n") + fmt.Fprintf(os.Stderr, "SSH server established\n") + // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on + conn.SetDeadline(time.Time{}) + + defer conn.Close() // The incoming Request channel must be serviced. go ssh.DiscardRequests(reqs) // Service the incoming Channel channel. + channelCounter := int64(0) for newChannel := range chans { // Channels have a type, depending on the application level // protocol intended. In the case of an SFTP session, this is "subsystem" // with a payload string of "<length=4>sftp" - fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType()) + fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType()) if newChannel.ChannelType() != "session" { newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") - fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType()) + fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType()) continue } channel, requests, err := newChannel.Accept() if err != nil { - log.Fatal("could not accept channel.", err) + log.Println("could not accept channel.", err) + continue } - fmt.Fprintf(debugStream, "Channel accepted\n") + + channelCounter++ + fmt.Fprintf(os.Stderr, "Channel accepted\n") // Sessions have out-of-band requests such as "shell", // "pty-req" and "env". Here we handle only the // "subsystem" request. go func(in <-chan *ssh.Request) { for req := range in { - fmt.Fprintf(debugStream, "Request: %v\n", req.Type) + fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type) ok := false switch req.Type { case "subsystem": - fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:]) + fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:]) if string(req.Payload[4:]) == "sftp" { ok = true } } - fmt.Fprintf(debugStream, " - accepted: %v\n", ok) + fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok) req.Reply(ok, nil) } }(requests) - connection := Connection{"alex"} + connection := Connection{sconn.User()} root := buildHandlers(&connection) server := sftp.NewRequestServer(channel, root) if err := server.Serve(); err == io.EOF { server.Close() - log.Print("sftp client exited session.") + log.Println("sftp client exited session.") } else if err != nil { - log.Fatal("sftp server completed with error:", err) + log.Println("sftp server completed with error:", err) + return } } +} + +// GenerateRSAKeys generate rsa private and public keys and write the +// private key to specified file and the public key to the specified +// file adding the .pub suffix +func GenerateRSAKeys() error { + key, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return err + } + + o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer o.Close() + + priv := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + if err := pem.Encode(o, priv); err != nil { + return err + } + + pub, err := ssh.NewPublicKey(&key.PublicKey) + if err != nil { + return err + } + return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600) } type listerat []os.FileInfo
M
utils.go
→
utils.go
@@ -6,6 +6,7 @@ "bufio"
"fmt" "io" "mime" + "net" "os" "path" "path/filepath"@@ -125,6 +126,13 @@ return "1 day ago"
} return fmt.Sprintf("%d days ago", days) } +} +func GetIPFromRemoteAddress(remoteAddress string) string { + ip, _, err := net.SplitHostPort(remoteAddress) + if err == nil { + return ip + } + return remoteAddress } // safe