src/app/functions.go (view raw)
1package app
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log"
8 "net/http"
9 "os"
10 "regexp"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/birabittoh/auth-boilerplate/src/email"
16)
17
18type HabitDisplay struct {
19 ID uint
20 Class string
21 Name string
22 LastAck string
23 Disabled bool
24}
25
26const (
27 minUsernameLength = 3
28 maxUsernameLength = 10
29
30 classGood = "good"
31 classWarn = "warn"
32 classBad = "bad"
33)
34
35var (
36 validUsername = regexp.MustCompile(`(?i)^[a-z0-9._-]+$`)
37 validEmail = regexp.MustCompile(`^[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,}$`)
38 validHabitName = regexp.MustCompile(`(?i)^[a-z0-9._,\s)(-]+$`)
39)
40
41func getUserByName(username string, excluding uint) (user User, err error) {
42 err = db.Model(&User{}).Where("upper(username) == upper(?) AND id != ?", username, excluding).First(&user).Error
43 return
44}
45
46func sanitizeUsername(username string) (string, error) {
47 if !validUsername.MatchString(username) || len(username) < minUsernameLength || len(username) > maxUsernameLength {
48 return "", errors.New("invalid username")
49 }
50
51 return username, nil
52}
53
54func sanitizeEmail(email string) (string, error) {
55 email = strings.ToLower(email)
56
57 if !validEmail.MatchString(email) {
58 return "", fmt.Errorf("invalid email")
59 }
60
61 return email, nil
62}
63
64func checkHabitName(name string) bool {
65 return len(name) < 50 && validHabitName.MatchString(name)
66}
67
68func login(w http.ResponseWriter, userID uint, remember bool) {
69 var duration time.Duration
70 if remember {
71 duration = durationWeek
72 } else {
73 duration = durationDay
74 }
75
76 cookie, err := g.GenerateCookie(duration)
77 if err != nil {
78 http.Error(w, "Could not generate session cookie.", http.StatusInternalServerError)
79 }
80
81 ks.Set("session:"+cookie.Value, userID, duration)
82 http.SetCookie(w, cookie)
83}
84
85func loadEmailConfig() *email.Client {
86 address := os.Getenv("APP_SMTP_EMAIL")
87 password := os.Getenv("APP_SMTP_PASSWORD")
88 host := os.Getenv("APP_SMTP_HOST")
89 port := os.Getenv("APP_SMTP_PORT")
90
91 if address == "" || password == "" || host == "" {
92 log.Println("Missing email configuration.")
93 return nil
94 }
95
96 if port == "" {
97 port = "587"
98 }
99
100 return email.NewClient(address, password, host, port)
101}
102
103func sendEmail(mail email.Email) error {
104 if m == nil {
105 return errors.New("email client is not initialized")
106 }
107 return m.Send(mail)
108}
109
110func sendResetEmail(address, token string) {
111 resetURL := fmt.Sprintf("%s/reset-password-confirm?token=%s", baseUrl, token)
112 err := sendEmail(email.Email{
113 To: []string{address},
114 Subject: "Reset password",
115 Body: fmt.Sprintf("Use the following link to reset your password:\n%s", resetURL),
116 })
117 if err != nil {
118 log.Printf("Could not send reset email for %s. Link: %s", address, resetURL)
119 }
120}
121
122func readSessionCookie(r *http.Request) (userID *uint, err error) {
123 cookie, err := r.Cookie("session_token")
124 if err != nil {
125 return
126 }
127 return ks.Get("session:" + cookie.Value)
128}
129
130// Middleware to check if the user is logged in
131func loginRequired(next http.HandlerFunc) http.HandlerFunc {
132 return func(w http.ResponseWriter, r *http.Request) {
133 userID, err := readSessionCookie(r)
134 if err != nil {
135 http.Redirect(w, r, "/login", http.StatusFound)
136 return
137 }
138
139 ctx := context.WithValue(r.Context(), userContextKey, *userID)
140 next(w, r.WithContext(ctx))
141 }
142}
143
144func getLoggedUser(r *http.Request) (user User, ok bool) {
145 userID, ok := r.Context().Value(userContextKey).(uint)
146 if !ok {
147 return
148 }
149
150 if db.Find(&user, userID).Error != nil {
151 ok = true
152 }
153
154 return user, ok
155}
156
157func formatDuration(t time.Time) (days int, s string) {
158 days = int(time.Since(t).Hours()) / 24
159
160 switch {
161 case days == 0:
162 s = "Today"
163 case days == 1:
164 s = "Yesterday"
165 case days <= 7:
166 s = fmt.Sprintf("%d day(s) ago", days)
167 case days <= 30:
168 weeks := days / 7
169 remainingDays := days % 7
170 if remainingDays == 0 {
171 s = fmt.Sprintf("%d week(s) ago", weeks)
172 } else {
173 s = fmt.Sprintf("%d week(s), %d day(s) ago", weeks, remainingDays)
174 }
175 default:
176 months := days / 30
177 remainingDays := days % 30
178 if remainingDays == 0 {
179 s = fmt.Sprintf("%d month(s) ago", months)
180 } else {
181 s = fmt.Sprintf("%d month(s), %d day(s) ago", months, remainingDays)
182 }
183 }
184 return
185}
186
187func toHabitDisplay(habit Habit) (d HabitDisplay) {
188 if habit.LastAck != nil {
189 var days int
190 days, d.LastAck = formatDuration(*habit.LastAck)
191 d.Class = getClassForAck(habit, days)
192 } else {
193 d.LastAck = "-"
194 if habit.Negative {
195 d.Class = classGood
196 } else if !habit.Disabled {
197 d.Class = classBad
198 }
199 }
200
201 d.ID = habit.ID
202 d.Name = habit.Name
203 d.Disabled = habit.Disabled
204 return
205}
206
207func getClassForAck(habit Habit, days int) string {
208 if habit.Negative {
209 switch {
210 case days <= 1:
211 return classBad
212 case days <= 7:
213 return classWarn
214 default:
215 return classGood
216 }
217 }
218
219 if habit.Disabled {
220 return ""
221 }
222
223 switch {
224 case days <= 1:
225 return classGood
226 case days <= int(habit.Days):
227 return classWarn
228 default:
229 return classBad
230 }
231}
232
233func getHabit(id uint) (habit Habit, err error) {
234 err = db.Model(&Habit{}).Find(&habit, id).Error
235 return
236}
237
238func getHabitHelper(w http.ResponseWriter, r *http.Request) (habit Habit, err error) {
239 id := getID(r)
240 if id == 0 {
241 err = errors.New("no id")
242 http.Error(w, "bad request", http.StatusBadRequest)
243 return
244 }
245
246 user, ok := getLoggedUser(r)
247 if !ok {
248 err = errors.New("no logged user")
249 http.Error(w, "unauthorized", http.StatusUnauthorized)
250 return
251 }
252
253 habit, err = getHabit(id)
254 if err != nil {
255 http.Error(w, "not found", http.StatusNotFound)
256 return
257 }
258
259 if habit.UserID != user.ID {
260 err = errors.New("forbidden")
261 http.Error(w, "forbidden", http.StatusForbidden)
262 }
263 return
264}
265
266func getAllHabits(userID uint) (positives []HabitDisplay, negatives []HabitDisplay, err error) {
267 var habits []Habit
268 err = db.Model(&Habit{}).Where(&Habit{UserID: userID}).Find(&habits).Error
269 if err != nil {
270 return
271 }
272
273 for _, habit := range habits {
274 habitDisplay := toHabitDisplay(habit)
275 if habit.Negative {
276 negatives = append(negatives, habitDisplay)
277 } else {
278 positives = append(positives, habitDisplay)
279 }
280 }
281 return
282}
283
284func getID(r *http.Request) uint {
285 res, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
286 if err != nil {
287 return 0
288 }
289 return uint(res)
290}