Parcourir la source

添加session测试

tangs il y a 7 ans
Parent
commit
2bd6ebfa52

+ 1 - 1
src/config.conf

@@ -1,5 +1,5 @@
 # 数据库连接地址
-db_url=user:password@tcp(localhost:5555)/dbname?charset=utf8
+db_url=ddpf:ddpf_123_ddpf@tcp(localhost:3306)/ddpf_new_13?charset=utf8
 
 # 程序监听的端口号
 listen_port=:8090

+ 0 - 68
src/ddpf/login/session.go

@@ -1,68 +0,0 @@
-package login
-
-import (
-	"net/http"
-	"github.com/tangs-drm/go-tool/log"
-	"github.com/tangs-drm/go-tool/util"
-	"strings"
-)
-
-type FilterFunc func(args util.Map) error
-
-type Filter struct {
-	FilterMap map[string]FilterFunc
-}
-
-func NewFilter() *Filter {
-	filter := &Filter{}
-	filter.DefaultFilter()
-	return filter
-}
-
-func (ft *Filter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	cookie, err := r.Cookie("session")
-	if err != nil || cookie == nil {
-		cookie = &http.Cookie{}
-	}
-	log.Debug("ServerHTTP ---- ")
-	// 过滤器
-	path  := r.URL.Path
-	ff := ft.getFilterFunc(path)
-	if nil == ff {
-		return
-	}
-
-
-
-	return
-}
-
-func (ft *Filter) getFilterFunc(path string) FilterFunc {
-	for key := range ft.FilterMap {
-		if strings.HasPrefix(key, path) {
-			return ft.FilterMap[key]
-		}
-	}
-	return nil
-}
-
-func (ft *Filter) Filter(router string, ff FilterFunc) {
-}
-
-func (ft *Filter) Print() {
-	log.Debug("Filter Print ---- ")
-	for key, _ := range ft.FilterMap {
-		log.Debug("-----> %v", key)
-	}
-}
-
-func (ft *Filter) DefaultFilter() {
-	var usrFilter FilterFunc = func(args util.Map) error {
-		return nil
-	}
-	ft.Filter("/usr", usrFilter)
-}
-
-func (ft *Filter) FindUser() {
-
-}

+ 21 - 3
src/ddpf/model/dbm/dbm.go

@@ -1,8 +1,26 @@
 package dbm
 
-import "github.com/tangs-drm/go-tool/dbm"
+import (
+	"github.com/tangs-drm/go-tool/dbm"
+	"io/ioutil"
+	"github.com/tangs-drm/go-tool/log"
+)
 
-func Db() *dbm.DB {
-	return dbm.DbMangerV.Db()
+var Db = func() *dbm.DB {
+	return dbm.DefaultDBManger.D()
 }
 
+func Init(path string) error {
+	bys, err := ioutil.ReadFile(path)
+	if err != nil {
+		return err
+	}
+
+	err = Db().ExecScripts(string(bys), true)
+	if err != nil {
+		return err
+	}
+
+	log.Debug("[Init] dbm init db with path(%v) success", path)
+	return nil
+}

+ 20 - 34
src/ddpf/model/session/session.go

@@ -8,37 +8,36 @@ import (
 	"net/http"
 )
 
-var ShowLog bool
-
 var Max_Session_Num = 10
-var Valid_Session_Time int64 = 2592000000 // 一个月
+var Valid_Session_Time int64 = 2592000 // 一个月
 
 var SessionLog bool
 
 func slog(format string, args... interface{}) {
 	if SessionLog {
-		log.LogD_(2, format, args...)
+		log.LogD_(3, format, args...)
 	}
 }
 
 // UserFilterFunc过滤器,检查用户是否是登录状态
 var UserFilterFunc thttp.FilterFunc = func(w http.ResponseWriter, r *http.Request) int {
 	cookie, err := r.Cookie("token")
+	slog("%v", util.S2Json(cookie))
 	if err != nil {
 		slog("[UserFilterFunc] get cookie token error ->(%v)", err)
 		http.Redirect(w, r, "/", 301)
 		return thttp.REQUEST_RETURN
 	}
-	token := cookie.String()
+	token := cookie.Value
 	// token 是36位的UUID,形如4C2FB50E-C530-7868-01DF-165B2BC47308
 	if len(token) != 36 {
-		slog("[UserFilterFunc] get empty token")
+		slog("[UserFilterFunc] get invalid token(%v)", token)
 		http.Redirect(w, r, "/", 301)
 		return thttp.REQUEST_RETURN
 	}
 
 	// 检查token是否合法
-	valid, err := CheckToken(token)
+	valid, err := CheckValidSession(token)
 	if err != nil {
 		slog("[UserFilterFunc] check token(%v) error -> (%v)", token, err)
 		http.Redirect(w, r, "/", 301)
@@ -51,11 +50,12 @@ var UserFilterFunc thttp.FilterFunc = func(w http.ResponseWriter, r *http.Reques
 		return thttp.REQUEST_RETURN
 	}
 
-	err = FlushToken(token)
+	err = FlushSession(token)
 	if err != nil {
 		slog("[UserFilterFunc] flush token with token(%v) error ->(%v)", token, err)
 	}
 
+	slog("--------------")
 	return thttp.REQUEST_CONTINUE
 }
 
@@ -65,8 +65,8 @@ func CreateSession(uid string) (*Session, error) {
 	session := &Session{
 		Id:util.UUID(),
 		Uid:uid,
-		Time:util.Now(),
-		LastTime:util.Now(),
+		Time:util.Now10(),
+		LastTime:util.Now10(),
 	}
 	var sqlString string = "INSERT INTO SESSION(ID,UID,TIME,LASTTIME) VALUE(?,?,?,?);"
 	stmt, err := dbm.Db().Prepare(sqlString)
@@ -87,11 +87,13 @@ func CreateSession(uid string) (*Session, error) {
 
 // CheckValidSession 是检查用户的session是否有效
 // 检查的指标有session是否存在,有效的session个数,session的有效时间
+// 返回值1 false: 无效, true: 有效
+// 返回值2 是否有错误信息,如有错误,返回err,否则返回nil
 func CheckValidSession(session string) (bool, error) {
 	if len(session) < 1 {
 		return false, nil
 	}
-	var sqlString = "SELECT ID, LASTTIME FROM (SELECT ID, LASTTIME FROM SESSION ORDER BY LASTTIME DESC LIMIT ?) WHERE ID = ?;"
+	var sqlString = "SELECT collect.ID, collect.LASTTIME FROM (SELECT ID, LASTTIME FROM SESSION ORDER BY LASTTIME DESC LIMIT ?) as collect WHERE ID = ?;"
 	rows, err := dbm.Db().Query(sqlString, Max_Session_Num, session)
 	if err != nil {
 		log.Error("CheckValidSession Query by session(%v), Max_Session_Num(%v) error ->(%v)", session, Max_Session_Num, err)
@@ -112,39 +114,23 @@ func CheckValidSession(session string) (bool, error) {
 	}
 
 	validTime := lastTime + Valid_Session_Time
-	if validTime < util.Now() {
-		return false, nil
-	}
-	return true, nil
-}
 
-// CheckToken 检查用户的session是否有效
-// 返回值1 false: 无效, true: 有效
-// 返回值2 是否有错误信息,如有错误,返回err,否则返回nil
-func CheckToken(session string) (bool, error) {
-	var sqlString = "SELECT * FROM SESSION WHERE TOKEN = ?"
-	rows, err := dbm.Db().Query(sqlString, session)
-	if err != nil {
-		log.Error("[CheckToken] check session(%v) error ->(%v)", session, err)
-		return false, err
+	log.Debug("%v, %v, %v, %v, %v", validTime, util.Now10(), lastTime, Valid_Session_Time, validTime > util.Now10())
+	if validTime > util.Now10() {
+		return true, nil
 	}
-	defer rows.Close()
-	if !rows.Next() {
-		return false, nil
-	}
-
-	return true, nil
+	return false, nil
 }
 
 // FlushToken更新session对应的最后一个登录时间
-func FlushToken(session string) error {
-	var sqlString = "UPDATE SESSION SET LASTTIME = ? WHERE ID = ?"
+func FlushSession(session string) error {
+	var sqlString = "UPDATE SESSION SET LASTTIME = ? WHERE ID = ?;"
 	stmt, err := dbm.Db().Prepare(sqlString)
 	if err != nil {
 		return err
 	}
 	defer stmt.Close()
 
-	_, err = stmt.Exec(util.Now(), session)
+	_, err = stmt.Exec(util.Now10(), session)
 	return err
 }

+ 193 - 0
src/ddpf/model/session/session_test.go

@@ -0,0 +1,193 @@
+package session
+
+import (
+	"testing"
+	tdbm "github.com/tangs-drm/go-tool/dbm"
+	"ddpf/model/dbm"
+	"time"
+	"github.com/tangs-drm/go-tool/log"
+	"fmt"
+	"github.com/tangs-drm/go-tool/util"
+	"github.com/tangs-drm/go-tool/http"
+	ghttp "net/http"
+	"io/ioutil"
+)
+
+func init() {
+	// 开启服务
+	SessionLog = true
+	db, err := tdbm.Default("testUser:123@tcp(localhost:3306)/testdb?charset=utf8")
+	if err != nil {
+		panic(err)
+		return
+	}
+	err = db.Ping()
+	if err != nil {
+		panic(err)
+		return
+	}
+	// init order.sql
+	var sqlConf string = "../../../order.sql"
+	err = dbm.Init(sqlConf)
+	if err != nil {
+		panic(err)
+		return
+	}
+}
+
+// TestSession 对session的单元测试
+func TestSession(t *testing.T) {
+	log.Debug("TestSession --- START")
+
+	fmt.Println(time.Now().Unix())
+	fmt.Println(time.Now().UnixNano())
+
+	fmt.Println(util.Now())
+	fmt.Println(util.Now10())
+	fmt.Println(util.Now13())
+
+	var uid string = "u100"
+	session, err := CreateSession(uid)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	if len(session.Id) < 1 || session.Uid != uid {
+		t.Error(session.Id, session.Uid)
+		return
+	}
+	if session.Time > 10000000001 || session.LastTime > 10000000001 {
+		t.Error(session.Time, session.LastTime)
+		return
+	}
+
+	Valid_Session_Time = 5
+	valid, err := CheckValidSession(session.Id)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	if !valid {
+		t.Error(valid)
+		return
+	}
+
+	time.Sleep(10*time.Second)
+	valid, err = CheckValidSession(session.Id)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	if valid {
+		t.Error(valid)
+		return
+	}
+
+	// flushSession
+	FlushSession(session.Id)
+	valid, err = CheckValidSession(session.Id)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	if !valid {
+		t.Error(valid)
+		return
+	}
+	log.Debug("TestSession --- END")
+}
+
+func HaloHandleFunc(w ghttp.ResponseWriter, r *ghttp.Request) {
+	fmt.Println("HaloHandleFunc come in")
+	w.Write([]byte("halo"))
+	return
+}
+
+func RootHandleFunc(w ghttp.ResponseWriter, r *ghttp.Request) {
+	fmt.Println("RootHandleFunc come in")
+	w.Write([]byte("root"))
+}
+
+func TestUserFilter(t *testing.T) {
+	SessionLog = true
+	mux := http.NewServerMux()
+	mux.FilterFunc("/usr", UserFilterFunc)
+	mux.HandleFunc("/login", RootHandleFunc)
+	mux.HandleFunc("/usr/halo", HaloHandleFunc)
+
+	go func() {
+		log.Error("%v", ghttp.ListenAndServe(":8023", mux))
+	}()
+
+	//time.Sleep(2 * time.Second)
+
+	resString, err := util.HTTPGetString("http://127.0.0.1:8023/usr/halo")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	fmt.Println(resString)
+
+	// 设置session
+	session, err := CreateSession("u101")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
+	// 带上cookie的请求
+	req, err := ghttp.NewRequest("GET", "http://127.0.0.1:8023/usr/halo", nil)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	cookie := &ghttp.Cookie{
+		Domain:"/",
+		Name:"token",
+		Value:session.Id,
+	}
+	req.AddCookie(cookie)
+	client := ghttp.Client{}
+	resp, err := client.Do(req)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	bys, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	if string(bys) != "halo" {
+		t.Error(string(bys))
+		return
+	}
+
+	//
+	req, err = ghttp.NewRequest("GET", "http://127.0.0.1:8023/usr/halo", nil)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	cookie = &ghttp.Cookie{
+		Domain:"/",
+		Name:"token",
+		Value:util.UUID(),
+	}
+	req.AddCookie(cookie)
+	client = ghttp.Client{}
+	resp, err = client.Do(req)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	bys, err = ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	if string(bys) != "root" {
+		t.Error(string(bys))
+		return
+	}
+}

+ 1 - 1
src/github.com/tangs-drm/go-tool

@@ -1 +1 @@
-Subproject commit da570539d081f8fd5616cab078a01bdecc6b5642
+Subproject commit 3d79ea8a677abc5ad618522b4fbd86578c8a91e4

+ 15 - 2
src/main.go

@@ -8,6 +8,7 @@ import (
 	"github.com/tangs-drm/go-tool/dbm"
 	_ "github.com/go-sql-driver/mysql"
 	"fmt"
+	tdbm "ddpf/model/dbm"
 	"os"
 	"ddpf/model/session"
 )
@@ -15,6 +16,10 @@ import (
 func main() {
 	var err error
 
+	execPath, err := os.Getwd()
+	fmt.Println(execPath, err)
+	log.Debug("[main] os.Getwd() get %v, %v", execPath, err)
+
 	// 初始化配置文件
 	var conf = "src/config.conf"
 	var config = config.NewConfig()
@@ -22,7 +27,7 @@ func main() {
 	fmt.Println(os.Getwd())
 	if err != nil {
 		fmt.Println(err)
-		log.Error("main read config with conf(%v) error ->(%v)", conf, err)
+		log.Error("[main] read config with conf(%v) error ->(%v)", conf, err)
 		return
 	}
 	config.Print()
@@ -32,7 +37,15 @@ func main() {
 	_, err = dbm.Default(db_url)
 	if err != nil {
 		fmt.Println(err)
-		log.Error("main init db by url(%v) error ->(%v)", db_url, err)
+		log.Error("[main] init db by url(%v) error ->(%v)", db_url, err)
+		return
+	}
+	// 初始化数据库的表,索引等
+	var mysqlPath = "src/order.sql"
+	err = tdbm.Init(mysqlPath)
+	if err != nil {
+		fmt.Println(err)
+		log.Error("[main] init sql command with path(%v) error ->(%v)", mysqlPath, err)
 		return
 	}
 	log.Debug("main init db success")

+ 60 - 0
src/order.sql

@@ -0,0 +1,60 @@
+/*==============================================================*/
+/* DBMS name:      MySQL 5.0                                    */
+/* Created on:     2017/8/6 14:47:07                            */
+/*==============================================================*/
+
+
+DROP INDEX LASTTIME ON SESSION;
+
+DROP INDEX TIME ON SESSION;
+
+DROP INDEX UID ON SESSION;
+
+DROP INDEX ID ON SESSION;
+
+DROP TABLE IF EXISTS SESSION;
+
+/*==============================================================*/
+/* Table: SESSION                                               */
+/*==============================================================*/
+CREATE TABLE SESSION
+(
+   ID                   VARCHAR(64) NOT NULL,
+   UID                  VARCHAR(64),
+   TIME                 INT,
+   LASTTIME             INT,
+   PRIMARY KEY (ID)
+);
+
+/*==============================================================*/
+/* Index: ID                                                    */
+/*==============================================================*/
+CREATE INDEX ID ON SESSION
+(
+   ID
+);
+
+/*==============================================================*/
+/* Index: UID                                                   */
+/*==============================================================*/
+CREATE INDEX UID ON SESSION
+(
+   UID
+);
+
+/*==============================================================*/
+/* Index: TIME                                                  */
+/*==============================================================*/
+CREATE INDEX TIME ON SESSION
+(
+   TIME
+);
+
+/*==============================================================*/
+/* Index: LASTTIME                                              */
+/*==============================================================*/
+CREATE INDEX LASTTIME ON SESSION
+(
+   LASTTIME
+);
+