package cmd import ( "fmt" "io" "math/rand" "net/http" "net/url" "os" "path/filepath" "strings" "time" "github.com/PuerkitoBio/goquery" "github.com/google/uuid" "github.com/jzelinskie/geddit" "github.com/kf5grd/keybasebot" ) func getMetaLoginCreds(b *keybasebot.Bot) (username, password string, err error) { user, ok := b.Meta["reddit-user"] if !ok { err = fmt.Errorf("No reddit username has been set.") } pass, ok := b.Meta["reddit-pass"] if !ok { err = fmt.Errorf("No reddit password has been set.") } username = user.(string) password = pass.(string) return } func getRedditSubmissions(b *keybasebot.Bot, sub string, sort string, count int) ([]*geddit.Submission, error) { // check the sort type allowedSort := map[string]bool{ "hot": true, "new": true, "rising": true, "top": true, "controversial": true, "": true, } if !allowedSort[sort] { return nil, fmt.Errorf("%s is not an allowed sort method", sort) } user, pass, err := getMetaLoginCreds(b) if err != nil { return nil, err } session, err := geddit.NewLoginSession( user, pass, "golang:pub.keybase.haukened.ssh0le:v2.0 (by /u/no-names-here)", ) if err != nil { return nil, err } subOpts := geddit.ListingOptions{ Limit: count, } submissions, err := session.SubredditSubmissions(sub, geddit.PopularitySort(sort), subOpts) if err != nil { return nil, err } return submissions, nil } func filterSubmissions(gs []*geddit.Submission, test func(*geddit.Submission) bool) (ret []*geddit.Submission) { for _, s := range gs { if test(s) { ret = append(ret, s) } } return } func filterMedia(gs []*geddit.Submission) (ret []*geddit.Submission) { test := func(s *geddit.Submission) bool { return strings.HasSuffix(s.URL, ".jpg") || strings.HasSuffix(s.URL, ".png") || strings.HasSuffix(s.URL, ".gif") || strings.HasSuffix(s.URL, ".mp4") } ret = filterSubmissions(gs, test) return } func fixGfycatURL(sub *geddit.Submission) error { resp, err := http.Get(sub.URL) if err != nil { return err } defer resp.Body.Close() doc, err := goquery.NewDocumentFromReader(resp.Body) if err != nil { return err } vids := doc.Find("source") found := false for _, vid := range vids.Nodes { for _, element := range vid.Attr { if strings.HasSuffix(element.Val, ".mp4") && strings.Contains(element.Val, "-mobile") { sub.URL = element.Val found = true break } } } if !found { return fmt.Errorf("Unable to find Gfycat Video, because they suck.") } return nil } func downloadRedditMedia(s *geddit.Submission) (path string, err error) { // break down the url to get the path, less the host and any arguments or fragments URLParts, err := url.Parse(s.URL) if err != nil { return } // get the file extension ext := filepath.Ext(URLParts.Path) if ext == "" { err = fmt.Errorf("URL had no media file extension") return } // get the file response, err := http.Get(s.URL) if err != nil { return } defer response.Body.Close() // open a tmp file for writing uid := uuid.New() path = fmt.Sprintf("/tmp/%s.%s", uid, ext) file, err := os.Create(path) if err != nil { return } defer file.Close() _, err = io.Copy(file, response.Body) return } func getRandomRedditMedia(b *keybasebot.Bot, sub string, sortby string, count int) (path, description string, err error) { submissions, err := getRedditSubmissions(b, sub, sortby, count) if err != nil { return } mediaSubs := filterMedia(submissions) // select a random post rand.Seed(time.Now().Unix()) randSub := mediaSubs[rand.Intn(len(mediaSubs))] path, err = downloadRedditMedia(randSub) if err != nil { return } // generate the description description = fmt.Sprintf("image by /u/%s", randSub.Author) return }