#-*-coding:utf-8-*-
#此模块定义了几种数据库连接和操作类
#编写者:lianzhang(785674410@qq.com)
#2015-02-07
本帖隐藏的内容
from pandas.io.sql import to_sql, read_sqlimport pandas as pd
import datetime
import pytz
from collections import OrderedDict
import abc
import pymssql
import sqlite3
import MySQLdb
#交易日历的来源指数
IndexName = "csi300"
#交易日历开始和结束日期
Start = pd.datetime(2005, 1, 1, 0, 0, 0, tzinfo=pytz.timezone("Asia/Shanghai"))
End = pd.datetime.now(tz = pytz.timezone('Asia/Shanghai'))
#mysql数据库参数设置
MySQLParas1 = {"host":"localhost", "user":"root", "pwd":"*********", "db":"chihiro"}
MySQLParas2 = {"host":"localhost", "user":"root", "pwd":"*********, "db":"factor"}
#sql server数据库参数设置
MSSQLParas = {}
#sqlite数据库参数设置
SQLiteParas = {}
def CloseConnect(func):
"""
这是一个decorator,用于关闭连接,释放内存
"""
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
self.conn.commit()
self.conn.close()
return result
return wrapper
class DataBase(metaclass=abc.ABCMeta):
"""
定义数据库抽象类
"""
@abc.abstractmethod
def GetConnect(self):
"""
连接到数据库
"""
raise NotImplementedError("无法执行抽象方法")
@CloseConnect
def ExecQuery(self, sql):
"""
执行查询语句
返回的是一个包含tuple的list,list的元素是记录行,tuple的元素是每行记录的字段
"""
cur = self.GetConnect()
cur.execute(sql)
resList = cur.fetchall()
return resList
@CloseConnect
def ExecNonQuery(self, sql):
"""
执行非查询语句
"""
cur = self.GetConnect()
cur.execute(sql)
@CloseConnect
def To_SQL(self, frame, name, *args, **kwargs):
"""
调用pandas.io.sql中的to_sql,将pandas数据框放入到数据库中
frame-被写入的数据框
name-表名
更多参数的具体含义和用法,参见to_sql函数
"""
try:
cur = self.GetConnect()
#貌似目前不支持sql server
to_sql(frame=frame, name=name, con=self.conn, *args, **kwargs)
except Exception as e:
print(e)
@CloseConnect
def Read_SQL(self, sql, *args, **kwargs):
"""
调用pandas.io.sql中的read_sql,从数据库中取出pandas数据框
sql-sql query
更多参数的具体含义和用法,参见read_sql函数
"""
try:
cur = self.GetConnect()
result = read_sql(sql=sql, con=self.conn, *args, **kwargs)
return result
except Exception as e:
print(e)
def FetchRawDataBaseData(self, indexes=None, stocks=None, start=None, end=None):
"""
从数据库中下载原始数据
参数:
indexes-指数列表
stock-股票列表,格式为字符串,数据库中存储的表名格式为:a+股票代码
start-开始日期,如果不指明,则从2005-1-1开始
end-结束日期,如果不指明,则截至到今天(如果今天不是交易日,则截至到最近的交易日)
返回:
该函数返回一个有序字典(OrderedDict),key为股票代码,value为DataFrame,日期列为index
"""
assert indexes is not None or stocks is not None, """请指明要提取的股票或者指数"""
if start is None:
start = pd.datetime(2005, 1, 1, 0, 0, 0, tzinfo=pytz.timezone("Asia/Shanghai"))
start = "'" + start.strftime("%Y-%m-%d") + "'"
else:
start = "'" + start.strftime("%Y-%m-%d") + "'"
if end is None:
end = pd.datetime.now(tz = pytz.timezone('Asia/Shanghai'))
end = "'" + end.strftime("%Y-%m-%d") + "'"
else:
end = "'" + end.strftime("%Y-%m-%d") + "'"
if start is not None and end is not None:
assert start < end,"起始日必须小于结束日"
data = OrderedDict()
if stocks is not None:
for stock in stocks:
sql = "select date,[open],high,low,avgprice as [close],traamount as volume,[close] as \
price from "+"a"+stock+" where date>="+start+" and date<="+end+" order by date"
stkdata = self.Read_SQL(sql=sql, index_col='date')
#这里可以加入更多代码,判断最近的交易日和获取数据的交易日是否相同,如果不同,更新数据
data[stock] = stkdata
if indexes is not None:
for index in indexes:
sql = "select * from " + index + " where date>=" + start + " amd date<=" + end + " order by date"
indexdata = self.Read_SQL(sql)
data[index] = indexdata
return data
def FetchFromDataBase(self, indexes=None, stocks=None, start=None, end=None, adjusted=True):
"""
从数据库中提取数据(调用FetchRawDataBaseData)
adjusted参数如果为True,则选择调整后的价格;如果为False,则选择原始数据;默认为True
返回:
该函数返回一个数据框,其中index为时间,每列对应一个股票的价格
"""
data = self.FetchRawDataBaseData(indexes, stocks, start, end)
if adjusted:
close_key = "price"
else:
close_key = "close"
df = pd.DataFrame({key: d[close_key] for key, d in data.items()})
df.index = df.index.tz_localize(pytz.timezone("Asia/Shanghai"))
return df
def FetchBarsFromDataBase(self, indexes=None, stocks=None, start=None, end=None):
"""
从数据库中提取Bar数据(调用FetchRawDataBaseData)
返回:
该函数返回一个panel,其中
"""
data = self.FetchRawDataBaseData(indexes, stocks, start, end)
panel = pd.Panel(data)
panel.major_axis = panel.major_axis.tz_localize(pytz.timezone("Asia/Shanghai"))
return panel
class MSSQL(DataBase):
"""
是DataBase的子类,即mssql类,用于和sql server进行交互
"""
def __init__(self, host, user, pwd, db):
self.host = host
self.user = user
self.pwd = pwd
self.db = db
def GetConnect(self):
"""
得到连接信息
返回游标
"""
if not self.db:
raise NameError("没有设置数据库信息")
self.conn = pymssql.connect(host=self.host, user=self.user, password=self.pwd, database=self.db, charset="utf8")
cur = self.conn.cursor()
if not cur:
raise NameError("连接数据库失败")
else:
return cur
class MySQL(DataBase):
"""
是DataBase的子类,即mysql类,用于和mysql进行交互
"""
def __init__(self, host, user, pwd, db, port=3306):
self.host = host
self.user = user
self.pwd = pwd
self.db = db
self.port= port
def GetConnect(self):
"""
得到连接信息
返回游标
"""
if not self.db:
raise NameError("没有设置数据库信息")
self.conn = MySQLdb.connect(host=self.host, user=self.user, passwd=self.pwd, db=self.db, port=self.port)
cur = self.conn.cursor()
if not cur:
raise NameError("连接数据库失败")
else:
return cur
class SQLite(DataBase):
"""
是DataBase的子类,即sqlite类,用于和sqlite进行交互
"""
def __init__(self, db):
self.db = db
def GetConnect(self):
"""
得到连接信息
返回游标
"""
if not self.db:
raise NameError("没有设置数据库信息")
self.conn = sqlite3.connect(self.db)
cur = self.conn.cursor()
if not cur:
raise NameError("连接数据库失败")
else:
return cur
def main():
ms = MSSQL(host='localhost', user='sa', pwd='7268015', db='stockdb')
#resList = ms.ExecQuery("SELECT * FROM FG")
#print(resList[:10])
#data1 = ms.Read_SQL(sql="SELECT * FROM FG")
#print(data1.head())
#ms.To_SQL(frame=data1, name='data1', flavor='sqlite')
temp=ms.FetchBarsFromDataBase(stocks=['000001','000002','000008'])
print(temp)
print(temp.major_axis)
print(temp.minor_axis)
print(temp.items)
if __name__ == '__main__':
main()