all repos — flounder @ ea5e67c4e14bd7fc12d4a64118252a2dbc3a1abb

A small site builder for the Gemini protocol

sftp.go (view raw)

  1// SFTP server for users with Flounder accounts
  2// 	A lot of this is copied from SFTPGo, but simplified for our use case.
  3package main
  4
  5import (
  6	"crypto/rand"
  7	"crypto/rsa"
  8	"crypto/x509"
  9	"encoding/pem"
 10	"fmt"
 11	"io"
 12	"io/ioutil"
 13	"log"
 14	"net"
 15	"os"
 16	"path"
 17	"path/filepath"
 18	"runtime/debug"
 19	"time"
 20
 21	"github.com/pkg/sftp"
 22	"golang.org/x/crypto/ssh"
 23)
 24
 25type Connection struct {
 26	User string
 27}
 28
 29func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 30	// check user perms -- cant read others hidden files
 31	userDir := getUserDirectory(con.User) // NOTE -- not cross platform
 32	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 33	f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
 34	if err != nil {
 35		return nil, err
 36	}
 37	return f, nil
 38}
 39
 40func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
 41	// check user perms -- cant write others files
 42	userDir := getUserDirectory(con.User) // NOTE -- not cross platform
 43	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 44	f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
 45	if err != nil {
 46		return nil, err
 47	}
 48	return f, nil
 49}
 50
 51func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
 52	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 53	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 54	switch request.Method {
 55	case "List":
 56		f, err := os.Open(fullpath)
 57		if err != nil {
 58			return nil, err
 59		}
 60		fileInfo, err := f.Readdir(-1)
 61		if err != nil {
 62			return nil, err
 63		}
 64		return listerat(fileInfo), nil
 65	case "Stat":
 66		stat, err := os.Stat(fullpath)
 67		if err != nil {
 68			return nil, err
 69		}
 70		return listerat([]os.FileInfo{stat}), nil
 71	}
 72	return nil, fmt.Errorf("Invalid command")
 73}
 74
 75func (conn *Connection) Filecmd(request *sftp.Request) error {
 76	// remove, rename, setstat? find out
 77	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 78	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 79	targetPath := path.Join(userDir, filepath.Clean(request.Target))
 80	var err error
 81	switch request.Method {
 82	case "Remove":
 83		err = os.Remove(fullpath)
 84	case "Mkdir":
 85		err = os.Mkdir(fullpath, 0755)
 86	case "Rename":
 87		err = os.Rename(fullpath, targetPath)
 88	}
 89	if err != nil {
 90		return err
 91	}
 92	// Rename, Mkdir
 93	return nil
 94}
 95
 96// TODO hide hidden folders
 97// Users have write persm on their files, read perms on all
 98
 99func buildHandlers(connection *Connection) sftp.Handlers {
100	return sftp.Handlers{
101		connection,
102		connection,
103		connection,
104		connection,
105	}
106}
107
108// Based on example server code from golang.org/x/crypto/ssh and server_standalone
109func runSFTPServer() {
110	if !c.EnableSFTP {
111		return
112	}
113	// An SSH server is represented by a ServerConfig, which holds
114	// certificate details and handles authentication of ServerConns.
115	config := &ssh.ServerConfig{
116		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
117			// Should use constant-time compare (or better, salt+hash) in
118			// a production setting.
119			if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
120				return nil, fmt.Errorf("Invalid username")
121			}
122			_, _, err := checkLogin(c.User(), string(pass))
123			// TODO maybe give admin extra permissions?
124			if err != nil {
125				return nil, fmt.Errorf("password rejected for %q", c.User())
126			} else {
127				log.Printf("Login: %s\n", c.User())
128				return nil, nil
129			}
130		},
131	}
132
133	// TODO generate key automatically
134	if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) {
135		// path/to/whatever does not exist
136		log.Println("Host key not found, generating host key")
137		err := GenerateRSAKeys()
138		if err != nil {
139			log.Fatal(err)
140		}
141	}
142
143	privateBytes, err := ioutil.ReadFile(c.HostKeyPath)
144	if err != nil {
145		log.Fatal("Failed to load private key", err)
146	}
147
148	private, err := ssh.ParsePrivateKey(privateBytes)
149	if err != nil {
150		log.Fatal("Failed to parse private key", err)
151	}
152
153	config.AddHostKey(private)
154
155	listener, err := net.Listen("tcp", "0.0.0.0:2024")
156	if err != nil {
157		log.Fatal("failed to listen for connection", err)
158	}
159
160	log.Printf("SFTP server listening on %v\n", listener.Addr())
161
162	for {
163		conn, err := listener.Accept()
164		if err != nil {
165			log.Fatal(err)
166		}
167		go acceptInboundConnection(conn, config)
168	}
169}
170
171func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
172	defer func() {
173		if r := recover(); r != nil {
174			log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
175		}
176	}()
177	ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
178	log.Println("Request from IP " + ipAddr)
179	limiter := getVisitor(ipAddr)
180	if limiter.Allow() == false {
181		conn.Close()
182		return
183	}
184	// Before beginning a handshake must be performed on the incoming net.Conn
185	// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
186	conn.SetDeadline(time.Now().Add(2 * time.Minute))
187
188	// Before use, a handshake must be performed on the incoming net.Conn.
189	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
190	if err != nil {
191		log.Printf("failed to accept an incoming connection: %v", err)
192		return
193	}
194	log.Println("login detected:", sconn.User())
195	fmt.Fprintf(os.Stderr, "SSH server established\n")
196	// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
197	conn.SetDeadline(time.Time{})
198
199	defer conn.Close()
200
201	// The incoming Request channel must be serviced.
202	go ssh.DiscardRequests(reqs)
203
204	// Service the incoming Channel channel.
205	channelCounter := int64(0)
206	for newChannel := range chans {
207		// Channels have a type, depending on the application level
208		// protocol intended. In the case of an SFTP session, this is "subsystem"
209		// with a payload string of "<length=4>sftp"
210		fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
211		if newChannel.ChannelType() != "session" {
212			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
213			fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
214			continue
215		}
216		channel, requests, err := newChannel.Accept()
217		if err != nil {
218			log.Println("could not accept channel.", err)
219			continue
220		}
221
222		channelCounter++
223		fmt.Fprintf(os.Stderr, "Channel accepted\n")
224
225		// Sessions have out-of-band requests such as "shell",
226		// "pty-req" and "env".  Here we handle only the
227		// "subsystem" request.
228		go func(in <-chan *ssh.Request) {
229			for req := range in {
230				fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
231				ok := false
232				switch req.Type {
233				case "subsystem":
234					fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
235					if string(req.Payload[4:]) == "sftp" {
236						ok = true
237					}
238				}
239				fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
240				req.Reply(ok, nil)
241			}
242		}(requests)
243		connection := Connection{sconn.User()}
244		root := buildHandlers(&connection)
245		server := sftp.NewRequestServer(channel, root)
246		if err := server.Serve(); err == io.EOF {
247			server.Close()
248			log.Println("sftp client exited session.")
249		} else if err != nil {
250			log.Println("sftp server completed with error:", err)
251			return
252		}
253	}
254}
255
256// GenerateRSAKeys generate rsa private and public keys and write the
257// private key to specified file and the public key to the specified
258// file adding the .pub suffix
259func GenerateRSAKeys() error {
260	key, err := rsa.GenerateKey(rand.Reader, 4096)
261	if err != nil {
262		return err
263	}
264
265	o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
266	if err != nil {
267		return err
268	}
269	defer o.Close()
270
271	priv := &pem.Block{
272		Type:  "RSA PRIVATE KEY",
273		Bytes: x509.MarshalPKCS1PrivateKey(key),
274	}
275
276	if err := pem.Encode(o, priv); err != nil {
277		return err
278	}
279
280	pub, err := ssh.NewPublicKey(&key.PublicKey)
281	if err != nil {
282		return err
283	}
284	return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
285}
286
287type listerat []os.FileInfo
288
289// Modeled after strings.Reader's ReadAt() implementation
290func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
291	var n int
292	if offset >= int64(len(f)) {
293		return 0, io.EOF
294	}
295	n = copy(ls, f[offset:])
296	if n < len(ls) {
297		return n, io.EOF
298	}
299	return n, nil
300}