sftp.go (view raw)
1// SFTP server for users with Flounder accounts
2// A lot of this is copied from SFTPGo, but simplified for our use case.
3package main
4
5import (
6 "fmt"
7 "io"
8 "io/ioutil"
9 "log"
10 "net"
11 "os"
12 "path"
13 "path/filepath"
14 "runtime/debug"
15 "strings"
16 "time"
17
18 "github.com/pkg/sftp"
19 "golang.org/x/crypto/ssh"
20)
21
22type Connection struct {
23 User string
24}
25
26func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
27 // check user perms -- cant read others hidden files
28 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
29 f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
30 if err != nil {
31 return nil, err
32 }
33 return f, nil
34}
35
36func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
37 // check user perms -- cant write others files
38 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
39 userDir := getUserDirectory(con.User) // NOTE -- not cross platform
40 if strings.HasPrefix(fullpath, userDir) {
41 f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
42 if err != nil {
43 return nil, err
44 }
45 return f, nil
46 } else {
47 return nil, fmt.Errorf("Invalid permissions")
48 }
49}
50
51func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
52 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
53 if strings.Contains(request.Filepath, ".hidden") {
54 return nil, fmt.Errorf("Invalid permissions") // TODO fix better
55 }
56 switch request.Method {
57 case "List":
58 f, err := os.Open(fullpath)
59 if err != nil {
60 return nil, err
61 }
62 fileInfo, err := f.Readdir(-1)
63 if err != nil {
64 return nil, err
65 }
66 return listerat(fileInfo), nil
67 case "Stat":
68 stat, err := os.Stat(fullpath)
69 if err != nil {
70 return nil, err
71 }
72 return listerat([]os.FileInfo{stat}), nil
73 }
74 return nil, fmt.Errorf("Invalid command")
75}
76
77func (conn *Connection) Filecmd(request *sftp.Request) error {
78 // remove, rename, setstat? find out
79 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
80 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
81 writePerms := strings.HasPrefix(fullpath, userDir)
82 var err error
83 if writePerms {
84 switch request.Method {
85 case "Remove":
86 err = os.Remove(fullpath)
87 case "Mkdir":
88 err = os.Mkdir(fullpath, 0755)
89 }
90 if err != nil {
91 return err
92 }
93 } else {
94 return fmt.Errorf("Unauthorized")
95 }
96 // Rename, Mkdir
97 return nil
98}
99
100// TODO hide hidden folders
101// Users have write persm on their files, read perms on all
102
103func buildHandlers(connection *Connection) sftp.Handlers {
104 return sftp.Handlers{
105 connection,
106 connection,
107 connection,
108 connection,
109 }
110}
111
112// Based on example server code from golang.org/x/crypto/ssh and server_standalone
113func runSFTPServer() {
114 if !c.EnableSFTP {
115 return
116 }
117 // An SSH server is represented by a ServerConfig, which holds
118 // certificate details and handles authentication of ServerConns.
119 config := &ssh.ServerConfig{
120 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
121 // Should use constant-time compare (or better, salt+hash) in
122 // a production setting.
123 if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
124 return nil, fmt.Errorf("Invalid username")
125 }
126 _, _, err := checkLogin(c.User(), string(pass))
127 // TODO maybe give admin extra permissions?
128 if err != nil {
129 return nil, fmt.Errorf("password rejected for %q", c.User())
130 } else {
131 log.Printf("Login: %s\n", c.User())
132 return nil, nil
133 }
134 },
135 }
136
137 // TODO generate key automatically
138 privateBytes, err := ioutil.ReadFile("id_rsa")
139 if err != nil {
140 log.Fatal("Failed to load private key", err)
141 }
142
143 private, err := ssh.ParsePrivateKey(privateBytes)
144 if err != nil {
145 log.Fatal("Failed to parse private key", err)
146 }
147
148 config.AddHostKey(private)
149
150 listener, err := net.Listen("tcp", "0.0.0.0:2024")
151 if err != nil {
152 log.Fatal("failed to listen for connection", err)
153 }
154
155 log.Printf("SFTP server listening on %v\n", listener.Addr())
156
157 for {
158 conn, err := listener.Accept()
159 if err != nil {
160 log.Fatal(err)
161 }
162 go acceptInboundConnection(conn, config)
163 }
164}
165
166func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
167 defer func() {
168 if r := recover(); r != nil {
169 log.Println("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
170 }
171 }()
172 ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
173 log.Println("Request from IP " + ipAddr)
174 limiter := getVisitor(ipAddr)
175 if limiter.Allow() == false {
176 conn.Close()
177 return
178 }
179 // Before beginning a handshake must be performed on the incoming net.Conn
180 // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
181 conn.SetDeadline(time.Now().Add(2 * time.Minute))
182
183 // Before use, a handshake must be performed on the incoming net.Conn.
184 sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
185 if err != nil {
186 log.Printf("failed to accept an incoming connection: %v", err)
187 return
188 }
189 log.Println("login detected:", sconn.User())
190 fmt.Fprintf(os.Stderr, "SSH server established\n")
191 // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
192 conn.SetDeadline(time.Time{})
193
194 defer conn.Close()
195
196 // The incoming Request channel must be serviced.
197 go ssh.DiscardRequests(reqs)
198
199 // Service the incoming Channel channel.
200 channelCounter := int64(0)
201 for newChannel := range chans {
202 // Channels have a type, depending on the application level
203 // protocol intended. In the case of an SFTP session, this is "subsystem"
204 // with a payload string of "<length=4>sftp"
205 fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
206 if newChannel.ChannelType() != "session" {
207 newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
208 fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
209 continue
210 }
211 channel, requests, err := newChannel.Accept()
212 if err != nil {
213 log.Println("could not accept channel.", err)
214 continue
215 }
216
217 channelCounter++
218 fmt.Fprintf(os.Stderr, "Channel accepted\n")
219
220 // Sessions have out-of-band requests such as "shell",
221 // "pty-req" and "env". Here we handle only the
222 // "subsystem" request.
223 go func(in <-chan *ssh.Request) {
224 for req := range in {
225 fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
226 ok := false
227 switch req.Type {
228 case "subsystem":
229 fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
230 if string(req.Payload[4:]) == "sftp" {
231 ok = true
232 }
233 }
234 fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
235 req.Reply(ok, nil)
236 }
237 }(requests)
238 connection := Connection{sconn.User()}
239 root := buildHandlers(&connection)
240 server := sftp.NewRequestServer(channel, root)
241 if err := server.Serve(); err == io.EOF {
242 server.Close()
243 log.Println("sftp client exited session.")
244 } else if err != nil {
245 log.Println("sftp server completed with error:", err)
246 return
247 }
248 }
249}
250
251type listerat []os.FileInfo
252
253// Modeled after strings.Reader's ReadAt() implementation
254func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
255 var n int
256 if offset >= int64(len(f)) {
257 return 0, io.EOF
258 }
259 n = copy(ls, f[offset:])
260 if n < len(ls) {
261 return n, io.EOF
262 }
263 return n, nil
264}