sftp.go (view raw)
1// SFTP server for users with Flounder accounts
2// A lot of this is copied from SFTPGo, but simplified for our use case.
3package main
4
5import (
6 "crypto/rand"
7 "crypto/rsa"
8 "crypto/x509"
9 "encoding/pem"
10 "fmt"
11 "io"
12 "io/ioutil"
13 "log"
14 "net"
15 "os"
16 "path"
17 "path/filepath"
18 "runtime/debug"
19 "time"
20
21 "github.com/pkg/sftp"
22 "golang.org/x/crypto/ssh"
23)
24
25type Connection struct {
26 User string
27}
28
29func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
30 // check user perms -- cant read others hidden files
31 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
32 f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
33 if err != nil {
34 return nil, err
35 }
36 return f, nil
37}
38
39func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
40 // check user perms -- cant write others files
41 userDir := getUserDirectory(con.User) // NOTE -- not cross platform
42 fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
43 f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
44 if err != nil {
45 return nil, err
46 }
47 return f, nil
48}
49
50func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
51 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
52 fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
53 switch request.Method {
54 case "List":
55 f, err := os.Open(fullpath)
56 if err != nil {
57 return nil, err
58 }
59 fileInfo, err := f.Readdir(-1)
60 if err != nil {
61 return nil, err
62 }
63 return listerat(fileInfo), nil
64 case "Stat":
65 stat, err := os.Stat(fullpath)
66 if err != nil {
67 return nil, err
68 }
69 return listerat([]os.FileInfo{stat}), nil
70 }
71 return nil, fmt.Errorf("Invalid command")
72}
73
74func (conn *Connection) Filecmd(request *sftp.Request) error {
75 // remove, rename, setstat? find out
76 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
77 fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
78 var err error
79 switch request.Method {
80 case "Remove":
81 err = os.Remove(fullpath)
82 case "Mkdir":
83 err = os.Mkdir(fullpath, 0755)
84 }
85 if err != nil {
86 return err
87 }
88 // Rename, Mkdir
89 return nil
90}
91
92// TODO hide hidden folders
93// Users have write persm on their files, read perms on all
94
95func buildHandlers(connection *Connection) sftp.Handlers {
96 return sftp.Handlers{
97 connection,
98 connection,
99 connection,
100 connection,
101 }
102}
103
104// Based on example server code from golang.org/x/crypto/ssh and server_standalone
105func runSFTPServer() {
106 if !c.EnableSFTP {
107 return
108 }
109 // An SSH server is represented by a ServerConfig, which holds
110 // certificate details and handles authentication of ServerConns.
111 config := &ssh.ServerConfig{
112 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
113 // Should use constant-time compare (or better, salt+hash) in
114 // a production setting.
115 if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
116 return nil, fmt.Errorf("Invalid username")
117 }
118 _, _, err := checkLogin(c.User(), string(pass))
119 // TODO maybe give admin extra permissions?
120 if err != nil {
121 return nil, fmt.Errorf("password rejected for %q", c.User())
122 } else {
123 log.Printf("Login: %s\n", c.User())
124 return nil, nil
125 }
126 },
127 }
128
129 // TODO generate key automatically
130 if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) {
131 // path/to/whatever does not exist
132 log.Println("Host key not found, generating host key")
133 err := GenerateRSAKeys()
134 if err != nil {
135 log.Fatal(err)
136 }
137 }
138
139 privateBytes, err := ioutil.ReadFile(c.HostKeyPath)
140 if err != nil {
141 log.Fatal("Failed to load private key", err)
142 }
143
144 private, err := ssh.ParsePrivateKey(privateBytes)
145 if err != nil {
146 log.Fatal("Failed to parse private key", err)
147 }
148
149 config.AddHostKey(private)
150
151 listener, err := net.Listen("tcp", "0.0.0.0:2024")
152 if err != nil {
153 log.Fatal("failed to listen for connection", err)
154 }
155
156 log.Printf("SFTP server listening on %v\n", listener.Addr())
157
158 for {
159 conn, err := listener.Accept()
160 if err != nil {
161 log.Fatal(err)
162 }
163 go acceptInboundConnection(conn, config)
164 }
165}
166
167func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
168 defer func() {
169 if r := recover(); r != nil {
170 log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
171 }
172 }()
173 ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
174 log.Println("Request from IP " + ipAddr)
175 limiter := getVisitor(ipAddr)
176 if limiter.Allow() == false {
177 conn.Close()
178 return
179 }
180 // Before beginning a handshake must be performed on the incoming net.Conn
181 // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
182 conn.SetDeadline(time.Now().Add(2 * time.Minute))
183
184 // Before use, a handshake must be performed on the incoming net.Conn.
185 sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
186 if err != nil {
187 log.Printf("failed to accept an incoming connection: %v", err)
188 return
189 }
190 log.Println("login detected:", sconn.User())
191 fmt.Fprintf(os.Stderr, "SSH server established\n")
192 // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
193 conn.SetDeadline(time.Time{})
194
195 defer conn.Close()
196
197 // The incoming Request channel must be serviced.
198 go ssh.DiscardRequests(reqs)
199
200 // Service the incoming Channel channel.
201 channelCounter := int64(0)
202 for newChannel := range chans {
203 // Channels have a type, depending on the application level
204 // protocol intended. In the case of an SFTP session, this is "subsystem"
205 // with a payload string of "<length=4>sftp"
206 fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
207 if newChannel.ChannelType() != "session" {
208 newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
209 fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
210 continue
211 }
212 channel, requests, err := newChannel.Accept()
213 if err != nil {
214 log.Println("could not accept channel.", err)
215 continue
216 }
217
218 channelCounter++
219 fmt.Fprintf(os.Stderr, "Channel accepted\n")
220
221 // Sessions have out-of-band requests such as "shell",
222 // "pty-req" and "env". Here we handle only the
223 // "subsystem" request.
224 go func(in <-chan *ssh.Request) {
225 for req := range in {
226 fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
227 ok := false
228 switch req.Type {
229 case "subsystem":
230 fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
231 if string(req.Payload[4:]) == "sftp" {
232 ok = true
233 }
234 }
235 fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
236 req.Reply(ok, nil)
237 }
238 }(requests)
239 connection := Connection{sconn.User()}
240 root := buildHandlers(&connection)
241 server := sftp.NewRequestServer(channel, root)
242 if err := server.Serve(); err == io.EOF {
243 server.Close()
244 log.Println("sftp client exited session.")
245 } else if err != nil {
246 log.Println("sftp server completed with error:", err)
247 return
248 }
249 }
250}
251
252// GenerateRSAKeys generate rsa private and public keys and write the
253// private key to specified file and the public key to the specified
254// file adding the .pub suffix
255func GenerateRSAKeys() error {
256 key, err := rsa.GenerateKey(rand.Reader, 4096)
257 if err != nil {
258 return err
259 }
260
261 o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
262 if err != nil {
263 return err
264 }
265 defer o.Close()
266
267 priv := &pem.Block{
268 Type: "RSA PRIVATE KEY",
269 Bytes: x509.MarshalPKCS1PrivateKey(key),
270 }
271
272 if err := pem.Encode(o, priv); err != nil {
273 return err
274 }
275
276 pub, err := ssh.NewPublicKey(&key.PublicKey)
277 if err != nil {
278 return err
279 }
280 return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
281}
282
283type listerat []os.FileInfo
284
285// Modeled after strings.Reader's ReadAt() implementation
286func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
287 var n int
288 if offset >= int64(len(f)) {
289 return 0, io.EOF
290 }
291 n = copy(ls, f[offset:])
292 if n < len(ls) {
293 return n, io.EOF
294 }
295 return n, nil
296}