main.go (view raw)
1package main
2
3import (
4 "crypto/rand"
5 "database/sql"
6 "flag"
7 "fmt"
8 "github.com/gorilla/sessions"
9 "io"
10 "io/ioutil"
11 "log"
12 mathrand "math/rand"
13 "mime"
14 "os"
15 "path"
16 "path/filepath"
17 "sort"
18 "strings"
19 "sync"
20 "time"
21)
22
23var c Config // global var to hold static configuration
24
25type File struct { // also folders
26 Creator string
27 Name string // includes folder
28 UpdatedTime time.Time
29 TimeAgo string
30 IsText bool
31 Children []*File
32 Host string
33}
34
35type User struct {
36 Username string
37 Email string
38 Active bool
39 Admin bool
40 CreatedAt int // timestamp
41}
42
43// returns in a random order
44func getActiveUserNames() ([]string, error) {
45 rows, err := DB.Query(`SELECT username from user WHERE active is true`)
46 if err != nil {
47 return nil, err
48 }
49 var users []string
50 for rows.Next() {
51 var user string
52 err = rows.Scan(&user)
53 if err != nil {
54 return nil, err
55 }
56 users = append(users, user)
57 }
58
59 dest := make([]string, len(users))
60 perm := mathrand.Perm(len(users))
61 for i, v := range perm {
62 dest[v] = users[i]
63 }
64 return dest, nil
65}
66
67func getUsers() ([]User, error) {
68 rows, err := DB.Query(`SELECT username, email, active, admin, created_at from user ORDER BY created_at DESC`)
69 if err != nil {
70 return nil, err
71 }
72 var users []User
73 for rows.Next() {
74 var user User
75 err = rows.Scan(&user.Username, &user.Email, &user.Active, &user.Admin, &user.CreatedAt)
76 if err != nil {
77 return nil, err
78 }
79 users = append(users, user)
80 }
81 return users, nil
82}
83
84// get the user-reltaive local path from the filespath
85// NOTE -- dont use on unsafe input ( I think )
86func getLocalPath(filesPath string) string {
87 l := len(strings.Split(c.FilesDirectory, "/"))
88 return strings.Join(strings.Split(filesPath, "/")[l+1:], "/")
89}
90
91func getCreator(filePath string) string {
92 l := len(strings.Split(c.FilesDirectory, "/"))
93 r := strings.Split(filePath, "/")[l]
94 fmt.Println(filePath, c.FilesDirectory, r)
95 return r
96}
97
98func getIndexFiles() ([]*File, error) { // cache this function
99 result := []*File{}
100 err := filepath.Walk(c.FilesDirectory, func(thepath string, info os.FileInfo, err error) error {
101 if err != nil {
102 log.Printf("Failure accessing a path %q: %v\n", thepath, err)
103 return err // think about
104 }
105 // make this do what it should
106 if !info.IsDir() {
107 creatorFolder := getCreator(thepath)
108 updatedTime := info.ModTime()
109 result = append(result, &File{
110 Name: getLocalPath(thepath),
111 Creator: path.Base(creatorFolder),
112 UpdatedTime: updatedTime,
113 TimeAgo: timeago(&updatedTime),
114 })
115 }
116 return nil
117 })
118 if err != nil {
119 return nil, err
120 }
121 sort.Slice(result, func(i, j int) bool {
122 return result[i].UpdatedTime.After(result[j].UpdatedTime)
123 })
124 if len(result) > 50 {
125 result = result[:50]
126 }
127 return result, nil
128} // todo clean up paths
129
130func getMyFilesRecursive(p string, creator string) ([]*File, error) {
131 result := []*File{}
132 files, err := ioutil.ReadDir(p)
133 if err != nil {
134 return nil, err
135 }
136 for _, file := range files {
137 isText := strings.HasPrefix(mime.TypeByExtension(path.Ext(file.Name())), "text")
138 fullPath := path.Join(p, file.Name())
139 localPath := getLocalPath(fullPath)
140 f := &File{
141 Name: localPath,
142 Creator: creator,
143 UpdatedTime: file.ModTime(),
144 IsText: isText,
145 Host: c.Host,
146 }
147 if file.IsDir() {
148 f.Children, err = getMyFilesRecursive(path.Join(p, file.Name()), creator)
149 }
150 result = append(result, f)
151 }
152 return result, nil
153}
154
155func createTablesIfDNE() {
156 _, err := DB.Exec(`CREATE TABLE IF NOT EXISTS user (
157 id INTEGER PRIMARY KEY NOT NULL,
158 username TEXT NOT NULL UNIQUE,
159 email TEXT NOT NULL UNIQUE,
160 password_hash TEXT NOT NULL,
161 active boolean NOT NULL DEFAULT false,
162 admin boolean NOT NULL DEFAULT false,
163 created_at INTEGER DEFAULT (strftime('%s', 'now'))
164);
165
166CREATE TABLE IF NOT EXISTS cookie_key (
167 value TEXT NOT NULL
168);`)
169 if err != nil {
170 log.Fatal(err)
171 }
172}
173
174// Generate a cryptographically secure key for the cookie store
175func generateCookieKeyIfDNE() []byte {
176 rows, err := DB.Query("SELECT value FROM cookie_key LIMIT 1")
177 defer rows.Close()
178 if err != nil {
179 log.Fatal(err)
180 }
181 if rows.Next() {
182 var cookie []byte
183 err := rows.Scan(&cookie)
184 if err != nil {
185 log.Fatal(err)
186 }
187 return cookie
188 } else {
189 k := make([]byte, 32)
190 _, err := io.ReadFull(rand.Reader, k)
191 if err != nil {
192 log.Fatal(err)
193 }
194 _, err = DB.Exec("insert into cookie_key values ($1)", k)
195 if err != nil {
196 log.Fatal(err)
197 }
198 return k
199 }
200}
201
202func main() {
203 configPath := flag.String("c", "flounder.toml", "path to config file") // doesnt work atm
204 flag.Parse()
205 args := flag.Args()
206 if len(args) < 1 {
207 fmt.Println("expected 'admin' or 'serve' subcommand")
208 os.Exit(1)
209 }
210
211 var err error
212 c, err = getConfig(*configPath)
213 if err != nil {
214 log.Fatal(err)
215 }
216 logFile, err := os.OpenFile(c.LogFile, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644)
217 if err != nil {
218 panic(err)
219 }
220 mw := io.MultiWriter(os.Stdout, logFile)
221 log.SetOutput(mw)
222
223 if c.HttpsEnabled {
224 _, err1 := os.Stat(c.TLSCertFile)
225 _, err2 := os.Stat(c.TLSKeyFile)
226 if os.IsNotExist(err1) || os.IsNotExist(err2) {
227 log.Fatal("Keyfile or certfile does not exist.")
228 }
229 }
230
231 // Generate session cookie key if does not exist
232 DB, err = sql.Open("sqlite3", c.DBFile)
233 if err != nil {
234 log.Fatal(err)
235 }
236
237 createTablesIfDNE()
238 cookie := generateCookieKeyIfDNE()
239 SessionStore = sessions.NewCookieStore(cookie)
240
241 switch args[0] {
242 case "serve":
243 wg := new(sync.WaitGroup)
244 wg.Add(2)
245 go func() {
246 runHTTPServer()
247 wg.Done()
248 }()
249 go func() {
250 runGeminiServer()
251 wg.Done()
252 }()
253 wg.Wait()
254 case "admin":
255 runAdminCommand()
256 }
257}