src/app/functions.go (view raw)
1package app
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log"
8 "net/http"
9 "os"
10 "regexp"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/birabittoh/auth-boilerplate/src/email"
16)
17
18type HabitDisplay struct {
19 ID uint
20 Class string
21 Name string
22 LastAck string
23 Disabled bool
24}
25
26const (
27 minUsernameLength = 3
28 maxUsernameLength = 10
29
30 classGood = "good"
31 classWarn = "warn"
32 classBad = "bad"
33)
34
35var (
36 validUsername = regexp.MustCompile(`(?i)^[a-z0-9._-]+$`)
37 validEmail = regexp.MustCompile(`^[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,}$`)
38 validHabitName = regexp.MustCompile(`(?i)^[a-z0-9._,\s)(-]+$`)
39)
40
41func getUserByName(username string, excluding uint) (user User, err error) {
42 err = db.Model(&User{}).Where("upper(username) == upper(?) AND id != ?", username, excluding).First(&user).Error
43 return
44}
45
46func sanitizeUsername(username string) (string, error) {
47 if !validUsername.MatchString(username) || len(username) < minUsernameLength || len(username) > maxUsernameLength {
48 return "", errors.New("invalid username")
49 }
50
51 return username, nil
52}
53
54func sanitizeEmail(email string) (string, error) {
55 email = strings.ToLower(email)
56
57 if !validEmail.MatchString(email) {
58 return "", fmt.Errorf("invalid email")
59 }
60
61 return email, nil
62}
63
64func checkHabitName(name string) bool {
65 return len(name) < 50 && validHabitName.MatchString(name)
66}
67
68func login(w http.ResponseWriter, userID uint, remember bool) {
69 var duration time.Duration
70 if remember {
71 duration = durationWeek
72 } else {
73 duration = durationDay
74 }
75
76 cookie, err := g.GenerateCookie(duration)
77 if err != nil {
78 http.Error(w, "Could not generate session cookie.", http.StatusInternalServerError)
79 }
80
81 ks.Set("session:"+cookie.Value, userID, duration)
82 http.SetCookie(w, cookie)
83}
84
85func loadEmailConfig() *email.Client {
86 address := os.Getenv("APP_SMTP_EMAIL")
87 password := os.Getenv("APP_SMTP_PASSWORD")
88 host := os.Getenv("APP_SMTP_HOST")
89 port := os.Getenv("APP_SMTP_PORT")
90
91 if address == "" || password == "" || host == "" {
92 log.Println("Missing email configuration.")
93 return nil
94 }
95
96 if port == "" {
97 port = "587"
98 }
99
100 return email.NewClient(address, password, host, port)
101}
102
103func sendEmail(mail email.Email) error {
104 if m == nil {
105 return errors.New("email client is not initialized")
106 }
107 return m.Send(mail)
108}
109
110func sendResetEmail(address, token string) {
111 resetURL := fmt.Sprintf("%s/reset-password-confirm?token=%s", baseUrl, token)
112 err := sendEmail(email.Email{
113 To: []string{address},
114 Subject: "Reset password",
115 Body: fmt.Sprintf("Use the following link to reset your password:\n%s", resetURL),
116 })
117 if err != nil {
118 log.Printf("Could not send reset email for %s. Link: %s", address, resetURL)
119 }
120}
121
122func readSessionCookie(r *http.Request) (userID *uint, err error) {
123 cookie, err := r.Cookie("session_token")
124 if err != nil {
125 return
126 }
127 return ks.Get("session:" + cookie.Value)
128}
129
130// Middleware to check if the user is logged in
131func loginRequired(next http.HandlerFunc) http.HandlerFunc {
132 return func(w http.ResponseWriter, r *http.Request) {
133 userID, err := readSessionCookie(r)
134 if err != nil {
135 http.Redirect(w, r, "/login", http.StatusFound)
136 return
137 }
138
139 ctx := context.WithValue(r.Context(), userContextKey, *userID)
140 next(w, r.WithContext(ctx))
141 }
142}
143
144func getLoggedUser(r *http.Request) (user User, ok bool) {
145 userID, ok := r.Context().Value(userContextKey).(uint)
146 if !ok {
147 return
148 }
149
150 if db.Find(&user, userID).Error != nil {
151 ok = true
152 }
153
154 return user, ok
155}
156
157func formatDuration(t time.Time) string {
158 days := int(time.Since(t).Hours()) / 24
159
160 switch {
161 case days == 0:
162 return "Today"
163 case days == 1:
164 return "Yesterday"
165 case days <= 7:
166 return fmt.Sprintf("%d day(s) ago", days)
167 case days <= 30:
168 weeks := days / 7
169 remainingDays := days % 7
170 if remainingDays == 0 {
171 return fmt.Sprintf("%d week(s) ago", weeks)
172 }
173 return fmt.Sprintf("%d wee(k), %d day(s) ago", weeks, remainingDays)
174 default:
175 months := days / 30
176 remainingDays := days % 30
177 if remainingDays == 0 {
178 return fmt.Sprintf("%d month(s) ago", months)
179 }
180 return fmt.Sprintf("%d month(s), %d day(s) ago", months, remainingDays)
181 }
182}
183
184func toHabitDisplay(habit Habit) HabitDisplay {
185 var lastAck string
186 if habit.LastAck == nil {
187 lastAck = "-"
188 } else {
189 lastAck = formatDuration(*habit.LastAck)
190 }
191
192 return HabitDisplay{
193 ID: habit.ID,
194 Name: habit.Name,
195 LastAck: lastAck,
196 Disabled: habit.Disabled,
197 Class: classGood,
198 }
199}
200
201func getHabit(id uint) (habit Habit, err error) {
202 err = db.Model(&Habit{}).Find(&habit, id).Error
203 return
204}
205
206func getHabitHelper(w http.ResponseWriter, r *http.Request) (habit Habit, err error) {
207 id := getID(r)
208 if id == 0 {
209 err = errors.New("no id")
210 http.Error(w, "bad request", http.StatusBadRequest)
211 return
212 }
213
214 user, ok := getLoggedUser(r)
215 if !ok {
216 err = errors.New("no logged user")
217 http.Error(w, "unauthorized", http.StatusUnauthorized)
218 return
219 }
220
221 habit, err = getHabit(id)
222 if err != nil {
223 http.Error(w, "not found", http.StatusNotFound)
224 return
225 }
226
227 if habit.UserID != user.ID {
228 err = errors.New("forbidden")
229 http.Error(w, "forbidden", http.StatusForbidden)
230 }
231 return
232}
233
234func getAllHabits(userID uint) (positives []HabitDisplay, negatives []HabitDisplay, err error) {
235 var habits []Habit
236 err = db.Model(&Habit{}).Where(&Habit{UserID: userID}).Find(&habits).Error
237 if err != nil {
238 return
239 }
240
241 for _, habit := range habits {
242 habitDisplay := toHabitDisplay(habit)
243 if habit.Negative {
244 negatives = append(negatives, habitDisplay)
245 } else {
246 positives = append(positives, habitDisplay)
247 }
248 }
249 return
250}
251
252func getID(r *http.Request) uint {
253 res, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
254 if err != nil {
255 return 0
256 }
257 return uint(res)
258}