Improve the speed of the analysis of cwt and pscf by spatial indexing techniques
I studied your code recently and found the anlysis of cwt and pscf are so slow when calculating large amount of data. I used spatial indexing techniques to reduced time consumption. The original project that took 30 minutes to calculate Nij was reduced to 30s with the introduction of spatial indexing. A python version of the code that mimics your cwt analysis is as follows:
import geopandas as gpd
from shapely.geometry import Point
from rtree import index # 用于空间索引
from tqdm import tqdm # 用于进度条
import numpy as np
def point_in_polygon(polygon, point):
"""
判断点是否在多边形内部(包括边界)。
基于射线交叉算法实现。
参数:
- polygon: 多边形的顶点列表,格式为 [(x1, y1), (x2, y2), ...]
- point: 点的坐标,格式为 (x, y)
返回值:
- True: 点在多边形内部或边界上
- False: 点在多边形外部
"""
x, y = point
n = len(polygon)
if n < 3:
return False # 少于 3 个顶点,不是多边形
inside = False
x_old, y_old = polygon[-1] # 最后一个顶点
for i in range(n):
x_new, y_new = polygon[i]
if x_new > x_old:
x1, x2 = x_old, x_new
y1, y2 = y_old, y_new
else:
x1, x2 = x_new, x_old
y1, y2 = y_new, y_old
# 检查点是否在边的左侧
if (x_new < x) == (x <= x_old) and (y - y1) * (x2 - x1) < (y2 - y1) * (x - x1):
inside = not inside
x_old, y_old = x_new, y_new
return inside
def calculate_nij_cwt_wcwt(pscf_path, traj_paths, field_name, a_type, a_criterion, a_n_data, end_points, reduce_ratios, traj_num_end_points, traj_num_reduce_ratios,output_path):
"""
计算 Nij、N_Traj、CWT 和 WCWT 的 Python 实现
参数:
- pscf_path: PSCF 图层(网格单元)的 Shapefile 路径
- traj_paths: 轨迹图层 Shapefile 路径列表
- field_name: 轨迹数据中用于筛选的字段名(如污染物浓度)
- a_type: 类型标识(如 "Mij")
- a_criterion: 筛选阈值
- a_n_data: 无效数据标记(如 -9999.0)
- end_points: Nij 的分段点列表(如 [10, 20, 30])
- reduce_ratios: 每段的权重列表(如 [1.0, 0.8, 0.5])
- traj_num_end_points: N_Traj 的分段点列表(如 [5, 10, 15])
- traj_num_reduce_ratios: 每段的权重列表(如 [1.0, 0.8, 0.5])
"""
# 1. 读取 PSCF 网格单元
pscf_gdf = gpd.read_file(pscf_path)
pscf_gdf = pscf_gdf.to_crs("EPSG:4326") # 确保使用 WGS84 坐标系
print(f"PSCF 网格单元数: {len(pscf_gdf):,}")
# 2. 初始化 Nij、N_Traj 和 CWT 数组
nijs = np.zeros(len(pscf_gdf), dtype=int)
n_trajs = np.zeros(len(pscf_gdf), dtype=int)
cwts = np.zeros(len(pscf_gdf), dtype=float)
# 3. 为 PSCF 网格创建空间索引
print("构建空间索引...")
spatial_idx = index.Index()
for idx, row in pscf_gdf.iterrows():
spatial_idx.insert(idx, row.geometry.bounds)
# 4. 处理每个轨迹图层
for traj_path in traj_paths:
traj_gdf = gpd.read_file(traj_path)
traj_gdf = traj_gdf.to_crs(pscf_gdf.crs) # 统一坐标系
# 遍历轨迹图层中的每条线
for _, traj_row in tqdm(traj_gdf.iterrows(), total=len(traj_gdf), desc=f"处理 {traj_path}"):
# 跳过无效数据
value = traj_row[field_name]
if value == a_n_data:
continue
# 根据类型筛选数据(如 Mij 需满足阈值)
if a_type == "Mij" and value < a_criterion:
continue
# 提取轨迹线的所有点
line = traj_row.geometry
if line.is_empty or line.geom_type != "LineString":
continue
visited_cells = set() # 记录当前轨迹线访问过的网格
# 遍历轨迹线上的每个点
for point in line.coords:
p = (point[0], point[1]) # 提取二维坐标 (x, y)
# 通过空间索引快速筛选可能包含点的网格
possible_cells = list(spatial_idx.intersection((p[0], p[1], p[0], p[1])))
for cell_idx in possible_cells:
polygon = list(pscf_gdf.iloc[cell_idx].geometry.exterior.coords)
# 使用射线交叉算法判断点是否在网格内
if point_in_polygon(polygon, p):
nijs[cell_idx] += 1
cwts[cell_idx] += value
if cell_idx not in visited_cells:
n_trajs[cell_idx] += 1
visited_cells.add(cell_idx)
break # 一个点只属于一个网格
# 5. 计算 CWT 值
for i in range(len(pscf_gdf)):
if nijs[i] > 0:
cwts[i] = cwts[i] / nijs[i]
# 6. 计算 WCWT 初步值
wcwts = np.zeros(len(pscf_gdf), dtype=float)
for i in range(len(pscf_gdf)):
nij = nijs[i]
cwt = cwts[i]
wcwt = cwt # 默认值
for j in range(len(end_points)):
if j == len(end_points) - 1:
if nij <= end_points[j]:
wcwt = cwt * reduce_ratios[j]
else:
if nij <= end_points[j] and nij > end_points[j + 1]:
wcwt = cwt * reduce_ratios[j]
wcwts[i] = wcwt
# 7. 计算 WCWT 二次更新值
for i in range(len(pscf_gdf)):
t_nij = n_trajs[i]
wcwt = wcwts[i] # 使用初步计算的 WCWT 值
for j in range(len(traj_num_end_points)):
if j == len(traj_num_end_points) - 1:
if t_nij <= traj_num_end_points[j]:
wcwt = wcwt * traj_num_reduce_ratios[j]
else:
if t_nij <= traj_num_end_points[j] and t_nij > traj_num_end_points[j + 1]:
wcwt = wcwt * traj_num_reduce_ratios[j]
wcwts[i] = wcwt
# 8. 将结果保存到 PSCF 图层
pscf_gdf["Nij"] = nijs
pscf_gdf["N_Traj"] = n_trajs
pscf_gdf["CWT"] = cwts
pscf_gdf["WCWT"] = wcwts
pscf_gdf.to_file(output_path)
print(f"结果已保存至: {output_path}")
# 示例调用
if __name__ == "__main__":
# 示例参数
end_points = [80, 20, 10] # Nij 的分段点
reduce_ratios = [0.7, 0.42, 0.05] # 每段的权重
traj_num_end_points = [3,1] # N_Traj 的分段点
traj_num_reduce_ratios = [0.5, 0.2] # 每段的权重
# *********************************************修改参数**********
years = [2017,2018,2019,2020,2021,2022, 2023] # 要计算的年份,注意修改
year = 2018 # 要计算的年份,注意修改
pscf_path = r"C:\Users\GIGA RTX\Desktop\SWY\HYSPLIT\O3\O3T_2018.shp" # PSCF 网格文件路径
field_names = ["SO2","O3","CO","PM2d5","NO2"] #字段名+
f"C:/Users/GIGA RTX/Desktop/SWY/HYSPLIT/O3/grid2020.shp"
# *********************************************修改结束**********
years
for i in range(len(field_names)):
calculate_nij_cwt_wcwt(
pscf_path=pscf_path, # PSCF 网格文件路径
traj_paths = [r"C:\Users\GIGA RTX\Desktop\SWY\HYSPLIT\calculate_tra\calculate_tra\2020\2020_2_5.shp"], # 这是轨迹*********************************************修改参数**********
field_name=field_names[i], # 轨迹数据中用于筛选的字段
a_type="Nij", # 类型标识
a_criterion=0, # 筛选阈值
a_n_data=-9999.0, # 无效数据标记
end_points=end_points, # Nij 的分段点
reduce_ratios=reduce_ratios, # 每段的权重
traj_num_end_points=traj_num_end_points, # N_Traj 的分段点
traj_num_reduce_ratios=traj_num_reduce_ratios, # 每段的权重
output_path="tt25_WCWT.shp" # 结果保存路径
)
Great job! Is it possible for you to write the similar code using Java and submit a PR to TrajStat for better performance?
Thanks for the recognition, your open source work has greatly facilitated our research. However, I am not familiar with java, so if I have some spare time, I will provide a java version