fix(training): patch lightgbm sklearn compatibility
This commit is contained in:
@@ -264,6 +264,28 @@ def sample_event(rng, employee):
|
||||
return event
|
||||
|
||||
|
||||
def attach_event_timeline(df):
|
||||
df = df.copy()
|
||||
rng = np.random.default_rng(config.RANDOM_STATE)
|
||||
base_date = np.datetime64('2025-01-01')
|
||||
timelines = []
|
||||
|
||||
for employee_id, group in df.groupby('员工编号', sort=False):
|
||||
group = group.copy().reset_index(drop=True)
|
||||
event_count = len(group)
|
||||
offsets = np.sort(rng.integers(0, 365, size=event_count))
|
||||
group['事件日期'] = [
|
||||
str(pd.Timestamp(base_date + np.timedelta64(int(offset), 'D')).date())
|
||||
for offset in offsets
|
||||
]
|
||||
group['事件日期索引'] = offsets.astype(int)
|
||||
group['事件序号'] = np.arange(1, event_count + 1)
|
||||
group['员工历史事件数'] = event_count
|
||||
timelines.append(group)
|
||||
|
||||
return pd.concat(timelines, ignore_index=True)
|
||||
|
||||
|
||||
def validate_dataset(df):
|
||||
required_columns = [
|
||||
'员工编号',
|
||||
@@ -273,6 +295,9 @@ def validate_dataset(df):
|
||||
'通勤时长分钟',
|
||||
'是否慢性病史',
|
||||
'请假类型',
|
||||
'事件序号',
|
||||
'事件日期索引',
|
||||
'员工历史事件数',
|
||||
'缺勤时长(小时)',
|
||||
]
|
||||
for column in required_columns:
|
||||
@@ -309,7 +334,7 @@ def generate_dataset(output_path=None, sample_count=12000, random_state=None):
|
||||
for idx in employee_idx:
|
||||
events.append(sample_event(rng, employees[int(idx)]))
|
||||
|
||||
df = pd.DataFrame(events)
|
||||
df = attach_event_timeline(pd.DataFrame(events))
|
||||
validate_dataset(df)
|
||||
|
||||
if output_path:
|
||||
|
||||
Reference in New Issue
Block a user